Total Complexity | 99 |
Total Lines | 445 |
Duplicated Lines | 0 % |
Changes | 0 |
Complex classes like amd._reader often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | """Contains base reader class for the io module. |
||
|
|||
2 | This class implements the converters from CifBlock, Entry to PeriodicSets. |
||
3 | """ |
||
4 | |||
5 | import warnings |
||
6 | from typing import Callable, Iterable, Sequence, Tuple |
||
7 | |||
8 | import numpy as np |
||
9 | import ase.spacegroup.spacegroup # parse_sitesym |
||
10 | import ase.io.cif |
||
11 | |||
12 | from .periodicset import PeriodicSet |
||
13 | from .utils import cellpar_to_cell |
||
14 | |||
15 | |||
16 | def _warning(message, category, filename, lineno, *args, **kwargs): |
||
2 ignored issues
–
show
|
|||
17 | return f'{filename}:{lineno}: {category.__name__}: {message}\n' |
||
18 | |||
19 | warnings.formatwarning = _warning |
||
20 | |||
21 | |||
22 | class _Reader: |
||
23 | """Base Reader class. Contains parsers for converting ase CifBlock |
||
24 | and ccdc Entry objects to PeriodicSets. |
||
25 | |||
26 | Intended use: |
||
27 | |||
28 | First make a new method for _Reader converting object to PeriodicSet |
||
29 | (e.g. named _X_to_PSet). Then make this class outline: |
||
30 | |||
31 | class XReader(_Reader): |
||
32 | def __init__(self, ..., **kwargs): |
||
33 | |||
34 | super().__init__(**kwargs) |
||
35 | |||
36 | # setup and checks |
||
37 | |||
38 | # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry) |
||
39 | |||
40 | # set self._generator like this |
||
41 | self._generator = self._read(iterable, self._X_to_PSet) |
||
42 | """ |
||
43 | |||
44 | disorder_options = {'skip', 'ordered_sites', 'all_sites'} |
||
45 | reserved_tags = { |
||
46 | 'motif', |
||
47 | 'cell', |
||
48 | 'name', |
||
49 | 'asymmetric_unit', |
||
50 | 'wyckoff_multiplicities', |
||
51 | 'types', |
||
52 | 'filename',} |
||
53 | atom_site_fract_tags = [ |
||
54 | '_atom_site_fract_x', |
||
55 | '_atom_site_fract_y', |
||
56 | '_atom_site_fract_z',] |
||
57 | atom_site_cartn_tags = [ |
||
58 | '_atom_site_cartn_x', |
||
59 | '_atom_site_cartn_y', |
||
60 | '_atom_site_cartn_z',] |
||
61 | symop_tags = [ |
||
62 | '_space_group_symop_operation_xyz', |
||
63 | '_space_group_symop.operation_xyz', |
||
64 | '_symmetry_equiv_pos_as_xyz',] |
||
65 | |||
66 | equiv_site_tol = 1e-3 |
||
67 | |||
68 | def __init__( |
||
69 | self, |
||
70 | remove_hydrogens=False, |
||
71 | disorder='skip', |
||
72 | heaviest_component=False, |
||
73 | show_warnings=True, |
||
74 | extract_data=None, |
||
75 | include_if=None): |
||
76 | |||
77 | # settings |
||
78 | if disorder not in _Reader.disorder_options: |
||
79 | raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}') |
||
80 | |||
81 | if extract_data: |
||
82 | _validate_extract_data(extract_data) |
||
83 | |||
84 | if include_if: |
||
85 | for func in include_if: |
||
86 | if not callable(func): |
||
87 | raise ValueError('include_if must be a list of callables') |
||
88 | |||
89 | self.remove_hydrogens = remove_hydrogens |
||
90 | self.disorder = disorder |
||
91 | self.heaviest_component = heaviest_component |
||
92 | self.extract_data = extract_data |
||
93 | self.include_if = include_if |
||
94 | self.show_warnings = show_warnings |
||
95 | self.current_identifier = None |
||
96 | self.current_filename = None |
||
97 | self._generator = [] |
||
98 | |||
99 | def __iter__(self): |
||
100 | yield from self._generator |
||
101 | |||
102 | def read_one(self): |
||
103 | """Read the next (or first) item.""" |
||
104 | return next(iter(self._generator)) |
||
105 | |||
106 | # basically the builtin map, but skips items if the function returned None. |
||
107 | # The object returned by this function (Iterable of PeriodicSets) is set to |
||
108 | # self._generator; then iterating over the Reader iterates over |
||
109 | # self._generator. |
||
110 | @staticmethod |
||
111 | def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]: |
||
112 | """Iterates over iterable, passing items through parser and |
||
113 | yielding the result if it is not None. |
||
114 | """ |
||
115 | |||
116 | for item in iterable: |
||
117 | res = func(item) |
||
118 | if res is not None: |
||
119 | yield res |
||
120 | |||
121 | def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet: |
||
122 | """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set.""" |
||
123 | |||
124 | # skip if structure does not pass checks in include_if |
||
125 | if self.include_if: |
||
126 | if not all(check(block) for check in self.include_if): |
||
127 | return None |
||
128 | |||
129 | # read name, cell, asym motif and atomic symbols |
||
130 | self.current_identifier = block.name |
||
131 | cell = block.get_cell().array |
||
132 | asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags] |
||
133 | if None in asym_frac_motif: |
||
134 | asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags] |
||
135 | if None in asym_motif: |
||
136 | if self.show_warnings: |
||
137 | warnings.warn( |
||
138 | f'Skipping {self.current_identifier} as coordinates were not found') |
||
139 | return None |
||
140 | asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell) |
||
141 | asym_frac_motif = np.array(asym_frac_motif).T |
||
142 | |||
143 | try: |
||
144 | asym_symbols = block.get_symbols() |
||
145 | except ase.io.cif.NoStructureData as _: |
||
146 | asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))] |
||
147 | |||
148 | # indices of sites to remove |
||
149 | remove = [] |
||
150 | if self.remove_hydrogens: |
||
151 | remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD')) |
||
152 | |||
153 | # find disordered sites |
||
154 | asym_is_disordered = [] |
||
155 | occupancies = block.get('_atom_site_occupancy') |
||
156 | labels = block.get('_atom_site_label') |
||
157 | if occupancies is not None: |
||
158 | disordered = [] # indices where there is disorder |
||
159 | for i, (occ, label) in enumerate(zip(occupancies, labels)): |
||
160 | if _atom_has_disorder(label, occ): |
||
161 | if i not in remove: |
||
162 | disordered.append(i) |
||
163 | asym_is_disordered.append(True) |
||
164 | else: |
||
165 | asym_is_disordered.append(False) |
||
166 | |||
167 | if self.disorder == 'skip' and len(disordered) > 0: |
||
168 | if self.show_warnings: |
||
169 | warnings.warn( |
||
170 | f'Skipping {self.current_identifier} as structure is disordered') |
||
171 | return None |
||
172 | |||
173 | if self.disorder == 'ordered_sites': |
||
174 | remove.extend(disordered) |
||
175 | |||
176 | # remove sites |
||
177 | asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1) |
||
178 | asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove] |
||
179 | asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove] |
||
180 | |||
181 | keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered) |
||
182 | if keep_sites is not None: |
||
183 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
184 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
185 | |||
186 | if self._has_no_valid_sites(asym_frac_motif): |
||
187 | return None |
||
188 | |||
189 | sitesym = ['x,y,z', ] |
||
190 | for tag in _Reader.symop_tags: |
||
191 | if tag in block: |
||
192 | sitesym = block[tag] |
||
193 | break |
||
194 | |||
195 | if isinstance(sitesym, str): |
||
196 | sitesym = [sitesym] |
||
197 | |||
198 | return self._construct_periodic_set(block, asym_frac_motif, asym_symbols, sitesym, cell) |
||
199 | |||
200 | |||
201 | def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet: |
||
202 | """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set.""" |
||
203 | |||
204 | # skip if structure does not pass checks in include_if |
||
205 | if self.include_if: |
||
206 | if not all(check(entry) for check in self.include_if): |
||
207 | return None |
||
208 | |||
209 | self.current_identifier = entry.identifier |
||
210 | # structure must pass this test |
||
211 | if not entry.has_3d_structure: |
||
212 | if self.show_warnings: |
||
213 | warnings.warn( |
||
214 | f'Skipping {self.current_identifier} as entry has no 3D structure') |
||
215 | return None |
||
216 | |||
217 | crystal = entry.crystal |
||
218 | |||
219 | # first disorder check, if skipping. If occ == 1 for all atoms but the entry |
||
220 | # or crystal is listed as having disorder, skip (can't know where disorder is). |
||
221 | # If occ != 1 for any atoms, we wait to see if we remove them before skipping. |
||
222 | molecule = crystal.disordered_molecule |
||
223 | if self.disorder == 'ordered_sites': |
||
224 | molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?')) |
||
225 | |||
226 | may_have_disorder = False |
||
227 | if self.disorder == 'skip': |
||
228 | for a in molecule.atoms: |
||
229 | occ = a.occupancy |
||
230 | if _atom_has_disorder(a.label, occ): |
||
231 | may_have_disorder = True |
||
232 | break |
||
233 | |||
234 | if not may_have_disorder: |
||
235 | if crystal.has_disorder or entry.has_disorder: |
||
236 | if self.show_warnings: |
||
237 | warnings.warn(f'Skipping {self.current_identifier} as structure is disordered') |
||
238 | return None |
||
239 | |||
240 | if self.remove_hydrogens: |
||
241 | molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD') |
||
242 | |||
243 | if self.heaviest_component: |
||
244 | molecule = _heaviest_component(molecule) |
||
245 | |||
246 | crystal.molecule = molecule |
||
247 | |||
248 | # by here all atoms to be removed have been (except via ordered_sites). |
||
249 | # If disorder == 'skip' and there were atom(s) with occ < 1 found |
||
250 | # eariler, we check if all such atoms were removed. If not, skip. |
||
251 | if self.disorder == 'skip' and may_have_disorder: |
||
252 | for a in crystal.disordered_molecule.atoms: |
||
253 | occ = a.occupancy |
||
254 | if _atom_has_disorder(a.label, occ): |
||
255 | if self.show_warnings: |
||
256 | warnings.warn( |
||
257 | f'Skipping {self.current_identifier} as structure is disordered') |
||
258 | return None |
||
259 | |||
260 | # if disorder is all_sites, we need to know where disorder is to ignore overlaps |
||
261 | asym_is_disordered = [] # True/False list same length as asym unit |
||
262 | if self.disorder == 'all_sites': |
||
263 | for a in crystal.asymmetric_unit_molecule.atoms: |
||
264 | occ = a.occupancy |
||
265 | if _atom_has_disorder(a.label, occ): |
||
266 | asym_is_disordered.append(True) |
||
267 | else: |
||
268 | asym_is_disordered.append(False) |
||
269 | |||
270 | # check all atoms have coords. option/default remove unknown sites? |
||
271 | if not molecule.all_atoms_have_sites or \ |
||
272 | any(a.fractional_coordinates is None for a in molecule.atoms): |
||
273 | if self.show_warnings: |
||
274 | warnings.warn( |
||
275 | f'Skipping {self.current_identifier} as some atoms do not have sites') |
||
276 | return None |
||
277 | |||
278 | # get cell & asymmetric unit |
||
279 | cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles) |
||
280 | asym_frac_motif = np.array([tuple(a.fractional_coordinates) |
||
281 | for a in crystal.asymmetric_unit_molecule.atoms]) |
||
282 | asym_frac_motif = np.mod(asym_frac_motif, 1) |
||
283 | asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms] |
||
284 | |||
285 | # remove overlapping sites, check sites exist |
||
286 | keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered) |
||
287 | if keep_sites is not None: |
||
288 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
289 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
290 | |||
291 | if self._has_no_valid_sites(asym_frac_motif): |
||
292 | return None |
||
293 | |||
294 | sitesym = crystal.symmetry_operators |
||
295 | if not sitesym: |
||
296 | sitesym = ['x,y,z', ] |
||
297 | |||
298 | entry.crystal.molecule = crystal.disordered_molecule # for extract_data. remove? |
||
299 | |||
300 | return self._construct_periodic_set(entry, asym_frac_motif, asym_symbols, sitesym, cell) |
||
301 | |||
302 | def expand( |
||
303 | self, |
||
304 | asym_frac_motif: np.ndarray, |
||
305 | sitesym: Sequence[str] |
||
306 | ) -> Tuple[np.ndarray, ...]: |
||
307 | """ |
||
308 | Asymmetric unit's fractional coords + sitesyms (as strings) |
||
309 | --> |
||
310 | frac_motif, asym_unit, multiplicities, inverses |
||
311 | """ |
||
312 | |||
313 | rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym) |
||
314 | all_sites = [] |
||
315 | asym_unit = [0] |
||
316 | multiplicities = [] |
||
317 | inverses = [] |
||
318 | |||
319 | for inv, site in enumerate(asym_frac_motif): |
||
320 | multiplicity = 0 |
||
321 | |||
322 | for rot, trans in zip(rotations, translations): |
||
323 | site_ = np.mod(np.dot(rot, site) + trans, 1) |
||
324 | |||
325 | if not all_sites: |
||
326 | all_sites.append(site_) |
||
327 | inverses.append(inv) |
||
328 | multiplicity += 1 |
||
329 | continue |
||
330 | |||
331 | if not self._is_site_overlapping(site_, all_sites, inverses, inv): |
||
332 | all_sites.append(site_) |
||
333 | inverses.append(inv) |
||
334 | multiplicity += 1 |
||
335 | |||
336 | if multiplicity > 0: |
||
337 | multiplicities.append(multiplicity) |
||
338 | asym_unit.append(len(all_sites)) |
||
339 | |||
340 | frac_motif = np.array(all_sites) |
||
341 | asym_unit = np.array(asym_unit[:-1]) |
||
342 | multiplicities = np.array(multiplicities) |
||
343 | return frac_motif, asym_unit, multiplicities, inverses |
||
344 | |||
345 | def _is_site_overlapping(self, new_site, all_sites, inverses, inv): |
||
346 | """Return True (and warn) if new_site overlaps with a site in all_sites.""" |
||
347 | diffs1 = np.abs(new_site - all_sites) |
||
348 | diffs2 = np.abs(diffs1 - 1) |
||
349 | mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol, |
||
350 | diffs2 <= _Reader.equiv_site_tol), |
||
351 | axis=-1) |
||
352 | |||
353 | if np.any(mask): |
||
354 | where_equal = np.argwhere(mask).flatten() |
||
355 | for ind in where_equal: |
||
356 | if inverses[ind] == inv: |
||
357 | pass |
||
358 | else: |
||
359 | if self.show_warnings: |
||
360 | warnings.warn( |
||
361 | f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}') |
||
362 | return True |
||
363 | else: |
||
364 | return False |
||
365 | |||
366 | def _validate_sites(self, asym_frac_motif, asym_is_disordered): |
||
367 | site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif) |
||
368 | site_diffs2 = np.abs(site_diffs1 - 1) |
||
369 | overlapping = np.triu(np.all( |
||
370 | (site_diffs1 <= _Reader.equiv_site_tol) | |
||
371 | (site_diffs2 <= _Reader.equiv_site_tol), |
||
372 | axis=-1), 1) |
||
373 | |||
374 | if self.disorder == 'all_sites': |
||
375 | for i, j in np.argwhere(overlapping): |
||
376 | if asym_is_disordered[i] or asym_is_disordered[j]: |
||
377 | overlapping[i, j] = False |
||
378 | |||
379 | if overlapping.any(): |
||
380 | if self.show_warnings: |
||
381 | warnings.warn( |
||
382 | f'{self.current_identifier} may have overlapping sites; duplicates will be removed') |
||
383 | keep_sites = ~overlapping.any(0) |
||
384 | return keep_sites |
||
385 | |||
386 | def _has_no_valid_sites(self, motif): |
||
387 | if motif.shape[0] == 0: |
||
388 | if self.show_warnings: |
||
389 | warnings.warn( |
||
390 | f'Skipping {self.current_identifier} as there are no sites with coordinates') |
||
391 | return True |
||
392 | return False |
||
393 | |||
394 | def _construct_periodic_set(self, raw_item, asym_frac_motif, asym_symbols, sitesym, cell): |
||
395 | frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym) |
||
396 | full_types = [asym_symbols[i] for i in inverses] |
||
397 | motif = frac_motif @ cell |
||
398 | |||
399 | kwargs = { |
||
400 | 'name': self.current_identifier, |
||
401 | 'asymmetric_unit': asym_unit, |
||
402 | 'wyckoff_multiplicities': multiplicities, |
||
403 | 'types': full_types, |
||
404 | } |
||
405 | |||
406 | if self.current_filename: |
||
407 | kwargs['filename'] = self.current_filename |
||
408 | |||
409 | if self.extract_data is not None: |
||
410 | for key in self.extract_data: |
||
411 | kwargs[key] = self.extract_data[key](raw_item) |
||
412 | |||
413 | return PeriodicSet(motif, cell, **kwargs) |
||
414 | |||
415 | def _heaviest_component(molecule): |
||
416 | """Heaviest component (removes all but the heaviest component of the asym unit). |
||
417 | Intended for removing solvents. Probably doesn't play well with disorder""" |
||
418 | if len(molecule.components) > 1: |
||
419 | component_weights = [] |
||
420 | for component in molecule.components: |
||
421 | weight = 0 |
||
422 | for a in component.atoms: |
||
423 | if isinstance(a.atomic_weight, (float, int)): |
||
424 | if isinstance(a.occupancy, (float, int)): |
||
425 | weight += a.occupancy * a.atomic_weight |
||
426 | else: |
||
427 | weight += a.atomic_weight |
||
428 | component_weights.append(weight) |
||
429 | largest_component_arg = np.argmax(np.array(component_weights)) |
||
430 | molecule = molecule.components[largest_component_arg] |
||
431 | |||
432 | return molecule |
||
433 | |||
434 | def _atom_has_disorder(label, occupancy): |
||
435 | return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1) |
||
436 | |||
437 | def _validate_extract_data(extract_data): |
||
438 | if not isinstance(extract_data, dict): |
||
439 | raise ValueError('extract_data must be a dict with callable values') |
||
440 | for key in extract_data: |
||
441 | if not callable(extract_data[key]): |
||
442 | raise ValueError('extract_data must be a dict with callable values') |
||
443 | if key in _Reader.reserved_tags: |
||
444 | raise ValueError(f'extract_data includes reserved key {key}') |
||