src.Prefetcher.__init__()   C
last analyzed

Complexity

Conditions 9

Size

Total Lines 21

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 9
dl 0
loc 21
rs 5.4999
1
from logging import getLogger
2
import time
3
import collections
4
5
import django
6
from django.db import models
7
from django.db.models import query
8
try:
9
    from django.db.models.fields.related import ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor
10
except ImportError:
11
    from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor
12
13
__version__ = '1.1.0'
14
15
logger = getLogger(__name__)
16
17
18
class PrefetchManagerMixin(models.Manager):
19
    use_for_related_fields = True
20
    prefetch_definitions = {}
21
22
    @classmethod
23
    def get_queryset_class(cls):
24
        return PrefetchQuerySet
25
26
    def __init__(self):
27
        super(PrefetchManagerMixin, self).__init__()
28
        for name, prefetcher in self.prefetch_definitions.items():
29
            if prefetcher.__class__ is not Prefetcher and not callable(prefetcher):
30
                raise InvalidPrefetch("Invalid prefetch definition %s. This prefetcher needs to be a class not an instance." % name)
31
32
    def get_queryset(self):
33
        qs = self.get_queryset_class()(
34
            self.model, prefetch_definitions=self.prefetch_definitions
35
        )
36
37
        if getattr(self, '_db', None) is not None:
38
            qs = qs.using(self._db)
39
        return qs
40
41
    def get_query_set(self):
42
        """
43
        Django <1.6 compatibility method.
44
        """
45
46
        return self.get_queryset()
47
48
    def prefetch(self, *args):
49
        return self.get_queryset().prefetch(*args)
50
51
52
class PrefetchManager(PrefetchManagerMixin):
53
    def __init__(self, **kwargs):
54
        self.prefetch_definitions = kwargs
55
        super(PrefetchManager, self).__init__()
56
57
58
class InvalidPrefetch(Exception):
59
    pass
60
61
62
class PrefetchOption(object):
63
    def __init__(self, name, *args, **kwargs):
64
        self.name = name
65
        self.args = args
66
        self.kwargs = kwargs
67
68
P = PrefetchOption
69
70
71
class PrefetchQuerySet(query.QuerySet):
72
    def __init__(self, model=None, query=None, using=None,
73
                 prefetch_definitions=None, **kwargs):
74
        if using is None:  # this is to support Django 1.1
75
            super(PrefetchQuerySet, self).__init__(model, query, **kwargs)
76
        else:
77
            super(PrefetchQuerySet, self).__init__(model, query, using, **kwargs)
78
        self._prefetch = {}
79
        self.prefetch_definitions = prefetch_definitions
80
81
    def _clone(self, **kwargs):
82
        return super(PrefetchQuerySet, self). \
83
            _clone(_prefetch=self._prefetch,
84
                   prefetch_definitions=self.prefetch_definitions, **kwargs)
85
86
    def prefetch(self, *names):
87
        obj = self._clone()
88
89
        for opt in names:
90
            if isinstance(opt, PrefetchOption):
91
                name = opt.name
92
            else:
93
                name = opt
94
                opt = None
95
            parts = name.split('__')
96
            forwarders = []
97
            prefetcher = None
98
            model = self.model
99
            prefetch_definitions = self.prefetch_definitions
100
101
            for what in parts:
102
                if not prefetcher:
103
                    if what in prefetch_definitions:
104
                        prefetcher = prefetch_definitions[what]
105
                        continue
106
                    descriptor = getattr(model, what, None)
107
                    if isinstance(descriptor, ForwardManyToOneDescriptor):
108
                        field = descriptor.field
109
                        forwarders.append(field.name)
110
                        if hasattr(field, 'remote_field'):
111
                            model = field.remote_field.model
112
                        else:
113
                            model = field.rel.to
114
                        manager = model.objects
115
                        if not isinstance(manager, PrefetchManagerMixin):
116
                            raise InvalidPrefetch('Manager for %s is not a PrefetchManagerMixin instance.' % model)
117
                        prefetch_definitions = manager.prefetch_definitions
118
                    else:
119
                        raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. "
120
                                              "The name is not a prefetcher nor a forward relation (fk)." % (
121
                                                  what, name, self.model))
122
                else:
123
                    raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. "
124
                                          "You cannot have any more relations after the prefetcher." % (
125
                                              what, name, self.model))
126
            if not prefetcher:
127
                raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. "
128
                                      "The last part isn't a prefetch definition." % (name, self.model))
129
            if opt:
130
                if prefetcher.__class__ is Prefetcher:
131
                    raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. "
132
                                          "This prefetcher (%s) needs to be a subclass of Prefetcher." % (
133
                                              name, self.model, prefetcher))
134
135
                obj._prefetch[name] = forwarders, prefetcher(*opt.args, **opt.kwargs)
136
            else:
137
                obj._prefetch[name] = forwarders, prefetcher if prefetcher.__class__ is Prefetcher else prefetcher()
138
139
        for forwarders, prefetcher in obj._prefetch.values():
140
            if forwarders:
141
                if django.VERSION < (1, 7) and obj.query.select_related:
142
                    if not obj.query.max_depth:
143
                        obj.query.add_select_related('__'.join(forwarders))
144
                else:
145
                    obj = obj.select_related('__'.join(forwarders))
146
        return obj
147
148
    def iterator(self):
149
        data = list(super(PrefetchQuerySet, self).iterator())
150
        for name, (forwarders, prefetcher) in self._prefetch.items():
151
            prefetcher.fetch(data, name, self.model, forwarders,
152
                             getattr(self, '_db', None))
153
        return iter(data)
154
155
156
class Prefetcher(object):
157
    """
158
    Prefetch definitition. For convenience you can either subclass this and
159
    define the methods on the subclass or just pass the functions to the
160
    contructor.
161
162
    Eg, subclassing::
163
164
        class GroupPrefetcher(Prefetcher):
165
166
            @staticmethod
167
            def filter(ids):
168
                return User.groups.through.objects.filter(user__in=ids).select_related('group')
169
170
            @staticmethod
171
            def reverse_mapper(user_group_association):
172
                return [user_group_association.user_id]
173
174
            @staticmethod
175
            def decorator(user, user_group_associations=()):
176
                setattr(user, 'prefetched_groups', [i.group for i in user_group_associations])
177
178
    Or with contructor::
179
180
        Prefetcher(
181
            filter = lambda ids: User.groups.through.objects.filter(user__in=ids).select_related('group'),
182
            reverse_mapper = lambda user_group_association: [user_group_association.user_id],
183
            decorator = lambda user, user_group_associations=(): setattr(user, 'prefetched_groups', [
184
                i.group for i in user_group_associations
185
            ])
186
        )
187
188
189
    Glossary:
190
191
    * filter(list_of_ids):
192
193
        A function that returns a queryset containing all the related data for a given list of keys.
194
        Takes a list of ids as argument.
195
196
    * reverse_mapper(related_object):
197
198
        A function that takes the related object as argument and returns a list
199
        of keys that maps that related object to the objects in the queryset.
200
201
    * mapper(object):
202
203
        Optional (defaults to ``lambda obj: obj.id``).
204
205
        A function that returns the key for a given object in your query set.
206
207
    * decorator(object, list_of_related_objects):
208
209
        A function that will save the related data on each of your objects in
210
        your queryset. Takes the object and a list of related objects as
211
        arguments. Note that you should not override existing attributes on the
212
        model instance here.
213
214
    """
215
    collect = False
216
217
    def __init__(self, filter=None, reverse_mapper=None, decorator=None, mapper=None, collect=None):
218
        if filter:
219
            self.filter = filter
220
        elif not hasattr(self, 'filter'):
221
            raise RuntimeError("You must define a filter function")
222
223
        if reverse_mapper:
224
            self.reverse_mapper = reverse_mapper
225
        elif not hasattr(self, 'reverse_mapper'):
226
            raise RuntimeError("You must define a reverse_mapper function")
227
228
        if decorator:
229
            self.decorator = decorator
230
        elif not hasattr(self, 'decorator'):
231
            raise RuntimeError("You must define a decorator function")
232
233
        if mapper:
234
            self.mapper = mapper
235
236
        if collect is not None:
237
            self.collect = collect
238
239
    @staticmethod
240
    def mapper(obj):
241
        return obj.id
242
243
    def fetch(self, dataset, name, model, forwarders, db):
244
        collect = self.collect or forwarders
245
246
        try:
247
            data_mapping = collections.defaultdict(list)
248
            t1 = time.time()
249
            for obj in dataset:
250
                for field in forwarders:
251
                    obj = getattr(obj, field, None)
252
253
                if not obj:
254
                    continue
255
256
                if collect:
257
                    data_mapping[self.mapper(obj)].append(obj)
258
                else:
259
                    data_mapping[self.mapper(obj)] = obj
260
261
                self.decorator(obj)
262
263
            t2 = time.time()
264
            logger.debug("Creating data_mapping for %s query took %.3f secs for the %s prefetcher.",
265
                         model.__name__, t2-t1, name)
266
            t1 = time.time()
267
            related_data = self.filter(data_mapping.keys())
268
            if db is not None:
269
                related_data = related_data.using(db)
270
            related_data_len = len(related_data)
271
            t2 = time.time()
272
            logger.debug("Filtering for %s related objects for %s query took %.3f secs for the %s prefetcher.",
273
                         related_data_len, model.__name__, t2-t1, name)
274
            relation_mapping = collections.defaultdict(list)
275
276
            t1 = time.time()
277
            for obj in related_data:
278
                for id_ in self.reverse_mapper(obj):
279
                    if id_:
280
                        relation_mapping[id_].append(obj)
281
            for id_, related_items in relation_mapping.items():
282
                if id_ in data_mapping:
283
                    if collect:
284
                        for item in data_mapping[id_]:
285
                            self.decorator(item, related_items)
286
                    else:
287
                        self.decorator(data_mapping[id_], related_items)
288
289
            t2 = time.time()
290
            logger.debug("Adding the related objects on the %s query took %.3f secs for the %s prefetcher.",
291
                         model.__name__, t2-t1, name)
292
            return dataset
293
        except Exception:
294
            logger.exception("Prefetch failed for %s prefetch on the %s model:", name, model.__name__)
295
            raise
296