Total Complexity | 103 |
Total Lines | 581 |
Duplicated Lines | 10.5 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like amd.io often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | """Contains I/O tools, including a .CIF reader and CSD reader |
||
2 | (``csd-python-api`` only) to extract periodic set representations |
||
3 | of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`. |
||
4 | |||
5 | These intermediate :class:`.periodicset.PeriodicSet` representations can be written |
||
6 | to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`. |
||
7 | This is much faster than rereading a .CIF and recomputing invariants. |
||
8 | """ |
||
9 | |||
10 | import os |
||
11 | import functools |
||
12 | import warnings |
||
13 | from typing import Callable, Iterable, Sequence, Tuple |
||
14 | |||
15 | import numpy as np |
||
16 | import ase.io.cif |
||
17 | import ase.spacegroup.spacegroup |
||
18 | |||
19 | from .periodicset import PeriodicSet |
||
20 | from .utils import cellpar_to_cell |
||
21 | |||
22 | try: |
||
23 | import ccdc.io |
||
24 | import ccdc.search |
||
25 | _CSD_PYTHON_API_ENABLED = True |
||
26 | except (ImportError, RuntimeError) as _: |
||
27 | _CSD_PYTHON_API_ENABLED = False |
||
28 | |||
29 | def _custom_warning(message, category, filename, lineno, *args, **kwargs): |
||
30 | return f'{category.__name__}: {message}\n' |
||
31 | |||
32 | warnings.formatwarning = _custom_warning |
||
33 | |||
34 | class ParseError(ValueError): |
||
35 | pass |
||
36 | |||
37 | _EQUIV_SITE_TOL = 1e-3 |
||
38 | _DISORDER_OPTIONS = {'skip', 'ordered_sites', 'all_sites',} |
||
39 | _ATOM_SITE_FRACT_TAGS = [ |
||
40 | '_atom_site_fract_x', |
||
41 | '_atom_site_fract_y', |
||
42 | '_atom_site_fract_z',] |
||
43 | _ATOM_SITE_CARTN_TAGS = [ |
||
44 | '_atom_site_cartn_x', |
||
45 | '_atom_site_cartn_y', |
||
46 | '_atom_site_cartn_z',] |
||
47 | _SYMOP_TAGS = [ |
||
48 | '_space_group_symop_operation_xyz', |
||
49 | '_space_group_symop.operation_xyz', |
||
50 | '_symmetry_equiv_pos_as_xyz',] |
||
51 | |||
52 | |||
53 | class CifReader: |
||
54 | """Read all structures in a .CIF with ``ase`` or ``ccdc`` |
||
55 | (``csd-python-api`` only), yielding :class:`.periodicset.PeriodicSet` |
||
56 | objects which can be passed to :func:`.calculate.AMD` or |
||
57 | :func:`.calculate.PDD`. |
||
58 | |||
59 | Examples: |
||
60 | |||
61 | :: |
||
62 | |||
63 | # Put all crystals in a .CIF in a list |
||
64 | structures = list(amd.CifReader('mycif.cif')) |
||
65 | |||
66 | # Reads just one if the .CIF has just one crystal |
||
67 | periodic_set = amd.CifReader('mycif.cif').read_one() |
||
68 | |||
69 | # If a folder has several .CIFs each with one crystal, use |
||
70 | structures = list(amd.CifReader('path/to/folder', folder=True)) |
||
71 | |||
72 | # Make list of AMDs (with k=100) of crystals in a .CIF |
||
73 | amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')] |
||
74 | """ |
||
75 | |||
76 | def __init__(self, |
||
77 | path, |
||
78 | reader='ase', |
||
79 | folder=False, |
||
80 | remove_hydrogens=False, |
||
81 | disorder='skip', |
||
82 | heaviest_component=False, |
||
83 | show_warnings=True, |
||
84 | extract_data=None, |
||
85 | include_if=None |
||
86 | ): |
||
87 | |||
88 | if disorder not in _DISORDER_OPTIONS: |
||
89 | raise ValueError(f'disorder parameter {disorder} must be one of {_DISORDER_OPTIONS}') |
||
90 | |||
91 | if reader not in ('ase', 'ccdc'): |
||
92 | raise ValueError(f'Invalid reader {reader}; must be ase or ccdc.') |
||
93 | |||
94 | if reader == 'ase' and heaviest_component: |
||
95 | raise NotImplementedError('Parameter heaviest_component not implimented for ase, only ccdc.') |
||
96 | |||
97 | extract_data, include_if = _validate_kwargs(extract_data, include_if) |
||
98 | |||
99 | self.show_warnings = show_warnings |
||
100 | self.extract_data = extract_data |
||
101 | self.include_if = include_if |
||
102 | self.current_filename = None |
||
103 | |||
104 | if reader == 'ase': |
||
105 | extensions = {'cif'} |
||
106 | file_parser = ase.io.cif.parse_cif |
||
107 | converter = functools.partial(cifblock_to_periodicset, |
||
108 | remove_hydrogens=remove_hydrogens, |
||
109 | disorder=disorder) |
||
110 | |||
111 | elif reader == 'ccdc': |
||
112 | if not _CSD_PYTHON_API_ENABLED: |
||
113 | raise ImportError("Failed to import csd-python-api; check it is installed and licensed.") |
||
114 | extensions = ccdc.io.EntryReader.known_suffixes |
||
115 | file_parser = ccdc.io.EntryReader |
||
116 | converter = functools.partial(entry_to_periodicset, |
||
117 | remove_hydrogens=remove_hydrogens, |
||
118 | disorder=disorder, |
||
119 | heaviest_component=heaviest_component) |
||
120 | |||
121 | if folder: |
||
122 | generator = self._folder_generator(path, file_parser, extensions) |
||
123 | else: |
||
124 | generator = file_parser(path) |
||
125 | |||
126 | self._generator = self._map(converter, generator) |
||
127 | |||
128 | def __iter__(self): |
||
129 | yield from self._generator |
||
130 | |||
131 | def read_one(self): |
||
132 | """Read the next (usually first and only) item.""" |
||
133 | return next(iter(self._generator)) |
||
134 | |||
135 | def _folder_generator(self, path, file_parser, extensions): |
||
136 | for file in os.listdir(path): |
||
137 | suff = os.path.splitext(file)[1][1:] |
||
138 | if suff.lower() in extensions: |
||
139 | self.current_filename = file |
||
140 | yield from file_parser(os.path.join(path, file)) |
||
141 | |||
142 | View Code Duplication | def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]: |
|
143 | """Iterates over iterable, passing items through parser and yielding the result. |
||
144 | Applies warning and include_if filters, catches bad structures and warns. |
||
145 | """ |
||
146 | |||
147 | for item in iterable: |
||
148 | |||
149 | with warnings.catch_warnings(record=True) as warning_msgs: |
||
150 | |||
151 | if not self.show_warnings: |
||
152 | warnings.simplefilter('ignore') |
||
153 | |||
154 | if any(not check(item) for check in self.include_if): |
||
155 | continue |
||
156 | |||
157 | try: |
||
158 | periodic_set = func(item) |
||
159 | except ParseError as err: |
||
160 | warnings.warn(err, category=UserWarning) |
||
161 | continue |
||
162 | |||
163 | for warning in warning_msgs: |
||
164 | msg = f'{periodic_set.name}: {warning.message}' |
||
165 | warnings.warn(msg, category=warning.category) |
||
166 | |||
167 | if self.current_filename: |
||
168 | periodic_set.tags['filename'] = self.current_filename |
||
169 | |||
170 | for key, func in self.extract_data.items(): |
||
171 | periodic_set.tags[key] = func(item) |
||
172 | |||
173 | yield periodic_set |
||
174 | |||
175 | |||
176 | class CSDReader: |
||
177 | """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects. |
||
178 | |||
179 | The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed |
||
180 | to :func:`.calculate.AMD` or :func:`.calculate.PDD`. |
||
181 | |||
182 | Examples: |
||
183 | |||
184 | Get crystals with refcodes in a list:: |
||
185 | |||
186 | refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01'] |
||
187 | structures = list(amd.CSDReader(refcodes)) |
||
188 | |||
189 | Read refcode families (any whose refcode starts with strings in the list):: |
||
190 | |||
191 | refcodes = ['ACSALA', 'HXACAN'] |
||
192 | structures = list(amd.CSDReader(refcodes, families=True)) |
||
193 | |||
194 | Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`:: |
||
195 | |||
196 | reader = amd.CSDReader() |
||
197 | debxit01 = reader.entry('DEBXIT01') |
||
198 | |||
199 | # looping over this generic reader will yield all CSD entries |
||
200 | for periodic_set in reader: |
||
201 | ... |
||
202 | |||
203 | Make list of AMD (with k=100) for crystals in these families:: |
||
204 | |||
205 | refcodes = ['ACSALA', 'HXACAN'] |
||
206 | amds = [] |
||
207 | for periodic_set in amd.CSDReader(refcodes, families=True): |
||
208 | amds.append(amd.AMD(periodic_set, 100)) |
||
209 | """ |
||
210 | |||
211 | def __init__(self, |
||
212 | refcodes=None, |
||
213 | families=False, |
||
214 | remove_hydrogens=False, |
||
215 | disorder='skip', |
||
216 | heaviest_component=False, |
||
217 | show_warnings=True, |
||
218 | extract_data=None, |
||
219 | include_if=None, |
||
220 | ): |
||
221 | |||
222 | if not _CSD_PYTHON_API_ENABLED: |
||
223 | raise ImportError('Failed to import csd-python-api; check it is installed and licensed.') |
||
224 | |||
225 | if disorder not in _DISORDER_OPTIONS: |
||
226 | raise ValueError(f'disorder parameter {disorder} must be one of {_DISORDER_OPTIONS}') |
||
227 | |||
228 | extract_data, include_if = _validate_kwargs(extract_data, include_if) |
||
229 | |||
230 | self.show_warnings = show_warnings |
||
231 | self.extract_data = extract_data |
||
232 | self.include_if = include_if |
||
233 | self.current_filename = None |
||
234 | |||
235 | if isinstance(refcodes, str) and refcodes.lower() == 'csd': |
||
236 | refcodes = None |
||
237 | |||
238 | if refcodes is None: |
||
239 | families = False |
||
240 | else: |
||
241 | refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes) |
||
242 | |||
243 | # families parameter reads all crystals with ids starting with passed refcodes |
||
244 | if families: |
||
245 | all_refcodes = [] |
||
246 | for refcode in refcodes: |
||
247 | query = ccdc.search.TextNumericSearch() |
||
248 | query.add_identifier(refcode) |
||
249 | all_refcodes.extend((hit.identifier for hit in query.search())) |
||
250 | |||
251 | # filter to unique refcodes |
||
252 | seen = set() |
||
253 | seen_add = seen.add |
||
254 | refcodes = [ |
||
255 | refcode for refcode in all_refcodes |
||
256 | if not (refcode in seen or seen_add(refcode))] |
||
257 | |||
258 | self._entry_reader = ccdc.io.EntryReader('CSD') |
||
259 | |||
260 | converter = functools.partial(entry_to_periodicset, |
||
261 | remove_hydrogens=remove_hydrogens, |
||
262 | disorder=disorder, |
||
263 | heaviest_component=heaviest_component) |
||
264 | |||
265 | generator = self._ccdc_generator(refcodes) |
||
266 | |||
267 | self._generator = self._map(converter, generator) |
||
268 | |||
269 | def __iter__(self): |
||
270 | yield from self._generator |
||
271 | |||
272 | def read_one(self): |
||
273 | """Read the next (usually first and only) item.""" |
||
274 | return next(iter(self._generator)) |
||
275 | |||
276 | def entry(self, refcode: str, **kwargs) -> PeriodicSet: |
||
277 | """Read a PeriodicSet given any CSD refcode.""" |
||
278 | |||
279 | entry = self._entry_reader.entry(refcode) |
||
280 | periodic_set = entry_to_periodicset(entry, **kwargs) |
||
281 | return periodic_set |
||
282 | |||
283 | def _ccdc_generator(self, refcodes): |
||
284 | """Generates ccdc Entries from CSD refcodes.""" |
||
285 | |||
286 | if refcodes is None: |
||
287 | for entry in self._entry_reader: |
||
288 | yield entry |
||
289 | else: |
||
290 | for refcode in refcodes: |
||
291 | try: |
||
292 | entry = self._entry_reader.entry(refcode) |
||
293 | yield entry |
||
294 | except RuntimeError: |
||
295 | warnings.warn(f'Identifier {refcode} not found in database') |
||
296 | |||
297 | View Code Duplication | def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]: |
|
298 | """Iterates over iterable, passing items through parser and yielding the result. |
||
299 | Applies warning and include_if filters, catches bad structures and warns. |
||
300 | """ |
||
301 | |||
302 | for item in iterable: |
||
303 | |||
304 | with warnings.catch_warnings(record=True) as warning_msgs: |
||
305 | |||
306 | if not self.show_warnings: |
||
307 | warnings.simplefilter('ignore') |
||
308 | |||
309 | if any(not check(item) for check in self.include_if): |
||
310 | continue |
||
311 | |||
312 | try: |
||
313 | periodic_set = func(item) |
||
314 | except ParseError as err: |
||
315 | warnings.warn(err, category=UserWarning) |
||
316 | continue |
||
317 | |||
318 | for warning in warning_msgs: |
||
319 | msg = f'{periodic_set.name}: {warning.message}' |
||
320 | warnings.warn(msg, category=warning.category) |
||
321 | |||
322 | for key, func in self.extract_data.items(): |
||
323 | periodic_set.tags[key] = func(item) |
||
324 | |||
325 | yield periodic_set |
||
326 | |||
327 | |||
328 | def entry_to_periodicset(entry, |
||
329 | remove_hydrogens=False, |
||
330 | disorder='skip', |
||
331 | heaviest_component=False |
||
332 | ) -> PeriodicSet: |
||
333 | """ccdc.entry.Entry --> PeriodicSet.""" |
||
334 | |||
335 | crystal = entry.crystal |
||
336 | |||
337 | if not entry.has_3d_structure: |
||
338 | raise ParseError(f'Has no 3D structure') |
||
339 | |||
340 | molecule = crystal.disordered_molecule |
||
341 | |||
342 | if disorder == 'skip': |
||
343 | if crystal.has_disorder or entry.has_disorder or \ |
||
344 | any(atom_has_disorder(a.label, a.occupancy) for a in molecule.atoms): |
||
345 | raise ParseError(f'Has disorder') |
||
346 | |||
347 | elif disorder == 'ordered_sites': |
||
348 | molecule.remove_atoms(a for a in molecule.atoms |
||
349 | if atom_has_disorder(a.label, a.occupancy)) |
||
350 | |||
351 | if remove_hydrogens: |
||
352 | molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD') |
||
353 | |||
354 | if heaviest_component and len(molecule.components) > 1: |
||
355 | molecule = _heaviest_component(molecule) |
||
356 | |||
357 | if not molecule.all_atoms_have_sites or \ |
||
358 | any(a.fractional_coordinates is None for a in molecule.atoms): |
||
359 | raise ParseError(f'Has atoms without sites') |
||
360 | |||
361 | crystal.molecule = molecule |
||
362 | asym_atoms = crystal.asymmetric_unit_molecule.atoms |
||
363 | asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms]) |
||
364 | asym_unit = np.mod(asym_unit, 1) |
||
365 | asym_symbols = [a.atomic_symbol for a in asym_atoms] |
||
366 | cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles) |
||
367 | |||
368 | sitesym = crystal.symmetry_operators |
||
369 | if not sitesym: |
||
370 | sitesym = ['x,y,z', ] |
||
371 | |||
372 | if disorder != 'all_sites': |
||
373 | keep_sites = _unique_sites(asym_unit) |
||
374 | if np.any(keep_sites == False): |
||
375 | warnings.warn(f'May have overlapping sites; duplicates will be removed') |
||
376 | asym_unit = asym_unit[keep_sites] |
||
377 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
378 | |||
379 | if asym_unit.shape[0] == 0: |
||
380 | raise ParseError(f'Has no valid sites') |
||
381 | |||
382 | frac_motif, asym_inds, multiplicities, inverses = expand_asym_unit(asym_unit, sitesym) |
||
383 | full_types = [asym_symbols[i] for i in inverses] |
||
384 | motif = frac_motif @ cell |
||
385 | |||
386 | tags = { |
||
387 | 'name': entry.identifier, |
||
388 | 'asymmetric_unit': asym_inds, |
||
389 | 'wyckoff_multiplicities': multiplicities, |
||
390 | 'types': full_types, |
||
391 | } |
||
392 | |||
393 | return PeriodicSet(motif, cell, **tags) |
||
394 | |||
395 | |||
396 | def cifblock_to_periodicset(block, |
||
397 | remove_hydrogens=False, |
||
398 | disorder='skip' |
||
399 | ) -> PeriodicSet: |
||
400 | """ase.io.cif.CIFBlock --> PeriodicSet.""" |
||
401 | |||
402 | cell = block.get_cell().array |
||
403 | |||
404 | # asymmetric unit fractional coords |
||
405 | asym_unit = [block.get(name) for name in _ATOM_SITE_FRACT_TAGS] |
||
406 | if None in asym_unit: |
||
407 | asym_motif = [block.get(name) for name in _ATOM_SITE_CARTN_TAGS] |
||
408 | if None in asym_motif: |
||
409 | raise ParseError(f'Has no sites') |
||
410 | asym_unit = np.array(asym_motif) @ np.linalg.inv(cell) |
||
411 | asym_unit = np.mod(np.array(asym_unit).T, 1) |
||
412 | |||
413 | try: |
||
414 | asym_symbols = block.get_symbols() |
||
415 | except ase.io.cif.NoStructureData as _: |
||
416 | asym_symbols = ['Unknown' for _ in range(len(asym_unit))] |
||
417 | |||
418 | sitesym = ['x,y,z', ] |
||
419 | for tag in _SYMOP_TAGS: |
||
420 | if tag in block: |
||
421 | sitesym = block[tag] |
||
422 | break |
||
423 | if isinstance(sitesym, str): |
||
424 | sitesym = [sitesym] |
||
425 | |||
426 | remove_sites = [] |
||
427 | |||
428 | occupancies = block.get('_atom_site_occupancy') |
||
429 | labels = block.get('_atom_site_label') |
||
430 | if occupancies is not None: |
||
431 | if disorder == 'skip': |
||
432 | if any(atom_has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)): |
||
433 | raise ParseError(f'Has disorder') |
||
434 | elif disorder == 'ordered_sites': |
||
435 | remove_sites.extend( |
||
436 | (i for i, (lab, occ) in enumerate(zip(labels, occupancies)) |
||
437 | if atom_has_disorder(lab, occ))) |
||
438 | |||
439 | if remove_hydrogens: |
||
440 | remove_sites.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD')) |
||
441 | |||
442 | asym_unit = np.delete(asym_unit, remove_sites, axis=0) |
||
443 | asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove_sites] |
||
444 | |||
445 | if disorder != 'all_sites': |
||
446 | keep_sites = _unique_sites(asym_unit) |
||
447 | if np.any(keep_sites == False): |
||
448 | warnings.warn(f'May have overlapping sites; duplicates will be removed') |
||
449 | asym_unit = asym_unit[keep_sites] |
||
450 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
451 | |||
452 | if asym_unit.shape[0] == 0: |
||
453 | raise ParseError(f'Has no valid sites') |
||
454 | |||
455 | frac_motif, asym_inds, multiplicities, inverses = expand_asym_unit(asym_unit, sitesym) |
||
456 | full_types = [asym_symbols[i] for i in inverses] |
||
457 | motif = frac_motif @ cell |
||
458 | |||
459 | tags = { |
||
460 | 'name': block.name, |
||
461 | 'asymmetric_unit': asym_inds, |
||
462 | 'wyckoff_multiplicities': multiplicities, |
||
463 | 'types': full_types, |
||
464 | } |
||
465 | |||
466 | return PeriodicSet(motif, cell, **tags) |
||
467 | |||
468 | |||
469 | def expand_asym_unit( |
||
470 | asym_unit: np.ndarray, |
||
471 | sitesym: Sequence[str] |
||
472 | ) -> Tuple[np.ndarray, ...]: |
||
473 | """ |
||
474 | Asymmetric unit's fractional coords + sitesyms (as strings) |
||
475 | --> |
||
476 | frac motif, asym unit inds, multiplicities, inverses |
||
477 | """ |
||
478 | |||
479 | rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym) |
||
480 | all_sites = [] |
||
481 | asym_inds = [0] |
||
482 | multiplicities = [] |
||
483 | inverses = [] |
||
484 | |||
485 | for inv, site in enumerate(asym_unit): |
||
486 | multiplicity = 0 |
||
487 | |||
488 | for rot, trans in zip(rotations, translations): |
||
489 | site_ = np.mod(np.dot(rot, site) + trans, 1) |
||
490 | |||
491 | if not all_sites: |
||
492 | all_sites.append(site_) |
||
493 | inverses.append(inv) |
||
494 | multiplicity += 1 |
||
495 | continue |
||
496 | |||
497 | # check if site_ overlaps with existing sites |
||
498 | diffs1 = np.abs(site_ - all_sites) |
||
499 | diffs2 = np.abs(diffs1 - 1) |
||
500 | mask = np.all((diffs1 <= _EQUIV_SITE_TOL) | (diffs2 <= _EQUIV_SITE_TOL), axis=-1) |
||
501 | |||
502 | if np.any(mask): |
||
503 | where_equal = np.argwhere(mask).flatten() |
||
504 | for ind in where_equal: |
||
505 | if inverses[ind] == inv: |
||
506 | pass |
||
507 | else: |
||
508 | warnings.warn(f'Equivalent sites at positions {inverses[ind]}, {inv}') |
||
509 | else: |
||
510 | all_sites.append(site_) |
||
511 | inverses.append(inv) |
||
512 | multiplicity += 1 |
||
513 | |||
514 | if multiplicity > 0: |
||
515 | multiplicities.append(multiplicity) |
||
516 | asym_inds.append(len(all_sites)) |
||
517 | |||
518 | frac_motif = np.array(all_sites) |
||
519 | asym_inds = np.array(asym_inds[:-1]) |
||
520 | multiplicities = np.array(multiplicities) |
||
521 | return frac_motif, asym_inds, multiplicities, inverses |
||
522 | |||
523 | |||
524 | def atom_has_disorder(label, occupancy): |
||
525 | return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1) |
||
526 | |||
527 | |||
528 | def _unique_sites(asym_unit): |
||
529 | site_diffs1 = np.abs(asym_unit[:, None] - asym_unit) |
||
530 | site_diffs2 = np.abs(site_diffs1 - 1) |
||
531 | overlapping = np.triu(np.all( |
||
532 | (site_diffs1 <= _EQUIV_SITE_TOL) | (site_diffs2 <= _EQUIV_SITE_TOL), |
||
533 | axis=-1), 1) |
||
534 | return ~overlapping.any(axis=0) |
||
535 | |||
536 | |||
537 | def _heaviest_component(molecule): |
||
538 | """Heaviest component (removes all but the heaviest component of the asym unit). |
||
539 | Intended for removing solvents. Probably doesn't play well with disorder""" |
||
540 | component_weights = [] |
||
541 | for component in molecule.components: |
||
542 | weight = 0 |
||
543 | for a in component.atoms: |
||
544 | if isinstance(a.atomic_weight, (float, int)): |
||
545 | if isinstance(a.occupancy, (float, int)): |
||
546 | weight += a.occupancy * a.atomic_weight |
||
547 | else: |
||
548 | weight += a.atomic_weight |
||
549 | component_weights.append(weight) |
||
550 | largest_component_ind = np.argmax(np.array(component_weights)) |
||
551 | molecule = molecule.components[largest_component_ind] |
||
552 | return molecule |
||
553 | |||
554 | |||
555 | def _validate_kwargs(extract_data, include_if): |
||
556 | |||
557 | reserved_tags = {'motif', 'cell', 'name', |
||
558 | 'asymmetric_unit', 'wyckoff_multiplicities', |
||
559 | 'types', 'filename'} |
||
560 | |||
561 | if extract_data is None: |
||
562 | extract_data = {} |
||
563 | else: |
||
564 | if not isinstance(extract_data, dict): |
||
565 | raise ValueError('extract_data must be a dict of callables') |
||
566 | for key in extract_data: |
||
567 | if not callable(extract_data[key]): |
||
568 | raise ValueError('extract_data must be a dict of callables') |
||
569 | if key in reserved_tags: |
||
570 | raise ValueError(f'extract_data includes reserved key {key}') |
||
571 | extract_data = extract_data |
||
572 | |||
573 | if include_if is None: |
||
574 | include_if = () |
||
575 | elif not all(callable(func) for func in include_if): |
||
576 | raise ValueError('include_if must be a list of callables') |
||
577 | else: |
||
578 | include_if = include_if |
||
579 | |||
580 | return extract_data, include_if |
||
581 |