|
1
|
|
|
from logging import getLogger |
|
2
|
|
|
import time |
|
3
|
|
|
import collections |
|
4
|
|
|
|
|
5
|
|
|
import django |
|
6
|
|
|
from django.db import models |
|
7
|
|
|
from django.db.models import query |
|
8
|
|
|
try: |
|
9
|
|
|
from django.db.models.fields.related import ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor |
|
10
|
|
|
except ImportError: |
|
11
|
|
|
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor |
|
12
|
|
|
|
|
13
|
|
|
__version__ = '1.1.0' |
|
14
|
|
|
|
|
15
|
|
|
logger = getLogger(__name__) |
|
16
|
|
|
|
|
17
|
|
|
|
|
18
|
|
|
class PrefetchManagerMixin(models.Manager): |
|
19
|
|
|
use_for_related_fields = True |
|
20
|
|
|
prefetch_definitions = {} |
|
21
|
|
|
|
|
22
|
|
|
@classmethod |
|
23
|
|
|
def get_queryset_class(cls): |
|
24
|
|
|
return PrefetchQuerySet |
|
25
|
|
|
|
|
26
|
|
|
def __init__(self): |
|
27
|
|
|
super(PrefetchManagerMixin, self).__init__() |
|
28
|
|
|
for name, prefetcher in self.prefetch_definitions.items(): |
|
29
|
|
|
if prefetcher.__class__ is not Prefetcher and not callable(prefetcher): |
|
30
|
|
|
raise InvalidPrefetch("Invalid prefetch definition %s. This prefetcher needs to be a class not an instance." % name) |
|
31
|
|
|
|
|
32
|
|
|
def get_queryset(self): |
|
33
|
|
|
qs = self.get_queryset_class()( |
|
34
|
|
|
self.model, prefetch_definitions=self.prefetch_definitions |
|
35
|
|
|
) |
|
36
|
|
|
|
|
37
|
|
|
if getattr(self, '_db', None) is not None: |
|
38
|
|
|
qs = qs.using(self._db) |
|
39
|
|
|
return qs |
|
40
|
|
|
|
|
41
|
|
|
def get_query_set(self): |
|
42
|
|
|
""" |
|
43
|
|
|
Django <1.6 compatibility method. |
|
44
|
|
|
""" |
|
45
|
|
|
|
|
46
|
|
|
return self.get_queryset() |
|
47
|
|
|
|
|
48
|
|
|
def prefetch(self, *args): |
|
49
|
|
|
return self.get_queryset().prefetch(*args) |
|
50
|
|
|
|
|
51
|
|
|
|
|
52
|
|
|
class PrefetchManager(PrefetchManagerMixin): |
|
53
|
|
|
def __init__(self, **kwargs): |
|
54
|
|
|
self.prefetch_definitions = kwargs |
|
55
|
|
|
super(PrefetchManager, self).__init__() |
|
56
|
|
|
|
|
57
|
|
|
|
|
58
|
|
|
class InvalidPrefetch(Exception): |
|
59
|
|
|
pass |
|
60
|
|
|
|
|
61
|
|
|
|
|
62
|
|
|
class PrefetchOption(object): |
|
63
|
|
|
def __init__(self, name, *args, **kwargs): |
|
64
|
|
|
self.name = name |
|
65
|
|
|
self.args = args |
|
66
|
|
|
self.kwargs = kwargs |
|
67
|
|
|
|
|
68
|
|
|
P = PrefetchOption |
|
69
|
|
|
|
|
70
|
|
|
|
|
71
|
|
|
class PrefetchQuerySet(query.QuerySet): |
|
72
|
|
|
def __init__(self, model=None, query=None, using=None, |
|
73
|
|
|
prefetch_definitions=None, **kwargs): |
|
74
|
|
|
if using is None: # this is to support Django 1.1 |
|
75
|
|
|
super(PrefetchQuerySet, self).__init__(model, query, **kwargs) |
|
76
|
|
|
else: |
|
77
|
|
|
super(PrefetchQuerySet, self).__init__(model, query, using, **kwargs) |
|
78
|
|
|
self._prefetch = {} |
|
79
|
|
|
self.prefetch_definitions = prefetch_definitions |
|
80
|
|
|
|
|
81
|
|
|
def _clone(self, **kwargs): |
|
82
|
|
|
return super(PrefetchQuerySet, self). \ |
|
83
|
|
|
_clone(_prefetch=self._prefetch, |
|
84
|
|
|
prefetch_definitions=self.prefetch_definitions, **kwargs) |
|
85
|
|
|
|
|
86
|
|
|
def prefetch(self, *names): |
|
87
|
|
|
obj = self._clone() |
|
88
|
|
|
|
|
89
|
|
|
for opt in names: |
|
90
|
|
|
if isinstance(opt, PrefetchOption): |
|
91
|
|
|
name = opt.name |
|
92
|
|
|
else: |
|
93
|
|
|
name = opt |
|
94
|
|
|
opt = None |
|
95
|
|
|
parts = name.split('__') |
|
96
|
|
|
forwarders = [] |
|
97
|
|
|
prefetcher = None |
|
98
|
|
|
model = self.model |
|
99
|
|
|
prefetch_definitions = self.prefetch_definitions |
|
100
|
|
|
|
|
101
|
|
|
for what in parts: |
|
102
|
|
|
if not prefetcher: |
|
103
|
|
|
if what in prefetch_definitions: |
|
104
|
|
|
prefetcher = prefetch_definitions[what] |
|
105
|
|
|
continue |
|
106
|
|
|
descriptor = getattr(model, what, None) |
|
107
|
|
|
if isinstance(descriptor, ForwardManyToOneDescriptor): |
|
108
|
|
|
field = descriptor.field |
|
109
|
|
|
forwarders.append(field.name) |
|
110
|
|
|
if hasattr(field, 'remote_field'): |
|
111
|
|
|
model = field.remote_field.model |
|
112
|
|
|
else: |
|
113
|
|
|
model = field.rel.to |
|
114
|
|
|
manager = model.objects |
|
115
|
|
|
if not isinstance(manager, PrefetchManagerMixin): |
|
116
|
|
|
raise InvalidPrefetch('Manager for %s is not a PrefetchManagerMixin instance.' % model) |
|
117
|
|
|
prefetch_definitions = manager.prefetch_definitions |
|
118
|
|
|
else: |
|
119
|
|
|
raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. " |
|
120
|
|
|
"The name is not a prefetcher nor a forward relation (fk)." % ( |
|
121
|
|
|
what, name, self.model)) |
|
122
|
|
|
else: |
|
123
|
|
|
raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. " |
|
124
|
|
|
"You cannot have any more relations after the prefetcher." % ( |
|
125
|
|
|
what, name, self.model)) |
|
126
|
|
|
if not prefetcher: |
|
127
|
|
|
raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. " |
|
128
|
|
|
"The last part isn't a prefetch definition." % (name, self.model)) |
|
129
|
|
|
if opt: |
|
130
|
|
|
if prefetcher.__class__ is Prefetcher: |
|
131
|
|
|
raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. " |
|
132
|
|
|
"This prefetcher (%s) needs to be a subclass of Prefetcher." % ( |
|
133
|
|
|
name, self.model, prefetcher)) |
|
134
|
|
|
|
|
135
|
|
|
obj._prefetch[name] = forwarders, prefetcher(*opt.args, **opt.kwargs) |
|
136
|
|
|
else: |
|
137
|
|
|
obj._prefetch[name] = forwarders, prefetcher if prefetcher.__class__ is Prefetcher else prefetcher() |
|
138
|
|
|
|
|
139
|
|
|
for forwarders, prefetcher in obj._prefetch.values(): |
|
140
|
|
|
if forwarders: |
|
141
|
|
|
if django.VERSION < (1, 7) and obj.query.select_related: |
|
142
|
|
|
if not obj.query.max_depth: |
|
143
|
|
|
obj.query.add_select_related('__'.join(forwarders)) |
|
144
|
|
|
else: |
|
145
|
|
|
obj = obj.select_related('__'.join(forwarders)) |
|
146
|
|
|
return obj |
|
147
|
|
|
|
|
148
|
|
|
def iterator(self): |
|
149
|
|
|
data = list(super(PrefetchQuerySet, self).iterator()) |
|
150
|
|
|
for name, (forwarders, prefetcher) in self._prefetch.items(): |
|
151
|
|
|
prefetcher.fetch(data, name, self.model, forwarders, |
|
152
|
|
|
getattr(self, '_db', None)) |
|
153
|
|
|
return iter(data) |
|
154
|
|
|
|
|
155
|
|
|
|
|
156
|
|
|
class Prefetcher(object): |
|
157
|
|
|
""" |
|
158
|
|
|
Prefetch definitition. For convenience you can either subclass this and |
|
159
|
|
|
define the methods on the subclass or just pass the functions to the |
|
160
|
|
|
contructor. |
|
161
|
|
|
|
|
162
|
|
|
Eg, subclassing:: |
|
163
|
|
|
|
|
164
|
|
|
class GroupPrefetcher(Prefetcher): |
|
165
|
|
|
|
|
166
|
|
|
@staticmethod |
|
167
|
|
|
def filter(ids): |
|
168
|
|
|
return User.groups.through.objects.filter(user__in=ids).select_related('group') |
|
169
|
|
|
|
|
170
|
|
|
@staticmethod |
|
171
|
|
|
def reverse_mapper(user_group_association): |
|
172
|
|
|
return [user_group_association.user_id] |
|
173
|
|
|
|
|
174
|
|
|
@staticmethod |
|
175
|
|
|
def decorator(user, user_group_associations=()): |
|
176
|
|
|
setattr(user, 'prefetched_groups', [i.group for i in user_group_associations]) |
|
177
|
|
|
|
|
178
|
|
|
Or with contructor:: |
|
179
|
|
|
|
|
180
|
|
|
Prefetcher( |
|
181
|
|
|
filter = lambda ids: User.groups.through.objects.filter(user__in=ids).select_related('group'), |
|
182
|
|
|
reverse_mapper = lambda user_group_association: [user_group_association.user_id], |
|
183
|
|
|
decorator = lambda user, user_group_associations=(): setattr(user, 'prefetched_groups', [ |
|
184
|
|
|
i.group for i in user_group_associations |
|
185
|
|
|
]) |
|
186
|
|
|
) |
|
187
|
|
|
|
|
188
|
|
|
|
|
189
|
|
|
Glossary: |
|
190
|
|
|
|
|
191
|
|
|
* filter(list_of_ids): |
|
192
|
|
|
|
|
193
|
|
|
A function that returns a queryset containing all the related data for a given list of keys. |
|
194
|
|
|
Takes a list of ids as argument. |
|
195
|
|
|
|
|
196
|
|
|
* reverse_mapper(related_object): |
|
197
|
|
|
|
|
198
|
|
|
A function that takes the related object as argument and returns a list |
|
199
|
|
|
of keys that maps that related object to the objects in the queryset. |
|
200
|
|
|
|
|
201
|
|
|
* mapper(object): |
|
202
|
|
|
|
|
203
|
|
|
Optional (defaults to ``lambda obj: obj.id``). |
|
204
|
|
|
|
|
205
|
|
|
A function that returns the key for a given object in your query set. |
|
206
|
|
|
|
|
207
|
|
|
* decorator(object, list_of_related_objects): |
|
208
|
|
|
|
|
209
|
|
|
A function that will save the related data on each of your objects in |
|
210
|
|
|
your queryset. Takes the object and a list of related objects as |
|
211
|
|
|
arguments. Note that you should not override existing attributes on the |
|
212
|
|
|
model instance here. |
|
213
|
|
|
|
|
214
|
|
|
""" |
|
215
|
|
|
collect = False |
|
216
|
|
|
|
|
217
|
|
|
def __init__(self, filter=None, reverse_mapper=None, decorator=None, mapper=None, collect=None): |
|
218
|
|
|
if filter: |
|
219
|
|
|
self.filter = filter |
|
220
|
|
|
elif not hasattr(self, 'filter'): |
|
221
|
|
|
raise RuntimeError("You must define a filter function") |
|
222
|
|
|
|
|
223
|
|
|
if reverse_mapper: |
|
224
|
|
|
self.reverse_mapper = reverse_mapper |
|
225
|
|
|
elif not hasattr(self, 'reverse_mapper'): |
|
226
|
|
|
raise RuntimeError("You must define a reverse_mapper function") |
|
227
|
|
|
|
|
228
|
|
|
if decorator: |
|
229
|
|
|
self.decorator = decorator |
|
230
|
|
|
elif not hasattr(self, 'decorator'): |
|
231
|
|
|
raise RuntimeError("You must define a decorator function") |
|
232
|
|
|
|
|
233
|
|
|
if mapper: |
|
234
|
|
|
self.mapper = mapper |
|
235
|
|
|
|
|
236
|
|
|
if collect is not None: |
|
237
|
|
|
self.collect = collect |
|
238
|
|
|
|
|
239
|
|
|
@staticmethod |
|
240
|
|
|
def mapper(obj): |
|
241
|
|
|
return obj.id |
|
242
|
|
|
|
|
243
|
|
|
def fetch(self, dataset, name, model, forwarders, db): |
|
244
|
|
|
collect = self.collect or forwarders |
|
245
|
|
|
|
|
246
|
|
|
try: |
|
247
|
|
|
data_mapping = collections.defaultdict(list) |
|
248
|
|
|
t1 = time.time() |
|
249
|
|
|
for obj in dataset: |
|
250
|
|
|
for field in forwarders: |
|
251
|
|
|
obj = getattr(obj, field, None) |
|
252
|
|
|
|
|
253
|
|
|
if not obj: |
|
254
|
|
|
continue |
|
255
|
|
|
|
|
256
|
|
|
if collect: |
|
257
|
|
|
data_mapping[self.mapper(obj)].append(obj) |
|
258
|
|
|
else: |
|
259
|
|
|
data_mapping[self.mapper(obj)] = obj |
|
260
|
|
|
|
|
261
|
|
|
self.decorator(obj) |
|
262
|
|
|
|
|
263
|
|
|
t2 = time.time() |
|
264
|
|
|
logger.debug("Creating data_mapping for %s query took %.3f secs for the %s prefetcher.", |
|
265
|
|
|
model.__name__, t2-t1, name) |
|
266
|
|
|
t1 = time.time() |
|
267
|
|
|
related_data = self.filter(data_mapping.keys()) |
|
268
|
|
|
if db is not None: |
|
269
|
|
|
related_data = related_data.using(db) |
|
270
|
|
|
related_data_len = len(related_data) |
|
271
|
|
|
t2 = time.time() |
|
272
|
|
|
logger.debug("Filtering for %s related objects for %s query took %.3f secs for the %s prefetcher.", |
|
273
|
|
|
related_data_len, model.__name__, t2-t1, name) |
|
274
|
|
|
relation_mapping = collections.defaultdict(list) |
|
275
|
|
|
|
|
276
|
|
|
t1 = time.time() |
|
277
|
|
|
for obj in related_data: |
|
278
|
|
|
for id_ in self.reverse_mapper(obj): |
|
279
|
|
|
if id_: |
|
280
|
|
|
relation_mapping[id_].append(obj) |
|
281
|
|
|
for id_, related_items in relation_mapping.items(): |
|
282
|
|
|
if id_ in data_mapping: |
|
283
|
|
|
if collect: |
|
284
|
|
|
for item in data_mapping[id_]: |
|
285
|
|
|
self.decorator(item, related_items) |
|
286
|
|
|
else: |
|
287
|
|
|
self.decorator(data_mapping[id_], related_items) |
|
288
|
|
|
|
|
289
|
|
|
t2 = time.time() |
|
290
|
|
|
logger.debug("Adding the related objects on the %s query took %.3f secs for the %s prefetcher.", |
|
291
|
|
|
model.__name__, t2-t1, name) |
|
292
|
|
|
return dataset |
|
293
|
|
|
except Exception: |
|
294
|
|
|
logger.exception("Prefetch failed for %s prefetch on the %s model:", name, model.__name__) |
|
295
|
|
|
raise |
|
296
|
|
|
|