Passed
Pull Request — 2.x (#1864)
by Ramon
05:46
created

UIDReferenceField.get_object()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 9
rs 10
c 0
b 0
f 0
cc 2
nop 2
1
# -*- coding: utf-8 -*-
2
3
import six
4
5
from bika.lims import api
6
from persistent.dict import PersistentDict
7
from persistent.list import PersistentList
8
from senaite.core import logger
9
from senaite.core.schema.fields import BaseField
10
from senaite.core.schema.interfaces import IUIDReferenceField
11
from zope.annotation.interfaces import IAnnotations
12
from zope.interface import implementer
13
from zope.schema import ASCIILine
14
from Acquisition import aq_base
15
from zope.schema import List
16
17
BACKREFS_STORAGE = "senaite.core.schema.uidreferencefield.backreferences"
18
19
20
def get_backrefs(context, relationship, as_objects=False):
21
    """Return backreferences of the context
22
23
    :returns: List of UIDs that are linked by the relationship
24
    """
25
    context = aq_base(context)
26
    # get the backref annotation storage of the context
27
    backrefs = get_backref_storage(context)
28
    # get the referenced UIDs
29
    backref_uids = list(backrefs.get(relationship, []))
30
31
    if not backref_uids:
32
        return []
33
34
    if as_objects is True:
35
        return [api.get_object(uid) for uid in backref_uids]
36
37
    return backref_uids
38
39
40
def get_backref_storage(context):
41
    """Get the annotation storage for backreferences of the context
42
    """
43
    annotation = IAnnotations(context)
44
    if annotation.get(BACKREFS_STORAGE) is None:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable BACKREFS_STORAGE does not seem to be defined.
Loading history...
45
        annotation[BACKREFS_STORAGE] = PersistentDict()
46
    return annotation[BACKREFS_STORAGE]
47
48
49
@implementer(IUIDReferenceField)
50
class UIDReferenceField(List, BaseField):
51
    """Stores UID references to other objects
52
    """
53
54
    value_type = ASCIILine(title=u"UID")
55
56
    def __init__(self, allowed_types=None, multi_valued=True, **kw):
57
        if allowed_types is None:
58
            allowed_types = ()
59
        self.allowed_types = allowed_types
60
        self.multi_valued = multi_valued
61
        super(UIDReferenceField, self).__init__(**kw)
62
63
    def get_relationship_key(self, context):
64
        """Relationship key used for backreferences
65
66
        The key used for the annotation storage on the referenced object to
67
        remember the current object UID.
68
69
        :returns: storage key to lookup back references
70
        """
71
        portal_type = api.get_portal_type(context)
72
        return "%s.%s" % (portal_type, self.__name__)
73
74
    def get_uid(self, value):
75
        """Value -> UID
76
77
        :parm value: object/UID/SuperModel
78
        :returns: UID
79
        """
80
        try:
81
            return api.get_uid(value)
82
        except api.APIError:
83
            return None
84
85
    def get_object(self, value):
86
        """Value -> object
87
88
        :returns: Object or None
89
        """
90
        try:
91
            return api.get_object(value)
92
        except api.APIError:
93
            return None
94
95
    def get_allowed_types(self):
96
        """Returns the allowed reference types
97
98
        :returns: tuple of allowed_types
99
        """
100
        allowed_types = self.allowed_types
101
        if not allowed_types:
102
            allowed_types = ()
103
        elif isinstance(allowed_types, six.string_types):
104
            allowed_types = (allowed_types, )
105
        return allowed_types
106
107
    def set(self, object, value):
108
        """Set UID reference
109
110
        :param object: the instance of the field
111
        :param value: object/UID/SuperModel
112
        :type value: list/tuple/str
113
        """
114
115
        # always handle all values internally as a list
116
        if isinstance(value, six.string_types):
117
            value = [value]
118
        elif api.is_object(value):
119
            value = [value]
120
        elif value is None:
121
            value = []
122
123
        # convert to UIDs
124
        uids = []
125
        for v in value:
126
            uid = self.get_uid(v)
127
            if uid is None:
128
                continue
129
            uids.append(uid)
130
131
        # current set UIDs
132
        existing = self.get_raw(object)
133
134
        # filter out new/removed UIDs
135
        added_uids = [u for u in uids if u not in existing]
136
        added_objs = filter(None, map(self.get_object, added_uids))
137
138
        removed_uids = [u for u in existing if u not in uids]
139
        removed_objs = filter(None, map(self.get_object, removed_uids))
140
141
        # link backreferences of new uids
142
        for added_obj in added_objs:
143
            self.link_backref(added_obj, object)
144
145
        # unlink backreferences of removed UIDs
146
        for removed_obj in removed_objs:
147
            self.unlink_backref(removed_obj, object)
148
149
        super(UIDReferenceField, self).set(object, uids)
150
151
    def unlink_backref(self, source, target):
152
        """Remove backreference from the source to the target
153
154
        :param source: the object where the backref is stored (our reference)
155
        :param target: the object where the backref points to (our object)
156
        :returns: True when the backref was removed, False otherwise
157
        """
158
        target_uid = self.get_uid(target)
159
        # get the storage key
160
        key = self.get_relationship_key(target)
161
        # get all backreferences from the source
162
        backrefs = get_backref_storage(source)
163
        if key not in backrefs:
164
            logger.warn(
165
                "Referenced object {} has no backreferences for the key {}"
166
                .format(repr(source), key))
167
            return False
168
        if target_uid not in backrefs[key]:
169
            logger.warn("Target {} was not linked by {}"
170
                        .format(repr(target), repr(source)))
171
            return False
172
        backrefs[key].remove(target_uid)
173
        return True
174
175
    def link_backref(self, source, target):
176
        """Add backreference from the source to the target
177
178
        :param source: the object where the backref is stored (our reference)
179
        :param target: the object where the backref points to (our object)
180
        :returns: True when the backref was written
181
        """
182
        target_uid = api.get_uid(target)
183
        # get the annotation storage key
184
        key = self.get_relationship_key(target)
185
        # get all backreferences
186
        backrefs = get_backref_storage(source)
187
        if key not in backrefs:
188
            backrefs[key] = PersistentList()
189
        if target_uid not in backrefs[key]:
190
            backrefs[key].append(target_uid)
191
        return True
192
193
    def get(self, object):
194
        """Get referenced objects
195
196
        :param object: instance of the field
197
        :returns: list of referenced objects
198
        """
199
        return self._get(object, as_objects=True)
200
201
    def get_raw(self, object):
202
        """Get referenced UIDs
203
204
        NOTE: Called from the data manager `query` method
205
              to get the widget value
206
207
        :param object: instance of the field
208
        :returns: list of referenced UIDs
209
        """
210
        return self._get(object, as_objects=False)
211
212
    def _get(self, object, as_objects=False):
213
        """Returns single/multi value
214
215
        :param object: instance of the field
216
        :param as_objects: Flag for UID/object returns
217
        :returns: list of referenced UIDs
218
        """
219
        uids = super(UIDReferenceField, self).get(object)
220
221
        if not uids:
222
            uids = []
223
224
        if as_objects is True:
225
            uids = filter(None, map(self.get_object, uids))
226
227
        if self.multi_valued:
228
            return uids
229
        if len(uids) == 0:
230
            return None
231
        return uids[0]
232
233
    def _validate(self, value):
234
        """Validator when called from form submission
235
        """
236
        super(UIDReferenceField, self)._validate(value)
237
        # check if the fields accepts single values only
238
        if not self.multi_valued and len(value) > 1:
239
            raise ValueError("Single valued field accepts at most 1 value")
240
241
        # check for valid UIDs
242
        for uid in value:
243
            if not api.is_uid(uid):
244
                raise ValueError("Invalid UID: '%s'" % uid)
245
246
        # check if the type is allowed
247
        allowed_types = self.get_allowed_types()
248
        if allowed_types:
249
            objs = filter(None, map(self.get_object, value))
250
            types = set(map(api.get_portal_type, objs))
251
            if not types.issubset(allowed_types):
252
                raise ValueError("Only the following types are allowed: %s"
253
                                 % ",".join(allowed_types))
254