1
|
|
|
from logging import getLogger |
2
|
|
|
import time |
3
|
|
|
import collections |
4
|
|
|
|
5
|
|
|
import django |
6
|
|
|
from django.db import models |
7
|
|
|
from django.db.models import query |
8
|
|
|
try: |
9
|
|
|
from django.db.models.fields.related import ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor |
10
|
|
|
except ImportError: |
11
|
|
|
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor |
12
|
|
|
|
13
|
|
|
__version__ = '1.1.0' |
14
|
|
|
|
15
|
|
|
logger = getLogger(__name__) |
16
|
|
|
|
17
|
|
|
|
18
|
|
|
class PrefetchManagerMixin(models.Manager): |
19
|
|
|
use_for_related_fields = True |
20
|
|
|
prefetch_definitions = {} |
21
|
|
|
|
22
|
|
|
@classmethod |
23
|
|
|
def get_queryset_class(cls): |
24
|
|
|
return PrefetchQuerySet |
25
|
|
|
|
26
|
|
|
def __init__(self): |
27
|
|
|
super(PrefetchManagerMixin, self).__init__() |
28
|
|
|
for name, prefetcher in self.prefetch_definitions.items(): |
29
|
|
|
if prefetcher.__class__ is not Prefetcher and not callable(prefetcher): |
30
|
|
|
raise InvalidPrefetch("Invalid prefetch definition %s. This prefetcher needs to be a class not an instance." % name) |
31
|
|
|
|
32
|
|
|
def get_queryset(self): |
33
|
|
|
qs = self.get_queryset_class()( |
34
|
|
|
self.model, prefetch_definitions=self.prefetch_definitions |
35
|
|
|
) |
36
|
|
|
|
37
|
|
|
if getattr(self, '_db', None) is not None: |
38
|
|
|
qs = qs.using(self._db) |
39
|
|
|
return qs |
40
|
|
|
|
41
|
|
|
def get_query_set(self): |
42
|
|
|
""" |
43
|
|
|
Django <1.6 compatibility method. |
44
|
|
|
""" |
45
|
|
|
|
46
|
|
|
return self.get_queryset() |
47
|
|
|
|
48
|
|
|
def prefetch(self, *args): |
49
|
|
|
return self.get_queryset().prefetch(*args) |
50
|
|
|
|
51
|
|
|
|
52
|
|
|
class PrefetchManager(PrefetchManagerMixin): |
53
|
|
|
def __init__(self, **kwargs): |
54
|
|
|
self.prefetch_definitions = kwargs |
55
|
|
|
super(PrefetchManager, self).__init__() |
56
|
|
|
|
57
|
|
|
|
58
|
|
|
class InvalidPrefetch(Exception): |
59
|
|
|
pass |
60
|
|
|
|
61
|
|
|
|
62
|
|
|
class PrefetchOption(object): |
63
|
|
|
def __init__(self, name, *args, **kwargs): |
64
|
|
|
self.name = name |
65
|
|
|
self.args = args |
66
|
|
|
self.kwargs = kwargs |
67
|
|
|
|
68
|
|
|
P = PrefetchOption |
69
|
|
|
|
70
|
|
|
|
71
|
|
|
class PrefetchQuerySet(query.QuerySet): |
72
|
|
|
def __init__(self, model=None, query=None, using=None, |
73
|
|
|
prefetch_definitions=None, **kwargs): |
74
|
|
|
if using is None: # this is to support Django 1.1 |
75
|
|
|
super(PrefetchQuerySet, self).__init__(model, query, **kwargs) |
76
|
|
|
else: |
77
|
|
|
super(PrefetchQuerySet, self).__init__(model, query, using, **kwargs) |
78
|
|
|
self._prefetch = {} |
79
|
|
|
self.prefetch_definitions = prefetch_definitions |
80
|
|
|
|
81
|
|
|
def _clone(self, **kwargs): |
82
|
|
|
return super(PrefetchQuerySet, self). \ |
83
|
|
|
_clone(_prefetch=self._prefetch, |
84
|
|
|
prefetch_definitions=self.prefetch_definitions, **kwargs) |
85
|
|
|
|
86
|
|
|
def prefetch(self, *names): |
87
|
|
|
obj = self._clone() |
88
|
|
|
|
89
|
|
|
for opt in names: |
90
|
|
|
if isinstance(opt, PrefetchOption): |
91
|
|
|
name = opt.name |
92
|
|
|
else: |
93
|
|
|
name = opt |
94
|
|
|
opt = None |
95
|
|
|
parts = name.split('__') |
96
|
|
|
forwarders = [] |
97
|
|
|
prefetcher = None |
98
|
|
|
model = self.model |
99
|
|
|
prefetch_definitions = self.prefetch_definitions |
100
|
|
|
|
101
|
|
|
for what in parts: |
102
|
|
|
if not prefetcher: |
103
|
|
|
if what in prefetch_definitions: |
104
|
|
|
prefetcher = prefetch_definitions[what] |
105
|
|
|
continue |
106
|
|
|
descriptor = getattr(model, what, None) |
107
|
|
|
if isinstance(descriptor, ForwardManyToOneDescriptor): |
108
|
|
|
field = descriptor.field |
109
|
|
|
forwarders.append(field.name) |
110
|
|
|
if hasattr(field, 'remote_field'): |
111
|
|
|
model = field.remote_field.model |
112
|
|
|
else: |
113
|
|
|
model = field.rel.to |
114
|
|
|
manager = model.objects |
115
|
|
|
if not isinstance(manager, PrefetchManagerMixin): |
116
|
|
|
raise InvalidPrefetch('Manager for %s is not a PrefetchManagerMixin instance.' % model) |
117
|
|
|
prefetch_definitions = manager.prefetch_definitions |
118
|
|
|
else: |
119
|
|
|
raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. " |
120
|
|
|
"The name is not a prefetcher nor a forward relation (fk)." % ( |
121
|
|
|
what, name, self.model)) |
122
|
|
|
else: |
123
|
|
|
raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. " |
124
|
|
|
"You cannot have any more relations after the prefetcher." % ( |
125
|
|
|
what, name, self.model)) |
126
|
|
|
if not prefetcher: |
127
|
|
|
raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. " |
128
|
|
|
"The last part isn't a prefetch definition." % (name, self.model)) |
129
|
|
|
if opt: |
130
|
|
|
if prefetcher.__class__ is Prefetcher: |
131
|
|
|
raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. " |
132
|
|
|
"This prefetcher (%s) needs to be a subclass of Prefetcher." % ( |
133
|
|
|
name, self.model, prefetcher)) |
134
|
|
|
|
135
|
|
|
obj._prefetch[name] = forwarders, prefetcher(*opt.args, **opt.kwargs) |
136
|
|
|
else: |
137
|
|
|
obj._prefetch[name] = forwarders, prefetcher if prefetcher.__class__ is Prefetcher else prefetcher() |
138
|
|
|
|
139
|
|
|
for forwarders, prefetcher in obj._prefetch.values(): |
140
|
|
|
if forwarders: |
141
|
|
|
if django.VERSION < (1, 7) and obj.query.select_related: |
142
|
|
|
if not obj.query.max_depth: |
143
|
|
|
obj.query.add_select_related('__'.join(forwarders)) |
144
|
|
|
else: |
145
|
|
|
obj = obj.select_related('__'.join(forwarders)) |
146
|
|
|
return obj |
147
|
|
|
|
148
|
|
|
def iterator(self): |
149
|
|
|
data = list(super(PrefetchQuerySet, self).iterator()) |
150
|
|
|
for name, (forwarders, prefetcher) in self._prefetch.items(): |
151
|
|
|
prefetcher.fetch(data, name, self.model, forwarders, |
152
|
|
|
getattr(self, '_db', None)) |
153
|
|
|
return iter(data) |
154
|
|
|
|
155
|
|
|
|
156
|
|
|
class Prefetcher(object): |
157
|
|
|
""" |
158
|
|
|
Prefetch definitition. For convenience you can either subclass this and |
159
|
|
|
define the methods on the subclass or just pass the functions to the |
160
|
|
|
contructor. |
161
|
|
|
|
162
|
|
|
Eg, subclassing:: |
163
|
|
|
|
164
|
|
|
class GroupPrefetcher(Prefetcher): |
165
|
|
|
|
166
|
|
|
@staticmethod |
167
|
|
|
def filter(ids): |
168
|
|
|
return User.groups.through.objects.filter(user__in=ids).select_related('group') |
169
|
|
|
|
170
|
|
|
@staticmethod |
171
|
|
|
def reverse_mapper(user_group_association): |
172
|
|
|
return [user_group_association.user_id] |
173
|
|
|
|
174
|
|
|
@staticmethod |
175
|
|
|
def decorator(user, user_group_associations=()): |
176
|
|
|
setattr(user, 'prefetched_groups', [i.group for i in user_group_associations]) |
177
|
|
|
|
178
|
|
|
Or with contructor:: |
179
|
|
|
|
180
|
|
|
Prefetcher( |
181
|
|
|
filter = lambda ids: User.groups.through.objects.filter(user__in=ids).select_related('group'), |
182
|
|
|
reverse_mapper = lambda user_group_association: [user_group_association.user_id], |
183
|
|
|
decorator = lambda user, user_group_associations=(): setattr(user, 'prefetched_groups', [ |
184
|
|
|
i.group for i in user_group_associations |
185
|
|
|
]) |
186
|
|
|
) |
187
|
|
|
|
188
|
|
|
|
189
|
|
|
Glossary: |
190
|
|
|
|
191
|
|
|
* filter(list_of_ids): |
192
|
|
|
|
193
|
|
|
A function that returns a queryset containing all the related data for a given list of keys. |
194
|
|
|
Takes a list of ids as argument. |
195
|
|
|
|
196
|
|
|
* reverse_mapper(related_object): |
197
|
|
|
|
198
|
|
|
A function that takes the related object as argument and returns a list |
199
|
|
|
of keys that maps that related object to the objects in the queryset. |
200
|
|
|
|
201
|
|
|
* mapper(object): |
202
|
|
|
|
203
|
|
|
Optional (defaults to ``lambda obj: obj.id``). |
204
|
|
|
|
205
|
|
|
A function that returns the key for a given object in your query set. |
206
|
|
|
|
207
|
|
|
* decorator(object, list_of_related_objects): |
208
|
|
|
|
209
|
|
|
A function that will save the related data on each of your objects in |
210
|
|
|
your queryset. Takes the object and a list of related objects as |
211
|
|
|
arguments. Note that you should not override existing attributes on the |
212
|
|
|
model instance here. |
213
|
|
|
|
214
|
|
|
""" |
215
|
|
|
collect = False |
216
|
|
|
|
217
|
|
|
def __init__(self, filter=None, reverse_mapper=None, decorator=None, mapper=None, collect=None): |
218
|
|
|
if filter: |
219
|
|
|
self.filter = filter |
220
|
|
|
elif not hasattr(self, 'filter'): |
221
|
|
|
raise RuntimeError("You must define a filter function") |
222
|
|
|
|
223
|
|
|
if reverse_mapper: |
224
|
|
|
self.reverse_mapper = reverse_mapper |
225
|
|
|
elif not hasattr(self, 'reverse_mapper'): |
226
|
|
|
raise RuntimeError("You must define a reverse_mapper function") |
227
|
|
|
|
228
|
|
|
if decorator: |
229
|
|
|
self.decorator = decorator |
230
|
|
|
elif not hasattr(self, 'decorator'): |
231
|
|
|
raise RuntimeError("You must define a decorator function") |
232
|
|
|
|
233
|
|
|
if mapper: |
234
|
|
|
self.mapper = mapper |
235
|
|
|
|
236
|
|
|
if collect is not None: |
237
|
|
|
self.collect = collect |
238
|
|
|
|
239
|
|
|
@staticmethod |
240
|
|
|
def mapper(obj): |
241
|
|
|
return obj.id |
242
|
|
|
|
243
|
|
|
def fetch(self, dataset, name, model, forwarders, db): |
244
|
|
|
collect = self.collect or forwarders |
245
|
|
|
|
246
|
|
|
try: |
247
|
|
|
data_mapping = collections.defaultdict(list) |
248
|
|
|
t1 = time.time() |
249
|
|
|
for obj in dataset: |
250
|
|
|
for field in forwarders: |
251
|
|
|
obj = getattr(obj, field, None) |
252
|
|
|
|
253
|
|
|
if not obj: |
254
|
|
|
continue |
255
|
|
|
|
256
|
|
|
if collect: |
257
|
|
|
data_mapping[self.mapper(obj)].append(obj) |
258
|
|
|
else: |
259
|
|
|
data_mapping[self.mapper(obj)] = obj |
260
|
|
|
|
261
|
|
|
self.decorator(obj) |
262
|
|
|
|
263
|
|
|
t2 = time.time() |
264
|
|
|
logger.debug("Creating data_mapping for %s query took %.3f secs for the %s prefetcher.", |
265
|
|
|
model.__name__, t2-t1, name) |
266
|
|
|
t1 = time.time() |
267
|
|
|
related_data = self.filter(data_mapping.keys()) |
268
|
|
|
if db is not None: |
269
|
|
|
related_data = related_data.using(db) |
270
|
|
|
related_data_len = len(related_data) |
271
|
|
|
t2 = time.time() |
272
|
|
|
logger.debug("Filtering for %s related objects for %s query took %.3f secs for the %s prefetcher.", |
273
|
|
|
related_data_len, model.__name__, t2-t1, name) |
274
|
|
|
relation_mapping = collections.defaultdict(list) |
275
|
|
|
|
276
|
|
|
t1 = time.time() |
277
|
|
|
for obj in related_data: |
278
|
|
|
for id_ in self.reverse_mapper(obj): |
279
|
|
|
if id_: |
280
|
|
|
relation_mapping[id_].append(obj) |
281
|
|
|
for id_, related_items in relation_mapping.items(): |
282
|
|
|
if id_ in data_mapping: |
283
|
|
|
if collect: |
284
|
|
|
for item in data_mapping[id_]: |
285
|
|
|
self.decorator(item, related_items) |
286
|
|
|
else: |
287
|
|
|
self.decorator(data_mapping[id_], related_items) |
288
|
|
|
|
289
|
|
|
t2 = time.time() |
290
|
|
|
logger.debug("Adding the related objects on the %s query took %.3f secs for the %s prefetcher.", |
291
|
|
|
model.__name__, t2-t1, name) |
292
|
|
|
return dataset |
293
|
|
|
except Exception: |
294
|
|
|
logger.exception("Prefetch failed for %s prefetch on the %s model:", name, model.__name__) |
295
|
|
|
raise |
296
|
|
|
|