Total Complexity | 98 |
Total Lines | 438 |
Duplicated Lines | 0 % |
Changes | 0 |
Complex classes like amd._reader often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import warnings |
||
|
|||
2 | from typing import Callable, Iterable, Sequence, Tuple |
||
3 | |||
4 | import numpy as np |
||
5 | import ase.spacegroup.spacegroup # parse_sitesym |
||
6 | import ase.io.cif |
||
7 | |||
8 | from .periodicset import PeriodicSet |
||
9 | from .utils import cellpar_to_cell |
||
10 | |||
11 | |||
12 | def _warning(message, category, filename, lineno, *args, **kwargs): |
||
2 ignored issues
–
show
|
|||
13 | return f'{filename}:{lineno}: {category.__name__}: {message}\n' |
||
14 | |||
15 | warnings.formatwarning = _warning |
||
16 | |||
17 | |||
18 | def _atom_has_disorder(label, occupancy): |
||
19 | return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1) |
||
20 | |||
21 | class _Reader: |
||
22 | """Base Reader class. Contains parsers for converting ase CifBlock |
||
23 | and ccdc Entry objects to PeriodicSets. |
||
24 | |||
25 | Intended use: |
||
26 | |||
27 | First make a new method for _Reader converting object to PeriodicSet |
||
28 | (e.g. named _X_to_PSet). Then make this class outline: |
||
29 | |||
30 | class XReader(_Reader): |
||
31 | def __init__(self, ..., **kwargs): |
||
32 | |||
33 | super().__init__(**kwargs) |
||
34 | |||
35 | # setup and checks |
||
36 | |||
37 | # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry) |
||
38 | |||
39 | # set self._generator like this |
||
40 | self._generator = self._read(iterable, self._X_to_PSet) |
||
41 | """ |
||
42 | |||
43 | disorder_options = {'skip', 'ordered_sites', 'all_sites'} |
||
44 | reserved_tags = { |
||
45 | 'motif', |
||
46 | 'cell', |
||
47 | 'name', |
||
48 | 'asymmetric_unit', |
||
49 | 'wyckoff_multiplicities', |
||
50 | 'types', |
||
51 | 'filename',} |
||
52 | atom_site_fract_tags = [ |
||
53 | '_atom_site_fract_x', |
||
54 | '_atom_site_fract_y', |
||
55 | '_atom_site_fract_z',] |
||
56 | atom_site_cartn_tags = [ |
||
57 | '_atom_site_cartn_x', |
||
58 | '_atom_site_cartn_y', |
||
59 | '_atom_site_cartn_z',] |
||
60 | symop_tags = [ |
||
61 | '_space_group_symop_operation_xyz', |
||
62 | '_space_group_symop.operation_xyz', |
||
63 | '_symmetry_equiv_pos_as_xyz',] |
||
64 | |||
65 | equiv_site_tol = 1e-3 |
||
66 | |||
67 | def __init__( |
||
68 | self, |
||
69 | remove_hydrogens=False, |
||
70 | disorder='skip', |
||
71 | heaviest_component=False, |
||
72 | show_warnings=True, |
||
73 | extract_data=None, |
||
74 | include_if=None): |
||
75 | |||
76 | # settings |
||
77 | if disorder not in _Reader.disorder_options: |
||
78 | raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}') |
||
79 | |||
80 | if extract_data: |
||
81 | if not isinstance(extract_data, dict): |
||
82 | raise ValueError('extract_data must be a dict with callable values') |
||
83 | for key in extract_data: |
||
84 | if not callable(extract_data[key]): |
||
85 | raise ValueError('extract_data must be a dict with callable values') |
||
86 | if key in _Reader.reserved_tags: |
||
87 | raise ValueError(f'extract_data includes reserved key {key}') |
||
88 | |||
89 | if include_if: |
||
90 | for func in include_if: |
||
91 | if not callable(func): |
||
92 | raise ValueError('include_if must be a list of callables') |
||
93 | |||
94 | self.remove_hydrogens = remove_hydrogens |
||
95 | self.disorder = disorder |
||
96 | self.heaviest_component = heaviest_component |
||
97 | self.extract_data = extract_data |
||
98 | self.include_if = include_if |
||
99 | self.show_warnings = show_warnings |
||
100 | self.current_identifier = None |
||
101 | self.current_filename = None |
||
102 | self._generator = [] |
||
103 | |||
104 | def __iter__(self): |
||
105 | yield from self._generator |
||
106 | |||
107 | def read_one(self): |
||
108 | """Read the next (or first) item.""" |
||
109 | return next(iter(self._generator)) |
||
110 | |||
111 | # basically the builtin map, but skips items if the function returned None. |
||
112 | # The object returned by this function (Iterable of PeriodicSets) is set to |
||
113 | # self._generator; then iterating over the Reader iterates over |
||
114 | # self._generator. |
||
115 | @staticmethod |
||
116 | def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]: |
||
117 | """Iterates over iterable, passing items through parser and |
||
118 | yielding the result if it is not None. |
||
119 | """ |
||
120 | |||
121 | for item in iterable: |
||
122 | res = func(item) |
||
123 | if res is not None: |
||
124 | yield res |
||
125 | |||
126 | def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet: |
||
127 | """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set.""" |
||
128 | |||
129 | # skip if structure does not pass checks in include_if |
||
130 | if self.include_if: |
||
131 | if not all(check(block) for check in self.include_if): |
||
132 | return None |
||
133 | |||
134 | # read name, cell, asym motif and atomic symbols |
||
135 | self.current_identifier = block.name |
||
136 | cell = block.get_cell().array |
||
137 | asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags] |
||
138 | if None in asym_frac_motif: |
||
139 | asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags] |
||
140 | if None in asym_motif: |
||
141 | if self.show_warnings: |
||
142 | warnings.warn( |
||
143 | f'Skipping {self.current_identifier} as coordinates were not found') |
||
144 | return None |
||
145 | asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell) |
||
146 | asym_frac_motif = np.array(asym_frac_motif).T |
||
147 | |||
148 | try: |
||
149 | asym_symbols = block.get_symbols() |
||
150 | except ase.io.cif.NoStructureData as _: |
||
151 | asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))] |
||
152 | |||
153 | # indices of sites to remove |
||
154 | remove = [] |
||
155 | if self.remove_hydrogens: |
||
156 | remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD')) |
||
157 | |||
158 | # find disordered sites |
||
159 | asym_is_disordered = [] |
||
160 | occupancies = block.get('_atom_site_occupancy') |
||
161 | labels = block.get('_atom_site_label') |
||
162 | if occupancies is not None: |
||
163 | disordered = [] # indices where there is disorder |
||
164 | for i, (occ, label) in enumerate(zip(occupancies, labels)): |
||
165 | if _atom_has_disorder(label, occ): |
||
166 | if i not in remove: |
||
167 | disordered.append(i) |
||
168 | asym_is_disordered.append(True) |
||
169 | else: |
||
170 | asym_is_disordered.append(False) |
||
171 | |||
172 | if self.disorder == 'skip' and len(disordered) > 0: |
||
173 | if self.show_warnings: |
||
174 | warnings.warn( |
||
175 | f'Skipping {self.current_identifier} as structure is disordered') |
||
176 | return None |
||
177 | |||
178 | if self.disorder == 'ordered_sites': |
||
179 | remove.extend(disordered) |
||
180 | |||
181 | # remove sites |
||
182 | asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1) |
||
183 | asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove] |
||
184 | asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove] |
||
185 | |||
186 | keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered) |
||
187 | if keep_sites is not None: |
||
188 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
189 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
190 | |||
191 | if self._has_no_valid_sites(asym_frac_motif): |
||
192 | return None |
||
193 | |||
194 | sitesym = ['x,y,z', ] |
||
195 | for tag in _Reader.symop_tags: |
||
196 | if tag in block: |
||
197 | sitesym = block[tag] |
||
198 | break |
||
199 | |||
200 | if isinstance(sitesym, str): |
||
201 | sitesym = [sitesym] |
||
202 | |||
203 | return self._construct_periodic_set(block, asym_frac_motif, asym_symbols, sitesym, cell) |
||
204 | |||
205 | |||
206 | def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet: |
||
207 | """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set.""" |
||
208 | |||
209 | # skip if structure does not pass checks in include_if |
||
210 | if self.include_if: |
||
211 | if not all(check(entry) for check in self.include_if): |
||
212 | return None |
||
213 | |||
214 | self.current_identifier = entry.identifier |
||
215 | # structure must pass this test |
||
216 | if not entry.has_3d_structure: |
||
217 | if self.show_warnings: |
||
218 | warnings.warn( |
||
219 | f'Skipping {self.current_identifier} as entry has no 3D structure') |
||
220 | return None |
||
221 | |||
222 | crystal = entry.crystal |
||
223 | |||
224 | # first disorder check, if skipping. If occ == 1 for all atoms but the entry |
||
225 | # or crystal is listed as having disorder, skip (can't know where disorder is). |
||
226 | # If occ != 1 for any atoms, we wait to see if we remove them before skipping. |
||
227 | molecule = crystal.disordered_molecule |
||
228 | if self.disorder == 'ordered_sites': |
||
229 | molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?')) |
||
230 | |||
231 | may_have_disorder = False |
||
232 | if self.disorder == 'skip': |
||
233 | for a in molecule.atoms: |
||
234 | occ = a.occupancy |
||
235 | if _atom_has_disorder(a.label, occ): |
||
236 | may_have_disorder = True |
||
237 | break |
||
238 | |||
239 | if not may_have_disorder: |
||
240 | if crystal.has_disorder or entry.has_disorder: |
||
241 | if self.show_warnings: |
||
242 | warnings.warn(f'Skipping {self.current_identifier} as structure is disordered') |
||
243 | return None |
||
244 | |||
245 | if self.remove_hydrogens: |
||
246 | molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD') |
||
247 | |||
248 | if self.heaviest_component: |
||
249 | molecule = _Reader._heaviest_component(molecule) |
||
250 | |||
251 | crystal.molecule = molecule |
||
252 | |||
253 | # by here all atoms to be removed have been (except via ordered_sites). |
||
254 | # If disorder == 'skip' and there were atom(s) with occ < 1 found |
||
255 | # eariler, we check if all such atoms were removed. If not, skip. |
||
256 | if self.disorder == 'skip' and may_have_disorder: |
||
257 | for a in crystal.disordered_molecule.atoms: |
||
258 | occ = a.occupancy |
||
259 | if _atom_has_disorder(a.label, occ): |
||
260 | if self.show_warnings: |
||
261 | warnings.warn( |
||
262 | f'Skipping {self.current_identifier} as structure is disordered') |
||
263 | return None |
||
264 | |||
265 | # if disorder is all_sites, we need to know where disorder is to ignore overlaps |
||
266 | asym_is_disordered = [] # True/False list same length as asym unit |
||
267 | if self.disorder == 'all_sites': |
||
268 | for a in crystal.asymmetric_unit_molecule.atoms: |
||
269 | occ = a.occupancy |
||
270 | if _atom_has_disorder(a.label, occ): |
||
271 | asym_is_disordered.append(True) |
||
272 | else: |
||
273 | asym_is_disordered.append(False) |
||
274 | |||
275 | # check all atoms have coords. option/default remove unknown sites? |
||
276 | if not molecule.all_atoms_have_sites or \ |
||
277 | any(a.fractional_coordinates is None for a in molecule.atoms): |
||
278 | if self.show_warnings: |
||
279 | warnings.warn( |
||
280 | f'Skipping {self.current_identifier} as some atoms do not have sites') |
||
281 | return None |
||
282 | |||
283 | # get cell & asymmetric unit |
||
284 | cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles) |
||
285 | asym_frac_motif = np.array([tuple(a.fractional_coordinates) |
||
286 | for a in crystal.asymmetric_unit_molecule.atoms]) |
||
287 | asym_frac_motif = np.mod(asym_frac_motif, 1) |
||
288 | asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms] |
||
289 | |||
290 | # remove overlapping sites, check sites exist |
||
291 | keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered) |
||
292 | if keep_sites is not None: |
||
293 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
294 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
295 | |||
296 | if self._has_no_valid_sites(asym_frac_motif): |
||
297 | return None |
||
298 | |||
299 | sitesym = crystal.symmetry_operators |
||
300 | if not sitesym: |
||
301 | sitesym = ['x,y,z', ] |
||
302 | |||
303 | entry.crystal.molecule = crystal.disordered_molecule # for extract_data. remove? |
||
304 | |||
305 | return self._construct_periodic_set(entry, asym_frac_motif, asym_symbols, sitesym, cell) |
||
306 | |||
307 | def expand( |
||
308 | self, |
||
309 | asym_frac_motif: np.ndarray, |
||
310 | sitesym: Sequence[str] |
||
311 | ) -> Tuple[np.ndarray, ...]: |
||
312 | """ |
||
313 | Asymmetric unit's fractional coords + sitesyms (as strings) |
||
314 | --> |
||
315 | frac_motif, asym_unit, multiplicities, inverses |
||
316 | """ |
||
317 | |||
318 | rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym) |
||
319 | all_sites = [] |
||
320 | asym_unit = [0] |
||
321 | multiplicities = [] |
||
322 | inverses = [] |
||
323 | |||
324 | for inv, site in enumerate(asym_frac_motif): |
||
325 | multiplicity = 0 |
||
326 | |||
327 | for rot, trans in zip(rotations, translations): |
||
328 | site_ = np.mod(np.dot(rot, site) + trans, 1) |
||
329 | |||
330 | if not all_sites: |
||
331 | all_sites.append(site_) |
||
332 | inverses.append(inv) |
||
333 | multiplicity += 1 |
||
334 | continue |
||
335 | |||
336 | if not self._is_site_overlapping(site_, all_sites, inverses, inv): |
||
337 | all_sites.append(site_) |
||
338 | inverses.append(inv) |
||
339 | multiplicity += 1 |
||
340 | |||
341 | if multiplicity > 0: |
||
342 | multiplicities.append(multiplicity) |
||
343 | asym_unit.append(len(all_sites)) |
||
344 | |||
345 | frac_motif = np.array(all_sites) |
||
346 | asym_unit = np.array(asym_unit[:-1]) |
||
347 | multiplicities = np.array(multiplicities) |
||
348 | return frac_motif, asym_unit, multiplicities, inverses |
||
349 | |||
350 | def _is_site_overlapping(self, new_site, all_sites, inverses, inv): |
||
351 | """Return True (and warn) if new_site overlaps with a site in all_sites.""" |
||
352 | diffs1 = np.abs(new_site - all_sites) |
||
353 | diffs2 = np.abs(diffs1 - 1) |
||
354 | mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol, |
||
355 | diffs2 <= _Reader.equiv_site_tol), |
||
356 | axis=-1) |
||
357 | |||
358 | if np.any(mask): |
||
359 | where_equal = np.argwhere(mask).flatten() |
||
360 | for ind in where_equal: |
||
361 | if inverses[ind] == inv: |
||
362 | pass |
||
363 | else: |
||
364 | if self.show_warnings: |
||
365 | warnings.warn( |
||
366 | f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}') |
||
367 | return True |
||
368 | else: |
||
369 | return False |
||
370 | |||
371 | def _validate_sites(self, asym_frac_motif, asym_is_disordered): |
||
372 | site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif) |
||
373 | site_diffs2 = np.abs(site_diffs1 - 1) |
||
374 | overlapping = np.triu(np.all( |
||
375 | (site_diffs1 <= _Reader.equiv_site_tol) | |
||
376 | (site_diffs2 <= _Reader.equiv_site_tol), |
||
377 | axis=-1), 1) |
||
378 | |||
379 | if self.disorder == 'all_sites': |
||
380 | for i, j in np.argwhere(overlapping): |
||
381 | if asym_is_disordered[i] or asym_is_disordered[j]: |
||
382 | overlapping[i, j] = False |
||
383 | |||
384 | if overlapping.any(): |
||
385 | if self.show_warnings: |
||
386 | warnings.warn( |
||
387 | f'{self.current_identifier} may have overlapping sites; duplicates will be removed') |
||
388 | keep_sites = ~overlapping.any(0) |
||
389 | return keep_sites |
||
390 | |||
391 | def _has_no_valid_sites(self, motif): |
||
392 | if motif.shape[0] == 0: |
||
393 | if self.show_warnings: |
||
394 | warnings.warn( |
||
395 | f'Skipping {self.current_identifier} as there are no sites with coordinates') |
||
396 | return True |
||
397 | return False |
||
398 | |||
399 | def _construct_periodic_set(self, raw_item, asym_frac_motif, asym_symbols, sitesym, cell): |
||
400 | frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym) |
||
401 | full_types = [asym_symbols[i] for i in inverses] |
||
402 | motif = frac_motif @ cell |
||
403 | |||
404 | kwargs = { |
||
405 | 'name': self.current_identifier, |
||
406 | 'asymmetric_unit': asym_unit, |
||
407 | 'wyckoff_multiplicities': multiplicities, |
||
408 | 'types': full_types, |
||
409 | } |
||
410 | |||
411 | if self.current_filename: |
||
412 | kwargs['filename'] = self.current_filename |
||
413 | |||
414 | if self.extract_data is not None: |
||
415 | for key in self.extract_data: |
||
416 | kwargs[key] = self.extract_data[key](raw_item) |
||
417 | |||
418 | return PeriodicSet(motif, cell, **kwargs) |
||
419 | |||
420 | def _heaviest_component(molecule): |
||
421 | """Heaviest component (removes all but the heaviest component of the asym unit). |
||
422 | Intended for removing solvents. Probably doesn't play well with disorder""" |
||
423 | if len(molecule.components) > 1: |
||
424 | component_weights = [] |
||
425 | for component in molecule.components: |
||
426 | weight = 0 |
||
427 | for a in component.atoms: |
||
428 | if isinstance(a.atomic_weight, (float, int)): |
||
429 | if isinstance(a.occupancy, (float, int)): |
||
430 | weight += a.occupancy * a.atomic_weight |
||
431 | else: |
||
432 | weight += a.atomic_weight |
||
433 | component_weights.append(weight) |
||
434 | largest_component_arg = np.argmax(np.array(component_weights)) |
||
435 | molecule = molecule.components[largest_component_arg] |
||
436 | |||
437 | return molecule |
||
438 |