Total Complexity | 41 |
Total Lines | 348 |
Duplicated Lines | 4.31 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like test_flush often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import time |
||
2 | import pdb |
||
3 | import threading |
||
4 | import logging |
||
5 | from multiprocessing import Pool, Process |
||
6 | import pytest |
||
7 | from utils import * |
||
8 | from constants import * |
||
9 | |||
10 | DELETE_TIMEOUT = 60 |
||
11 | default_single_query = { |
||
12 | "bool": { |
||
13 | "must": [ |
||
14 | {"vector": {default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, default_dim), |
||
15 | "metric_type": "L2", "params": {"nprobe": 10}}}} |
||
16 | ] |
||
17 | } |
||
18 | } |
||
19 | |||
20 | |||
21 | class TestFlushBase: |
||
22 | """ |
||
23 | ****************************************************************** |
||
24 | The following cases are used to test `flush` function |
||
25 | ****************************************************************** |
||
26 | """ |
||
27 | |||
28 | @pytest.fixture( |
||
29 | scope="function", |
||
30 | params=gen_simple_index() |
||
31 | ) |
||
32 | def get_simple_index(self, request, connect): |
||
33 | if str(connect._cmd("mode")[1]) == "GPU": |
||
34 | if request.param["index_type"] not in ivf(): |
||
35 | pytest.skip("Only support index_type: idmap/flat") |
||
36 | return request.param |
||
37 | |||
38 | @pytest.fixture( |
||
39 | scope="function", |
||
40 | params=gen_single_filter_fields() |
||
41 | ) |
||
42 | def get_filter_field(self, request): |
||
43 | yield request.param |
||
44 | |||
45 | @pytest.fixture( |
||
46 | scope="function", |
||
47 | params=gen_single_vector_fields() |
||
48 | ) |
||
49 | def get_vector_field(self, request): |
||
50 | yield request.param |
||
51 | |||
52 | def test_flush_collection_not_existed(self, connect, collection): |
||
53 | ''' |
||
54 | target: test flush, params collection_name not existed |
||
55 | method: flush, with collection not existed |
||
56 | expected: error raised |
||
57 | ''' |
||
58 | collection_new = gen_unique_str("test_flush_1") |
||
59 | with pytest.raises(Exception) as e: |
||
60 | connect.flush([collection_new]) |
||
61 | |||
62 | def test_flush_empty_collection(self, connect, collection): |
||
63 | ''' |
||
64 | method: flush collection with no vectors |
||
65 | expected: no error raised |
||
66 | ''' |
||
67 | ids = connect.bulk_insert(collection, default_entities) |
||
68 | assert len(ids) == default_nb |
||
69 | status = connect.delete_entity_by_id(collection, ids) |
||
70 | assert status.OK() |
||
71 | connect.flush([collection]) |
||
72 | res = connect.count_entities(collection) |
||
73 | assert 0 == res |
||
74 | # with pytest.raises(Exception) as e: |
||
75 | # connect.flush([collection]) |
||
76 | |||
77 | def test_add_partition_flush(self, connect, id_collection): |
||
78 | ''' |
||
79 | method: add entities into partition in collection, flush serveral times |
||
80 | expected: the length of ids and the collection row count |
||
81 | ''' |
||
82 | connect.create_partition(id_collection, default_tag) |
||
83 | ids = [i for i in range(default_nb)] |
||
84 | ids = connect.bulk_insert(id_collection, default_entities, ids) |
||
85 | connect.flush([id_collection]) |
||
86 | res_count = connect.count_entities(id_collection) |
||
87 | assert res_count == default_nb |
||
88 | ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag) |
||
89 | assert len(ids) == default_nb |
||
90 | connect.flush([id_collection]) |
||
91 | res_count = connect.count_entities(id_collection) |
||
92 | assert res_count == default_nb * 2 |
||
93 | |||
94 | View Code Duplication | def test_add_partitions_flush(self, connect, id_collection): |
|
95 | ''' |
||
96 | method: add entities into partitions in collection, flush one |
||
97 | expected: the length of ids and the collection row count |
||
98 | ''' |
||
99 | tag_new = gen_unique_str() |
||
100 | connect.create_partition(id_collection, default_tag) |
||
101 | connect.create_partition(id_collection, tag_new) |
||
102 | ids = [i for i in range(default_nb)] |
||
103 | ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag) |
||
104 | connect.flush([id_collection]) |
||
105 | ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=tag_new) |
||
106 | connect.flush([id_collection]) |
||
107 | res = connect.count_entities(id_collection) |
||
108 | assert res == 2 * default_nb |
||
109 | |||
110 | def test_add_collections_flush(self, connect, id_collection): |
||
111 | ''' |
||
112 | method: add entities into collections, flush one |
||
113 | expected: the length of ids and the collection row count |
||
114 | ''' |
||
115 | collection_new = gen_unique_str() |
||
116 | default_fields = gen_default_fields(False) |
||
117 | connect.create_collection(collection_new, default_fields) |
||
118 | connect.create_partition(id_collection, default_tag) |
||
119 | connect.create_partition(collection_new, default_tag) |
||
120 | ids = [i for i in range(default_nb)] |
||
121 | ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag) |
||
122 | ids = connect.bulk_insert(collection_new, default_entities, ids, partition_tag=default_tag) |
||
123 | connect.flush([id_collection]) |
||
124 | connect.flush([collection_new]) |
||
125 | res = connect.count_entities(id_collection) |
||
126 | assert res == default_nb |
||
127 | res = connect.count_entities(collection_new) |
||
128 | assert res == default_nb |
||
129 | |||
130 | def test_add_collections_fields_flush(self, connect, id_collection, get_filter_field, get_vector_field): |
||
131 | ''' |
||
132 | method: create collection with different fields, and add entities into collections, flush one |
||
133 | expected: the length of ids and the collection row count |
||
134 | ''' |
||
135 | nb_new = 5 |
||
136 | filter_field = get_filter_field |
||
137 | vector_field = get_vector_field |
||
138 | collection_new = gen_unique_str("test_flush") |
||
139 | fields = { |
||
140 | "fields": [filter_field, vector_field], |
||
141 | "segment_row_limit": default_segment_row_limit, |
||
142 | "auto_id": False |
||
143 | } |
||
144 | connect.create_collection(collection_new, fields) |
||
145 | connect.create_partition(id_collection, default_tag) |
||
146 | connect.create_partition(collection_new, default_tag) |
||
147 | entities_new = gen_entities_by_fields(fields["fields"], nb_new, default_dim) |
||
148 | ids = [i for i in range(default_nb)] |
||
149 | ids_new = [i for i in range(nb_new)] |
||
150 | ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag) |
||
151 | ids = connect.bulk_insert(collection_new, entities_new, ids_new, partition_tag=default_tag) |
||
152 | connect.flush([id_collection]) |
||
153 | connect.flush([collection_new]) |
||
154 | res = connect.count_entities(id_collection) |
||
155 | assert res == default_nb |
||
156 | res = connect.count_entities(collection_new) |
||
157 | assert res == nb_new |
||
158 | |||
159 | def test_add_flush_multiable_times(self, connect, collection): |
||
160 | ''' |
||
161 | method: add entities, flush serveral times |
||
162 | expected: no error raised |
||
163 | ''' |
||
164 | ids = connect.bulk_insert(collection, default_entities) |
||
165 | for i in range(10): |
||
166 | connect.flush([collection]) |
||
167 | res = connect.count_entities(collection) |
||
168 | assert res == len(ids) |
||
169 | # query_vecs = [vectors[0], vectors[1], vectors[-1]] |
||
170 | res = connect.search(collection, default_single_query) |
||
171 | logging.getLogger().debug(res) |
||
172 | assert res |
||
173 | |||
174 | def test_add_flush_auto(self, connect, id_collection): |
||
175 | ''' |
||
176 | method: add entities |
||
177 | expected: no error raised |
||
178 | ''' |
||
179 | ids = [i for i in range(default_nb)] |
||
180 | ids = connect.bulk_insert(id_collection, default_entities, ids) |
||
181 | timeout = 20 |
||
182 | start_time = time.time() |
||
183 | while (time.time() - start_time < timeout): |
||
184 | time.sleep(1) |
||
185 | res = connect.count_entities(id_collection) |
||
186 | if res == default_nb: |
||
187 | break |
||
188 | if time.time() - start_time > timeout: |
||
189 | assert False |
||
190 | |||
191 | @pytest.fixture( |
||
192 | scope="function", |
||
193 | params=[ |
||
194 | 1, |
||
195 | 100 |
||
196 | ], |
||
197 | ) |
||
198 | def same_ids(self, request): |
||
199 | yield request.param |
||
200 | |||
201 | def test_add_flush_same_ids(self, connect, id_collection, same_ids): |
||
202 | ''' |
||
203 | method: add entities, with same ids, count(same ids) < 15, > 15 |
||
204 | expected: the length of ids and the collection row count |
||
205 | ''' |
||
206 | ids = [i for i in range(default_nb)] |
||
207 | for i, item in enumerate(ids): |
||
208 | if item <= same_ids: |
||
209 | ids[i] = 0 |
||
210 | ids = connect.bulk_insert(id_collection, default_entities, ids) |
||
211 | connect.flush([id_collection]) |
||
212 | res = connect.count_entities(id_collection) |
||
213 | assert res == default_nb |
||
214 | |||
215 | def test_delete_flush_multiable_times(self, connect, collection): |
||
216 | ''' |
||
217 | method: delete entities, flush serveral times |
||
218 | expected: no error raised |
||
219 | ''' |
||
220 | ids = connect.bulk_insert(collection, default_entities) |
||
221 | status = connect.delete_entity_by_id(collection, [ids[-1]]) |
||
222 | assert status.OK() |
||
223 | for i in range(10): |
||
224 | connect.flush([collection]) |
||
225 | # query_vecs = [vectors[0], vectors[1], vectors[-1]] |
||
226 | res = connect.search(collection, default_single_query) |
||
227 | logging.getLogger().debug(res) |
||
228 | assert res |
||
229 | |||
230 | # TODO: unable to set config |
||
231 | @pytest.mark.level(2) |
||
232 | def _test_collection_count_during_flush(self, connect, collection, args): |
||
233 | ''' |
||
234 | method: flush collection at background, call `count_entities` |
||
235 | expected: no timeout |
||
236 | ''' |
||
237 | ids = [] |
||
238 | for i in range(5): |
||
239 | tmp_ids = connect.bulk_insert(collection, default_entities) |
||
240 | connect.flush([collection]) |
||
241 | ids.extend(tmp_ids) |
||
242 | disable_flush(connect) |
||
243 | status = connect.delete_entity_by_id(collection, ids) |
||
244 | |||
245 | def flush(): |
||
246 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
247 | logging.error("start flush") |
||
248 | milvus.flush([collection]) |
||
249 | logging.error("end flush") |
||
250 | |||
251 | p = TestThread(target=flush, args=()) |
||
252 | p.start() |
||
253 | time.sleep(0.2) |
||
254 | logging.error("start count") |
||
255 | res = connect.count_entities(collection, timeout=10) |
||
256 | p.join() |
||
257 | res = connect.count_entities(collection) |
||
258 | assert res == 0 |
||
259 | |||
260 | @pytest.mark.level(2) |
||
261 | def test_delete_flush_during_search(self, connect, collection, args): |
||
262 | ''' |
||
263 | method: search at background, call `delete and flush` |
||
264 | expected: no timeout |
||
265 | ''' |
||
266 | ids = [] |
||
267 | loops = 5 |
||
268 | for i in range(loops): |
||
269 | tmp_ids = connect.bulk_insert(collection, default_entities) |
||
270 | connect.flush([collection]) |
||
271 | ids.extend(tmp_ids) |
||
272 | nq = 10000 |
||
273 | query, query_vecs = gen_query_vectors(default_float_vec_field_name, default_entities, default_top_k, nq) |
||
274 | time.sleep(0.1) |
||
275 | future = connect.search(collection, query, _async=True) |
||
276 | delete_ids = [ids[0], ids[-1]] |
||
277 | status = connect.delete_entity_by_id(collection, delete_ids) |
||
278 | connect.flush([collection]) |
||
279 | res = future.result() |
||
280 | res_count = connect.count_entities(collection, timeout=120) |
||
281 | assert res_count == loops * default_nb - len(delete_ids) |
||
282 | |||
283 | |||
284 | class TestFlushAsync: |
||
285 | @pytest.fixture(scope="function", autouse=True) |
||
286 | def skip_http_check(self, args): |
||
287 | if args["handler"] == "HTTP": |
||
288 | pytest.skip("skip in http mode") |
||
289 | |||
290 | """ |
||
291 | ****************************************************************** |
||
292 | The following cases are used to test `flush` function |
||
293 | ****************************************************************** |
||
294 | """ |
||
295 | |||
296 | def check_status(self): |
||
297 | logging.getLogger().info("In callback check status") |
||
298 | |||
299 | def test_flush_empty_collection(self, connect, collection): |
||
300 | ''' |
||
301 | method: flush collection with no vectors |
||
302 | expected: status ok |
||
303 | ''' |
||
304 | future = connect.flush([collection], _async=True) |
||
305 | status = future.result() |
||
306 | |||
307 | def test_flush_async_long(self, connect, collection): |
||
308 | ids = connect.bulk_insert(collection, default_entities) |
||
309 | future = connect.flush([collection], _async=True) |
||
310 | status = future.result() |
||
311 | |||
312 | def test_flush_async_long_drop_collection(self, connect, collection): |
||
313 | for i in range(5): |
||
314 | ids = connect.bulk_insert(collection, default_entities) |
||
315 | future = connect.flush([collection], _async=True) |
||
316 | logging.getLogger().info("DROP") |
||
317 | connect.drop_collection(collection) |
||
318 | |||
319 | def test_flush_async(self, connect, collection): |
||
320 | connect.bulk_insert(collection, default_entities) |
||
321 | logging.getLogger().info("before") |
||
322 | future = connect.flush([collection], _async=True, _callback=self.check_status) |
||
323 | logging.getLogger().info("after") |
||
324 | future.done() |
||
325 | status = future.result() |
||
326 | |||
327 | |||
328 | class TestCollectionNameInvalid(object): |
||
329 | """ |
||
330 | Test adding vectors with invalid collection names |
||
331 | """ |
||
332 | |||
333 | @pytest.fixture( |
||
334 | scope="function", |
||
335 | # params=gen_invalid_collection_names() |
||
336 | params=gen_invalid_strs() |
||
337 | ) |
||
338 | def get_invalid_collection_name(self, request): |
||
339 | yield request.param |
||
340 | |||
341 | @pytest.mark.level(2) |
||
342 | def test_flush_with_invalid_collection_name(self, connect, get_invalid_collection_name): |
||
343 | collection_name = get_invalid_collection_name |
||
344 | if collection_name is None or not collection_name: |
||
345 | pytest.skip("while collection_name is None, then flush all collections") |
||
346 | with pytest.raises(Exception) as e: |
||
347 | connect.flush(collection_name) |
||
348 |