Passed
Pull Request — 2.x (#1864)
by Ramon
05:59
created

senaite.core.schema.uidreferencefield.get_brefs()   A

Complexity

Conditions 3

Size

Total Lines 18
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 18
rs 9.95
c 0
b 0
f 0
cc 3
nop 3
1
# -*- coding: utf-8 -*-
2
3
import six
4
5
from bika.lims import api
6
from persistent.dict import PersistentDict
7
from persistent.list import PersistentList
8
from senaite.core import logger
9
from senaite.core.schema.fields import BaseField
10
from senaite.core.schema.interfaces import IUIDReferenceField
11
from zope.annotation.interfaces import IAnnotations
12
from zope.interface import implementer
13
from zope.schema import ASCIILine
14
from Acquisition import aq_base
15
from zope.schema import List
16
from zope.schema.interfaces import IFromUnicode
17
18
BACKREFS_STORAGE = "senaite.core.schema.uidreferencefield.backreferences"
19
20
21
def get_brefs(context, relationship, as_objects=False):
22
    """Return backreferences of the context
23
24
    :returns: List of UIDs that are linked by the relationship
25
    """
26
    context = aq_base(context)
27
    # get the bref annotation storage of the context
28
    brefs = get_bref_storage(context)
29
    # get the referenced UIDs
30
    bref_uids = list(brefs.get(relationship, []))
31
32
    if not bref_uids:
33
        return []
34
35
    if as_objects is True:
36
        return [api.get_object(uid) for uid in bref_uids]
37
38
    return bref_uids
39
40
41
def get_bref_storage(context):
42
    """Get the annotation storage for backreferences of the context
43
    """
44
    annotation = IAnnotations(context)
45
    if annotation.get(BACKREFS_STORAGE) is None:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable BACKREFS_STORAGE does not seem to be defined.
Loading history...
46
        annotation[BACKREFS_STORAGE] = PersistentDict()
47
    return annotation[BACKREFS_STORAGE]
48
49
50
@implementer(IUIDReferenceField, IFromUnicode)
51
class UIDReferenceField(List, BaseField):
52
    """Stores UID references to other objects
53
    """
54
55
    value_type = ASCIILine(title=u"UID")
56
57
    def __init__(self, allowed_types=None, multi_valued=True, **kw):
58
        if allowed_types is None:
59
            allowed_types = ()
60
        self.allowed_types = allowed_types
61
        self.multi_valued = multi_valued
62
        super(UIDReferenceField, self).__init__(**kw)
63
64
    def get_relationship_key(self, context):
65
        """Relationship key used for backreferences
66
67
        The key used for the annotation storage on the referenced object to
68
        remember the current object UID.
69
70
        :returns: storage key to lookup back references
71
        """
72
        portal_type = api.get_portal_type(context)
73
        return "%s.%s" % (portal_type, self.__name__)
74
75
    def to_uid(self, value):
76
        """convert a value to an UID
77
78
        :parm value: object/UID/SuperModel
79
        :returns: UID
80
        """
81
        try:
82
            return api.get_uid(value)
83
        except api.APIError:
84
            raise TypeError("Can not get UID of '%s'" % repr(value))
85
86
    def get_allowed_types(self):
87
        """Returns the allowed reference types
88
89
        :returns: tuple of allowed_types
90
        """
91
        allowed_types = self.allowed_types
92
        if not allowed_types:
93
            allowed_types = ()
94
        elif isinstance(allowed_types, six.string_types):
95
            allowed_types = (allowed_types, )
96
        return allowed_types
97
98
    def set(self, object, value):
99
        """Set UID reference
100
101
        :param object: the instance of the field
102
        :param value: object/UID/SuperModel
103
        :type value: list/tuple/str
104
        """
105
106
        # always handle all values internally as a list
107
        if isinstance(value, six.string_types):
108
            value = [value]
109
        elif api.is_object(value):
110
            value = [value]
111
112
        # check if the fields accepts single values only
113
        if not self.multi_valued and len(value) > 1:
114
            raise TypeError("Single valued field accepts at most 1 value")
115
116
        # check if the type is allowed
117
        allowed_types = self.get_allowed_types()
118
        if allowed_types:
119
            objs = map(api.get_object, value)
120
            types = set(map(api.get_portal_type, objs))
121
            if not types.issubset(allowed_types):
122
                raise TypeError("Only the following types are allowed: %s"
123
                                % ",".join(allowed_types))
124
125
        # convert to UIDs
126
        uids = []
127
        for v in value:
128
            uid = self.to_uid(v)
129
            uids.append(uid)
130
131
        # current set UIDs
132
        existing = self.get_raw(object)
133
134
        # filter out new/removed UIDs
135
        added_uids = [u for u in uids if u not in existing]
136
        removed_uids = [u for u in existing if u not in uids]
137
138
        # link backreferences of new uids
139
        for uid in added_uids:
140
            self.link_bref(api.get_object(uid), object)
141
142
        # unlink backreferences of removed UIDs
143
        for uid in removed_uids:
144
            self.unlink_bref(api.get_object(uid), object)
145
146
        super(UIDReferenceField, self).set(object, uids)
147
148
    def unlink_bref(self, source, target):
149
        """Remove backreference from the source to the target
150
151
        :param source: the object where the bref is stored (our reference)
152
        :param target: the object where the bref points to (our object)
153
        :returns: True when the bref was removed, False otherwise
154
        """
155
        target_uid = api.get_uid(target)
156
        # get the storage key
157
        key = self.get_relationship_key(target)
158
        # get all backreferences from the source
159
        brefs = get_bref_storage(source)
160
        if key not in brefs:
161
            logger.warn(
162
                "Referenced object {} has no backreferences for the key {}"
163
                .format(repr(source), key))
164
            return False
165
        if target_uid not in brefs[key]:
166
            logger.warn("Target {} was not linked by {}"
167
                        .format(repr(target), repr(source)))
168
            return False
169
        brefs[key].remove(target_uid)
170
        return True
171
172
    def link_bref(self, source, target):
173
        """Add backreference from the source to the target
174
175
        :param source: the object where the bref is stored (our reference)
176
        :param target: the object where the bref points to (our object)
177
        :returns: True when the bref was written
178
        """
179
        target_uid = api.get_uid(target)
180
        # get the annotation storage key
181
        key = self.get_relationship_key(target)
182
        # get all backreferences
183
        brefs = get_bref_storage(source)
184
        if key not in brefs:
185
            brefs[key] = PersistentList()
186
        if target_uid not in brefs[key]:
187
            brefs[key].append(target_uid)
188
        return True
189
190
    def get(self, object):
191
        """Get referenced objects
192
193
        :param object: instance of the field
194
        :returns: list of referenced objects
195
        """
196
        return self._get(object, as_objects=True)
197
198
    def get_raw(self, object):
199
        """Get referenced UIDs
200
201
        :param object: instance of the field
202
        :returns: list of referenced UIDs
203
        """
204
        return self._get(object, as_objects=False)
205
206
    def _get(self, object, as_objects=False):
207
        """Returns single/multi value
208
209
        :param object: instance of the field
210
        :param as_objects: Flag for UID/object returns
211
        :returns: list of referenced UIDs
212
        """
213
        uids = super(UIDReferenceField, self).get(object)
214
215
        if not uids:
216
            uids = []
217
218
        if as_objects is True:
219
            uids = map(api.get_object_by_uid, uids)
220
221
        if self.multi_valued:
222
            return uids
223
        if len(uids) == 0:
224
            return None
225
        return uids[0]
226
227
    def fromUnicode(self, value):
228
        self.validate(value)
229
        return value
230