Versioned   A
last analyzed

Complexity

Total Complexity 2

Size/Duplication

Total Lines 8
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
dl 0
loc 8
rs 10
c 0
b 0
f 0
wmc 2

2 Methods

Rating   Name   Duplication   Size   Complexity  
A map() 0 4 1
A __mapper_cls__() 0 7 2
1
"""Versioned mixin class and other utilities."""
2
3
import datetime
4
5
from sqlalchemy.ext.declarative import declared_attr
6
from sqlalchemy.orm import mapper, attributes, object_mapper
7
from sqlalchemy.orm.exc import UnmappedColumnError
8
from sqlalchemy import Table, Column, ForeignKeyConstraint, Integer, ForeignKey
9
from sqlalchemy import event
10
from sqlalchemy.orm.properties import RelationshipProperty
11
12
from sqlalchemy_utils import ArrowType
13
14
def col_references_table(col, table):
15
    for fk in col.foreign_keys:
16
        if fk.references(table):
17
            return True
18
    return False
19
20
def _is_versioning_col(col):
21
    return "version_meta" in col.info
22
23
def _history_mapper(local_mapper):
24
    cls = local_mapper.class_
25
26
    # set the "active_history" flag
27
    # on on column-mapped attributes so that the old version
28
    # of the info is always loaded (currently sets it on all attributes)
29
    for prop in local_mapper.iterate_properties:
30
        getattr(local_mapper.class_, prop.key).impl.active_history = True
31
32
    super_mapper = local_mapper.inherits
33
    super_history_mapper = getattr(cls, '__history_mapper__', None)
34
35
    polymorphic_on = None
36
    super_fks = []
37
38
    def _col_copy(col):
39
        col = col.copy()
40
        col.unique = False
41
        col.default = col.server_default = None
42
        return col
43
44
    if not super_mapper or local_mapper.local_table is not super_mapper.local_table:
45
        cols = []
46
        for column in local_mapper.local_table.c:
47
            if _is_versioning_col(column):
48
                continue
49
50
            col = _col_copy(column)
51
52
            if super_mapper and col_references_table(column, super_mapper.local_table):
53
                super_fks.append((col.key, list(super_history_mapper.local_table.primary_key)[0]))
54
55
            cols.append(col)
56
57
            if column is local_mapper.polymorphic_on:
58
                polymorphic_on = col
59
60
        if super_mapper:
61
            super_fks.append(('version', super_history_mapper.local_table.c.version))
62
63
        version_meta = {"version_meta": True}  # add column.info to identify
64
                                               # columns specific to versioning
65
66
        # "version" stores the integer version id.  This column is
67
        # required.
68
        cols.append(Column('version', Integer, primary_key=True,
69
                            autoincrement=False, info=version_meta))
70
71
        # "changed" column stores the UTC timestamp of when the
72
        # history row was created.
73
        # This column is optional and can be omitted.
74
        model_name = cls.__name__.lower()
75
        cols.append(Column('%s_changed_at' % model_name, ArrowType,
76
                            default=datetime.datetime.utcnow,
77
                            info=version_meta))
78
        cols.append(Column('%s_changed_by' % model_name, Integer, ForeignKey("users.id"),
79
                            info=version_meta))
80
81
        if super_fks:
82
            cols.append(ForeignKeyConstraint(*zip(*super_fks)))
83
84
        table = Table(local_mapper.local_table.name + '_history',
85
                        local_mapper.local_table.metadata,
86
                        *cols,
87
                        schema=local_mapper.local_table.schema
88
        )
89
    else:
90
        # single table inheritance.  take any additional columns that may have
91
        # been added and add them to the history table.
92
        for column in local_mapper.local_table.c:
93
            if column.key not in super_history_mapper.local_table.c:
94
                col = _col_copy(column)
95
                super_history_mapper.local_table.append_column(col)
96
        table = None
97
98
    if super_history_mapper:
99
        bases = (super_history_mapper.class_,)
100
    else:
101
        bases = local_mapper.base_mapper.class_.__bases__
102
    versioned_cls = type.__new__(type, "%sHistory" % cls.__name__, bases, {})
103
104
    m = mapper(
105
            versioned_cls,
106
            table,
107
            inherits=super_history_mapper,
108
            polymorphic_on=polymorphic_on,
109
            polymorphic_identity=local_mapper.polymorphic_identity
110
            )
111
    cls.__history_mapper__ = m
112
113
    if not super_history_mapper:
114
        local_mapper.local_table.append_column(
115
            Column('version', Integer, default=1, nullable=False)
116
        )
117
        local_mapper.add_property("version", local_mapper.local_table.c.version)
118
119
120
class Versioned(object):
121
    @declared_attr
122
    def __mapper_cls__(cls):
123
        def map(cls, *arg, **kw):
124
            mp = mapper(cls, *arg, **kw)
125
            _history_mapper(mp)
126
            return mp
127
        return map
128
129
130
def versioned_objects(iter):
131
    for obj in iter:
132
        if hasattr(obj, '__history_mapper__'):
133
            yield obj
134
135
def create_version(obj, session, deleted=False):
136
    obj_mapper = object_mapper(obj)
137
    history_mapper = obj.__history_mapper__
138
    history_cls = history_mapper.class_
139
140
    obj_state = attributes.instance_state(obj)
141
142
    attr = {}
143
144
    obj_changed = False
145
146
    for om, hm in zip(obj_mapper.iterate_to_root(), history_mapper.iterate_to_root()):
147
        if hm.single:
148
            continue
149
150
        for hist_col in hm.local_table.c:
151
            if _is_versioning_col(hist_col):
152
                continue
153
154
            obj_col = om.local_table.c[hist_col.key]
155
156
            # get the value of the
157
            # attribute based on the MapperProperty related to the
158
            # mapped column.  this will allow usage of MapperProperties
159
            # that have a different keyname than that of the mapped column.
160
            try:
161
                prop = obj_mapper.get_property_by_column(obj_col)
162
            except UnmappedColumnError:
163
                # in the case of single table inheritance, there may be
164
                # columns on the mapped table intended for the subclass only.
165
                # the "unmapped" status of the subclass column on the
166
                # base class is a feature of the declarative module as of sqla 0.5.2.
167
                continue
168
169
            # expired object attributes and also deferred cols might not be in the
170
            # dict.  force it to load no matter what by using getattr().
171
            if prop.key not in obj_state.dict:
172
                getattr(obj, prop.key)
173
174
            a, u, d = attributes.get_history(obj, prop.key)
175
176
            if d:
177
                attr[hist_col.key] = d[0]
178
                obj_changed = True
179
            elif u:
180
                attr[hist_col.key] = u[0]
181
            else:
182
                # if the attribute had no value.
183
                attr[hist_col.key] = a[0]
184
                obj_changed = True
185
186
    if not obj_changed:
187
        # not changed, but we have relationships.  OK
188
        # check those too
189
        for prop in obj_mapper.iterate_properties:
190
            if isinstance(prop, RelationshipProperty) and \
191
                attributes.get_history(obj, prop.key,
192
                        passive=attributes.PASSIVE_NO_INITIALIZE).has_changes():
193
                for p in prop.local_columns:
194
                    if p.foreign_keys:
195
                        obj_changed = True
196
                        break
197
                if obj_changed is True:
198
                    break
199
200
    if not obj_changed and not deleted:
201
        return
202
203
    attr['version'] = obj.version
204
    hist = history_cls()
205
    for key, value in attr.items():
206
        setattr(hist, key, value)
207
    session.add(hist)
208
    obj.version += 1
209
210
def versioned_session(session):
211
    @event.listens_for(session, 'before_flush')
212
    def before_flush(session, flush_context, instances):
213
        for obj in versioned_objects(session.dirty):
214
            create_version(obj, session)
215
        for obj in versioned_objects(session.deleted):
216
            create_version(obj, session, deleted=True)
217
    return session
218