Total Complexity | 173 |
Total Lines | 908 |
Duplicated Lines | 0 % |
Changes | 0 |
Complex classes like amd.io often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | """Contains I/O tools, including a .CIF reader and CSD reader |
||
2 | (``csd-python-api`` only) to extract periodic set representations |
||
3 | of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`. |
||
4 | |||
5 | These intermediate :class:`.periodicset.PeriodicSet` representations can be written |
||
6 | to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`. |
||
7 | This is much faster than rereading a .CIF and recomputing invariants. |
||
8 | """ |
||
9 | |||
10 | import os |
||
11 | import warnings |
||
12 | from typing import Callable, Iterable, Sequence, Tuple, Optional |
||
13 | |||
14 | import numpy as np |
||
15 | import ase.spacegroup.spacegroup # parse_sitesym |
||
16 | import ase.io.cif |
||
17 | import h5py |
||
18 | |||
19 | from .periodicset import PeriodicSet |
||
20 | from .utils import _extend_signature, cellpar_to_cell |
||
21 | |||
22 | try: |
||
23 | import ccdc.io # EntryReader |
||
24 | import ccdc.search # TextNumericSearch |
||
25 | _CCDC_ENABLED = True |
||
26 | except (ImportError, RuntimeError) as e: |
||
1 ignored issue
–
show
|
|||
27 | _CCDC_ENABLED = False |
||
28 | |||
29 | |||
30 | def _warning(message, category, filename, lineno, file=None, line=None): |
||
31 | return f'{filename}:{lineno}: {category.__name__}: {message}\n' |
||
32 | |||
33 | warnings.formatwarning = _warning |
||
34 | |||
35 | |||
36 | def _atom_has_disorder(label, occupancy): |
||
37 | return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1) |
||
38 | |||
39 | |||
40 | class _Reader: |
||
41 | """Base Reader class. Contains parsers for converting ase CifBlock |
||
42 | and ccdc Entry objects to PeriodicSets. |
||
43 | |||
44 | Intended use: |
||
45 | |||
46 | First make a new method for _Reader converting object to PeriodicSet |
||
47 | (e.g. named _X_to_PSet). Then make this class outline: |
||
48 | |||
49 | class XReader(_Reader): |
||
50 | def __init__(self, ..., **kwargs): |
||
51 | |||
52 | super().__init__(**kwargs) |
||
53 | |||
54 | # setup and checks |
||
55 | |||
56 | # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry) |
||
57 | |||
58 | # set self._generator like this |
||
59 | self._generator = self._read(iterable, self._X_to_PSet) |
||
60 | """ |
||
61 | |||
62 | disorder_options = {'skip', 'ordered_sites', 'all_sites'} |
||
63 | reserved_tags = { |
||
64 | 'motif', |
||
65 | 'cell', |
||
66 | 'name', |
||
67 | 'asymmetric_unit', |
||
68 | 'wyckoff_multiplicities', |
||
69 | 'types',} |
||
70 | atom_site_fract_tags = [ |
||
71 | '_atom_site_fract_x', |
||
72 | '_atom_site_fract_y', |
||
73 | '_atom_site_fract_z',] |
||
74 | atom_site_cartn_tags = [ |
||
75 | '_atom_site_cartn_x', |
||
76 | '_atom_site_cartn_y', |
||
77 | '_atom_site_cartn_z',] |
||
78 | symop_tags = [ |
||
79 | '_space_group_symop_operation_xyz', |
||
80 | '_space_group_symop.operation_xyz', |
||
81 | '_symmetry_equiv_pos_as_xyz',] |
||
82 | |||
83 | equiv_site_tol = 1e-3 |
||
84 | |||
85 | def __init__( |
||
86 | self, |
||
87 | remove_hydrogens=False, |
||
88 | disorder='skip', |
||
89 | heaviest_component=False, |
||
90 | show_warnings=True, |
||
91 | extract_data=None, |
||
92 | include_if=None): |
||
93 | |||
94 | # settings |
||
95 | if disorder not in _Reader.disorder_options: |
||
96 | raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}') |
||
97 | |||
98 | if extract_data: |
||
99 | if not isinstance(extract_data, dict): |
||
100 | raise ValueError('extract_data must be a dict with callable values') |
||
101 | for key in extract_data: |
||
102 | if not callable(extract_data[key]): |
||
103 | raise ValueError('extract_data must be a dict with callable values') |
||
104 | if key in _Reader.reserved_tags: |
||
105 | raise ValueError(f'extract_data includes reserved key {key}') |
||
106 | |||
107 | if include_if: |
||
108 | for func in include_if: |
||
109 | if not callable(func): |
||
110 | raise ValueError('include_if must be a list of callables') |
||
111 | |||
112 | self.remove_hydrogens = remove_hydrogens |
||
113 | self.disorder = disorder |
||
114 | self.heaviest_component = heaviest_component |
||
115 | self.extract_data = extract_data |
||
116 | self.include_if = include_if |
||
117 | self.show_warnings = show_warnings |
||
118 | self.current_identifier = None |
||
119 | self.current_filename = None |
||
120 | |||
121 | def __iter__(self): |
||
122 | yield from self._generator |
||
123 | |||
124 | def read_one(self): |
||
125 | """Read the next (or first) item.""" |
||
126 | return next(iter(self._generator)) |
||
127 | |||
128 | # basically the builtin map, but skips items if the function returned None. |
||
129 | # The object returned by this function (Iterable of PeriodicSets) is set to |
||
130 | # self._generator; then iterating over the Reader iterates over |
||
131 | # self._generator. |
||
132 | @staticmethod |
||
133 | def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]: |
||
134 | """Iterates over iterable, passing items through parser and |
||
135 | yielding the result if it is not None. |
||
136 | """ |
||
137 | |||
138 | for item in iterable: |
||
139 | res = func(item) |
||
140 | if res is not None: |
||
141 | yield res |
||
142 | |||
143 | def _expand( |
||
1 ignored issue
–
show
|
|||
144 | self, |
||
145 | asym_frac_motif: np.ndarray, |
||
146 | sitesym: Sequence[str] |
||
147 | ) -> Tuple[np.ndarray, ...]: |
||
148 | """ |
||
149 | Asymmetric unit's fractional coords + sitesyms (as strings) |
||
150 | --> |
||
151 | frac_motif, asym_unit, multiplicities, inverses |
||
152 | """ |
||
153 | |||
154 | rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym) |
||
155 | all_sites = [] |
||
156 | asym_unit = [0] |
||
157 | multiplicities = [] |
||
158 | inverses = [] |
||
159 | |||
160 | for inv, site in enumerate(asym_frac_motif): |
||
161 | multiplicity = 0 |
||
162 | |||
163 | for rot, trans in zip(rotations, translations): |
||
164 | site_ = np.mod(np.dot(rot, site) + trans, 1) |
||
165 | |||
166 | if not all_sites: |
||
167 | all_sites.append(site_) |
||
168 | inverses.append(inv) |
||
169 | multiplicity += 1 |
||
170 | continue |
||
171 | |||
172 | diffs1 = np.abs(site_ - all_sites) |
||
173 | diffs2 = np.abs(diffs1 - 1) |
||
174 | mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol, |
||
175 | diffs2 <= _Reader.equiv_site_tol), |
||
176 | axis=-1) |
||
177 | |||
178 | if np.any(mask): |
||
179 | where_equal = np.argwhere(mask).flatten() |
||
180 | for ind in where_equal: |
||
181 | if inverses[ind] == inv: |
||
182 | pass |
||
183 | else: |
||
184 | if self.show_warnings: |
||
185 | warnings.warn( |
||
186 | f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}') |
||
187 | else: |
||
188 | all_sites.append(site_) |
||
189 | inverses.append(inv) |
||
190 | multiplicity += 1 |
||
191 | |||
192 | if multiplicity > 0: |
||
193 | multiplicities.append(multiplicity) |
||
194 | asym_unit.append(len(all_sites)) |
||
195 | |||
196 | frac_motif = np.array(all_sites) |
||
197 | asym_unit = np.array(asym_unit[:-1]) |
||
198 | multiplicities = np.array(multiplicities) |
||
199 | return frac_motif, asym_unit, multiplicities, inverses |
||
200 | |||
201 | def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet: |
||
2 ignored issues
–
show
|
|||
202 | """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set.""" |
||
203 | |||
204 | # skip if structure does not pass checks in include_if |
||
205 | if self.include_if: |
||
206 | if not all(check(block) for check in self.include_if): |
||
207 | return None |
||
208 | |||
209 | # read name, cell, asym motif and atomic symbols |
||
210 | self.current_identifier = block.name |
||
211 | cell = block.get_cell().array |
||
212 | asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags] |
||
213 | if None in asym_frac_motif: |
||
214 | asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags] |
||
215 | if None in asym_motif: |
||
216 | if self.show_warnings: |
||
217 | warnings.warn( |
||
218 | f'Skipping {self.current_identifier} as coordinates were not found') |
||
219 | return None |
||
220 | asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell) |
||
221 | asym_frac_motif = np.array(asym_frac_motif).T |
||
222 | |||
223 | try: |
||
224 | asym_symbols = block.get_symbols() |
||
225 | except ase.io.cif.NoStructureData as _: |
||
226 | asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))] |
||
227 | |||
228 | # indices of sites to remove |
||
229 | remove = [] |
||
230 | if self.remove_hydrogens: |
||
231 | remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD')) |
||
232 | |||
233 | # find disordered sites |
||
234 | asym_is_disordered = [] |
||
235 | occupancies = block.get('_atom_site_occupancy') |
||
236 | labels = block.get('_atom_site_label') |
||
237 | if occupancies is not None: |
||
238 | disordered = [] # indices where there is disorder |
||
239 | for i, (occ, label) in enumerate(zip(occupancies, labels)): |
||
240 | if _atom_has_disorder(label, occ): |
||
241 | if i not in remove: |
||
242 | disordered.append(i) |
||
243 | asym_is_disordered.append(True) |
||
244 | else: |
||
245 | asym_is_disordered.append(False) |
||
246 | |||
247 | if self.disorder == 'skip' and len(disordered) > 0: |
||
248 | if self.show_warnings: |
||
249 | warnings.warn( |
||
250 | f'Skipping {self.current_identifier} as structure is disordered') |
||
251 | return None |
||
252 | |||
253 | if self.disorder == 'ordered_sites': |
||
254 | remove.extend(disordered) |
||
255 | |||
256 | # remove sites |
||
257 | asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1) |
||
258 | asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove] |
||
259 | asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove] |
||
260 | |||
261 | # if there are overlapping sites in asym unit, warn and keep only one |
||
262 | site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif) |
||
263 | site_diffs2 = np.abs(site_diffs1 - 1) |
||
264 | overlapping = np.triu(np.all( |
||
265 | (site_diffs1 <= _Reader.equiv_site_tol) | |
||
266 | (site_diffs2 <= _Reader.equiv_site_tol), |
||
267 | axis=-1), 1) |
||
268 | |||
269 | # don't remove overlapping sites if one is disordered and disorder='all_sites' |
||
270 | if self.disorder == 'all_sites': |
||
271 | for i, j in np.argwhere(overlapping): |
||
272 | if asym_is_disordered[i] or asym_is_disordered[j]: |
||
273 | overlapping[i, j] = False |
||
274 | |||
275 | if overlapping.any(): |
||
276 | if self.show_warnings: |
||
277 | warnings.warn( |
||
278 | f'{self.current_identifier} may have overlapping sites; duplicates will be removed') |
||
279 | keep_sites = ~overlapping.any(0) |
||
280 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
281 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
282 | |||
283 | # if no points left in motif, skip structure |
||
284 | if asym_frac_motif.shape[0] == 0: |
||
285 | if self.show_warnings: |
||
286 | warnings.warn( |
||
287 | f'Skipping {self.current_identifier} as there are no sites with coordinates') |
||
288 | return None |
||
289 | |||
290 | # get symmetries |
||
291 | sitesym = ['x,y,z', ] |
||
292 | for tag in _Reader.symop_tags: |
||
293 | if tag in block: |
||
294 | sitesym = block[tag] |
||
295 | break |
||
296 | |||
297 | if isinstance(sitesym, str): |
||
298 | sitesym = [sitesym] |
||
299 | |||
300 | # expand the asymmetric unit to full motif + multiplicities |
||
301 | frac_motif, asym_unit, multiplicities, inverses = self._expand(asym_frac_motif, sitesym) |
||
302 | motif = frac_motif @ cell |
||
303 | |||
304 | # construct PeriodicSet |
||
305 | kwargs = { |
||
306 | 'name': self.current_identifier, |
||
307 | 'asymmetric_unit': asym_unit, |
||
308 | 'wyckoff_multiplicities': multiplicities, |
||
309 | 'types': [asym_symbols[i] for i in inverses], |
||
310 | } |
||
311 | |||
312 | if self.current_filename: |
||
313 | kwargs['filename'] = self.current_filename |
||
314 | |||
315 | if self.extract_data is not None: |
||
316 | for key in self.extract_data: |
||
317 | kwargs[key] = self.extract_data[key](block) |
||
318 | |||
319 | periodic_set = PeriodicSet(motif, cell, **kwargs) |
||
320 | |||
321 | return periodic_set |
||
322 | |||
323 | def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet: |
||
2 ignored issues
–
show
|
|||
324 | """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set.""" |
||
325 | |||
326 | # skip if structure does not pass checks in include_if |
||
327 | if self.include_if: |
||
328 | if not all(check(entry) for check in self.include_if): |
||
329 | return None |
||
330 | |||
331 | self.current_identifier = entry.identifier |
||
332 | # structure must pass this test |
||
333 | if not entry.has_3d_structure: |
||
334 | if self.show_warnings: |
||
335 | warnings.warn( |
||
336 | f'Skipping {self.current_identifier} as entry has no 3D structure') |
||
337 | return None |
||
338 | |||
339 | try: |
||
340 | crystal = entry.crystal |
||
341 | except RuntimeError as e: |
||
1 ignored issue
–
show
|
|||
342 | if self.show_warnings: |
||
343 | warnings.warn(f'Skipping {self.current_identifier}: {e}') |
||
344 | return None |
||
345 | |||
346 | # first disorder check, if skipping. If occ == 1 for all atoms but the entry |
||
347 | # or crystal is listed as having disorder, skip (can't know where disorder is). |
||
348 | # If occ != 1 for any atoms, we wait to see if we remove them before skipping. |
||
349 | molecule = crystal.disordered_molecule |
||
350 | if self.disorder == 'ordered_sites': |
||
351 | molecule.remove_atoms( |
||
352 | a for a in molecule.atoms if a.label.endswith('?')) |
||
353 | |||
354 | may_have_disorder = False |
||
355 | if self.disorder == 'skip': |
||
356 | for a in molecule.atoms: |
||
1 ignored issue
–
show
|
|||
357 | occ = a.occupancy |
||
358 | if _atom_has_disorder(a.label, occ): |
||
359 | may_have_disorder = True |
||
360 | break |
||
361 | |||
362 | if not may_have_disorder: |
||
363 | if crystal.has_disorder or entry.has_disorder: |
||
364 | if self.show_warnings: |
||
365 | warnings.warn( |
||
366 | f'Skipping {self.current_identifier} as structure is disordered') |
||
367 | return None |
||
368 | |||
369 | if self.remove_hydrogens: |
||
370 | molecule.remove_atoms( |
||
371 | a for a in molecule.atoms if a.atomic_symbol in 'HD') |
||
372 | |||
373 | # heaviest component (removes all but the heaviest component of the asym unit) |
||
374 | # intended for removing solvents. probably doesn't play well with disorder |
||
375 | if self.heaviest_component: |
||
376 | if len(molecule.components) > 1: |
||
377 | component_weights = [] |
||
378 | for component in molecule.components: |
||
379 | weight = 0 |
||
380 | for a in component.atoms: |
||
1 ignored issue
–
show
|
|||
381 | if isinstance(a.atomic_weight, (float, int)): |
||
382 | if isinstance(a.occupancy, (float, int)): |
||
383 | weight += a.occupancy * a.atomic_weight |
||
384 | else: |
||
385 | weight += a.atomic_weight |
||
386 | component_weights.append(weight) |
||
387 | largest_component_arg = np.argmax(np.array(component_weights)) |
||
388 | molecule = molecule.components[largest_component_arg] |
||
389 | |||
390 | crystal.molecule = molecule |
||
391 | |||
392 | # by here all atoms to be removed have been (except via ordered_sites). |
||
393 | # If disorder == 'skip' and there were atom(s) with occ < 1 found |
||
394 | # eariler, we check if all such atoms were removed. If not, skip. |
||
395 | if self.disorder == 'skip' and may_have_disorder: |
||
396 | for a in crystal.disordered_molecule.atoms: |
||
1 ignored issue
–
show
|
|||
397 | occ = a.occupancy |
||
398 | if _atom_has_disorder(a.label, occ): |
||
399 | if self.show_warnings: |
||
400 | warnings.warn( |
||
401 | f'Skipping {self.current_identifier} as structure is disordered') |
||
402 | return None |
||
403 | |||
404 | # if disorder is all_sites, we need to know where disorder is to ignore overlaps |
||
405 | asym_is_disordered = [] # True/False list same length as asym unit |
||
406 | if self.disorder == 'all_sites': |
||
407 | for a in crystal.asymmetric_unit_molecule.atoms: |
||
1 ignored issue
–
show
|
|||
408 | occ = a.occupancy |
||
409 | if _atom_has_disorder(a.label, occ): |
||
410 | asym_is_disordered.append(True) |
||
411 | else: |
||
412 | asym_is_disordered.append(False) |
||
413 | |||
414 | # check all atoms have coords. option/default remove unknown sites? |
||
415 | if not molecule.all_atoms_have_sites or any(a.fractional_coordinates is None for a in molecule.atoms): |
||
416 | if self.show_warnings: |
||
417 | warnings.warn( |
||
418 | f'Skipping {self.current_identifier} as some atoms do not have sites') |
||
419 | return None |
||
420 | |||
421 | # get cell & asymmetric unit |
||
422 | cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles) |
||
423 | asym_frac_motif = np.array([tuple(a.fractional_coordinates) |
||
424 | for a in crystal.asymmetric_unit_molecule.atoms]) |
||
425 | asym_frac_motif = np.mod(asym_frac_motif, 1) |
||
426 | asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms] |
||
427 | |||
428 | # if there are overlapping sites in asym unit, warn and keep only one |
||
429 | site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif) |
||
430 | site_diffs2 = np.abs(site_diffs1 - 1) |
||
431 | overlapping = np.triu(np.all((site_diffs1 <= _Reader.equiv_site_tol) | |
||
432 | (site_diffs2 <= _Reader.equiv_site_tol), |
||
433 | axis=-1), 1) |
||
434 | |||
435 | # don't remove overlapping sites if one is disordered and disorder='all_sites' |
||
436 | if self.disorder == 'all_sites': |
||
437 | for i, j in np.argwhere(overlapping): |
||
438 | if asym_is_disordered[i] or asym_is_disordered[j]: |
||
439 | overlapping[i, j] = False |
||
440 | |||
441 | if overlapping.any(): |
||
442 | if self.show_warnings: |
||
443 | warnings.warn( |
||
444 | f'{self.current_identifier} may have overlapping sites; ' |
||
445 | 'duplicates will be removed') |
||
446 | keep_sites = ~overlapping.any(0) |
||
447 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
448 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
449 | |||
450 | # if no points left in motif, skip structure |
||
451 | if asym_frac_motif.shape[0] == 0: |
||
452 | if self.show_warnings: |
||
453 | warnings.warn( |
||
454 | f'Skipping {self.current_identifier} as there are no sites with coordinates') |
||
455 | return None |
||
456 | |||
457 | # get symmetries, expand the asymmetric unit to full motif + multiplicities |
||
458 | sitesym = crystal.symmetry_operators |
||
459 | if not sitesym: |
||
460 | sitesym = ('x,y,z', ) |
||
461 | frac_motif, asym_unit, multiplicities, inverses = self._expand(asym_frac_motif, sitesym) |
||
462 | motif = frac_motif @ cell |
||
463 | |||
464 | # construct PeriodicSet |
||
465 | kwargs = { |
||
466 | 'name': self.current_identifier, |
||
467 | 'asymmetric_unit': asym_unit, |
||
468 | 'wyckoff_multiplicities': multiplicities, |
||
469 | 'types': [asym_symbols[i] for i in inverses], |
||
470 | } |
||
471 | |||
472 | if self.current_filename: kwargs['filename'] = self.current_filename |
||
473 | |||
474 | if self.extract_data is not None: |
||
475 | entry.crystal.molecule = crystal.disordered_molecule |
||
476 | for key in self.extract_data: |
||
477 | kwargs[key] = self.extract_data[key](entry) |
||
478 | |||
479 | periodic_set = PeriodicSet(motif, cell, **kwargs) |
||
480 | |||
481 | return periodic_set |
||
482 | |||
483 | |||
484 | class CifReader(_Reader): |
||
485 | """Read all structures in a .CIF with ``ase`` or ``ccdc`` |
||
486 | (``csd-python-api`` only), yielding :class:`.periodicset.PeriodicSet` |
||
487 | objects which can be passed to :func:`.calculate.AMD` or |
||
488 | :func:`.calculate.PDD`. |
||
489 | |||
490 | Examples: |
||
491 | |||
492 | :: |
||
493 | |||
494 | # Put all crystals in a .CIF in a list |
||
495 | structures = list(amd.CifReader('mycif.cif')) |
||
496 | |||
497 | # Reads just one if the .CIF has just one crystal |
||
498 | periodic_set = amd.CifReader('mycif.cif').read_one() |
||
499 | |||
500 | # If a folder has several .CIFs each with one crystal, use |
||
501 | structures = list(amd.CifReader('path/to/folder', folder=True)) |
||
502 | |||
503 | # Make list of AMDs (with k=100) of crystals in a .CIF |
||
504 | amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')] |
||
505 | """ |
||
506 | |||
507 | @_extend_signature(_Reader.__init__) |
||
508 | def __init__( |
||
509 | self, |
||
510 | path, |
||
511 | reader='ase', |
||
512 | folder=False, |
||
513 | **kwargs): |
||
514 | |||
515 | super().__init__(**kwargs) |
||
516 | |||
517 | if reader not in ('ase', 'ccdc'): |
||
518 | raise ValueError( |
||
519 | f'Invalid reader {reader}; must be ase or ccdc.') |
||
520 | |||
521 | if reader == 'ase' and self.heaviest_component: |
||
522 | raise NotImplementedError( |
||
523 | f'Parameter heaviest_component not implimented for ase, only ccdc.') |
||
524 | |||
525 | if reader == 'ase': |
||
526 | extensions = {'cif'} |
||
527 | file_parser = ase.io.cif.parse_cif |
||
528 | pset_converter = self._CIFBlock_to_PeriodicSet |
||
529 | |||
530 | elif reader == 'ccdc': |
||
531 | if not _CCDC_ENABLED: |
||
532 | raise ImportError( |
||
533 | f"Failed to import csd-python-api; " |
||
534 | "please check it is installed and licensed.") |
||
535 | extensions = ccdc.io.EntryReader.known_suffixes |
||
536 | file_parser = ccdc.io.EntryReader |
||
537 | pset_converter = self._Entry_to_PeriodicSet |
||
538 | |||
539 | if folder: |
||
540 | generator = self._folder_generator(path, file_parser, extensions) |
||
541 | else: |
||
542 | generator = file_parser(path) |
||
543 | |||
544 | self._generator = self._map(pset_converter, generator) |
||
545 | |||
546 | def _folder_generator(self, path, file_parser, extensions): |
||
547 | for file in os.listdir(path): |
||
548 | suff = os.path.splitext(file)[1][1:] |
||
549 | if suff.lower() in extensions: |
||
550 | self.current_filename = file |
||
551 | yield from file_parser(os.path.join(path, file)) |
||
552 | |||
553 | |||
554 | class CSDReader(_Reader): |
||
555 | """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects. |
||
556 | |||
557 | The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed |
||
558 | to :func:`.calculate.AMD` or :func:`.calculate.PDD`. |
||
559 | |||
560 | Examples: |
||
561 | |||
562 | Get crystals with refcodes in a list:: |
||
563 | |||
564 | refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01'] |
||
565 | structures = list(amd.CSDReader(refcodes)) |
||
566 | |||
567 | Read refcode families (any whose refcode starts with strings in the list):: |
||
568 | |||
569 | refcodes = ['ACSALA', 'HXACAN'] |
||
570 | structures = list(amd.CSDReader(refcodes, families=True)) |
||
571 | |||
572 | Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`:: |
||
573 | |||
574 | reader = amd.CSDReader() |
||
575 | debxit01 = reader.entry('DEBXIT01') |
||
576 | |||
577 | # looping over this generic reader will yield all CSD entries |
||
578 | for periodic_set in reader: |
||
579 | ... |
||
580 | |||
581 | Make list of AMD (with k=100) for crystals in these families:: |
||
582 | |||
583 | refcodes = ['ACSALA', 'HXACAN'] |
||
584 | amds = [] |
||
585 | for periodic_set in amd.CSDReader(refcodes, families=True): |
||
586 | amds.append(amd.AMD(periodic_set, 100)) |
||
587 | """ |
||
588 | |||
589 | @_extend_signature(_Reader.__init__) |
||
590 | def __init__( |
||
591 | self, |
||
592 | refcodes=None, |
||
593 | families=False, |
||
594 | **kwargs): |
||
595 | |||
596 | if not _CCDC_ENABLED: |
||
597 | raise ImportError( |
||
598 | f"Failed to import csd-python-api; " |
||
599 | "please check it is installed and licensed.") |
||
600 | |||
601 | super().__init__(**kwargs) |
||
602 | |||
603 | if isinstance(refcodes, str) and refcodes.lower() == 'csd': |
||
604 | refcodes = None |
||
605 | |||
606 | if refcodes is None: |
||
607 | families = False |
||
608 | else: |
||
609 | refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes) |
||
610 | |||
611 | # families parameter reads all crystals with ids starting with passed refcodes |
||
612 | if families: |
||
613 | all_refcodes = [] |
||
614 | for refcode in refcodes: |
||
615 | query = ccdc.search.TextNumericSearch() |
||
616 | query.add_identifier(refcode) |
||
617 | all_refcodes.extend((hit.identifier for hit in query.search())) |
||
618 | |||
619 | # filter to unique refcodes |
||
620 | seen = set() |
||
621 | seen_add = seen.add |
||
622 | refcodes = [ |
||
623 | refcode for refcode in all_refcodes |
||
624 | if not (refcode in seen or seen_add(refcode))] |
||
625 | |||
626 | self._entry_reader = ccdc.io.EntryReader('CSD') |
||
627 | self._generator = self._map( |
||
628 | self._Entry_to_PeriodicSet, |
||
629 | self._ccdc_generator(refcodes)) |
||
630 | |||
631 | def _ccdc_generator(self, refcodes): |
||
632 | """Generates ccdc Entries from CSD refcodes""" |
||
633 | |||
634 | if refcodes is None: |
||
635 | for entry in self._entry_reader: |
||
636 | yield entry |
||
637 | else: |
||
638 | for refcode in refcodes: |
||
639 | try: |
||
640 | entry = self._entry_reader.entry(refcode) |
||
641 | yield entry |
||
642 | except RuntimeError: |
||
643 | warnings.warn( |
||
644 | f'Identifier {refcode} not found in database') |
||
645 | |||
646 | def entry(self, refcode: str) -> PeriodicSet: |
||
647 | """Read a PeriodicSet given any CSD refcode.""" |
||
648 | |||
649 | entry = self._entry_reader.entry(refcode) |
||
650 | periodic_set = self._Entry_to_PeriodicSet(entry) |
||
651 | return periodic_set |
||
652 | |||
653 | |||
654 | class SetWriter: |
||
655 | """Write several :class:`.periodicset.PeriodicSet` objects to a .hdf5 file. |
||
656 | Reading the .hdf5 is much faster than parsing a .CIF file. |
||
657 | |||
658 | Examples: |
||
659 | |||
660 | Write the crystals in mycif.cif to a .hdf5 file:: |
||
661 | |||
662 | with amd.SetWriter('crystals.hdf5') as writer: |
||
663 | |||
664 | for periodic_set in amd.CifReader('mycif.cif'): |
||
665 | writer.write(periodic_set) |
||
666 | |||
667 | # use iwrite to write straight from an iterator |
||
668 | # below is equivalent to the above loop |
||
669 | writer.iwrite(amd.CifReader('mycif.cif')) |
||
670 | |||
671 | Read the crystals back from the file with :class:`SetReader`. |
||
672 | """ |
||
673 | |||
674 | _str_dtype = h5py.vlen_dtype(str) |
||
675 | |||
676 | def __init__(self, filename: str): |
||
677 | |||
678 | self.file = h5py.File(filename, 'w', track_order=True) |
||
679 | |||
680 | def write(self, periodic_set: PeriodicSet, name: Optional[str] = None): |
||
681 | """Write a PeriodicSet object to file.""" |
||
682 | |||
683 | if not isinstance(periodic_set, PeriodicSet): |
||
684 | raise ValueError( |
||
685 | f'Object type {periodic_set.__class__.__name__} cannot be written with SetWriter') |
||
686 | |||
687 | # need a name to store or you can't access items by key |
||
688 | if name is None: |
||
689 | if periodic_set.name is None: |
||
690 | raise ValueError( |
||
691 | 'Periodic set must have a name to be written. Either set the name ' |
||
692 | 'attribute of the PeriodicSet or pass a name to SetWriter.write()') |
||
693 | name = periodic_set.name |
||
694 | |||
695 | # this group is the PeriodicSet |
||
696 | group = self.file.create_group(name) |
||
697 | |||
698 | # datasets in the group for motif and cell |
||
699 | group.create_dataset('motif', data=periodic_set.motif) |
||
700 | group.create_dataset('cell', data=periodic_set.cell) |
||
701 | |||
702 | if periodic_set.tags: |
||
703 | # a subgroup contains tags that are lists or ndarrays |
||
704 | tags_group = group.create_group('tags') |
||
705 | |||
706 | for tag in periodic_set.tags: |
||
707 | data = periodic_set.tags[tag] |
||
708 | |||
709 | if data is None: # nonce to handle None |
||
710 | tags_group.attrs[tag] = '__None' |
||
711 | elif np.isscalar(data): # scalars (nums and strs) stored as attrs |
||
712 | tags_group.attrs[tag] = data |
||
713 | elif isinstance(data, np.ndarray): |
||
714 | tags_group.create_dataset(tag, data=data) |
||
715 | elif isinstance(data, list): |
||
716 | # lists of strings stored as special type for some reason |
||
717 | if any(isinstance(d, str) for d in data): |
||
718 | data = [str(d) for d in data] |
||
719 | tags_group.create_dataset(tag, |
||
720 | data=data, |
||
721 | dtype=SetWriter._str_dtype) |
||
722 | else: # other lists must be castable to ndarray |
||
723 | data = np.asarray(data) |
||
724 | tags_group.create_dataset(tag, |
||
725 | data=np.array(data)) |
||
726 | else: |
||
727 | raise ValueError( |
||
728 | f'Cannot store tag of type {type(data)} with SetWriter') |
||
729 | |||
730 | def iwrite(self, periodic_sets: Iterable[PeriodicSet]): |
||
731 | """Write :class:`.periodicset.PeriodicSet` objects from an iterable to file.""" |
||
732 | for periodic_set in periodic_sets: |
||
733 | self.write(periodic_set) |
||
734 | |||
735 | def close(self): |
||
736 | """Close the :class:`SetWriter`.""" |
||
737 | self.file.close() |
||
738 | |||
739 | def __enter__(self): |
||
740 | return self |
||
741 | |||
742 | # handle exceptions? |
||
743 | def __exit__(self, exc_type, exc_value, tb): |
||
1 ignored issue
–
show
|
|||
744 | self.file.close() |
||
745 | |||
746 | |||
747 | class SetReader: |
||
748 | """Read :class:`.periodicset.PeriodicSet` objects from a .hdf5 file written |
||
749 | with :class:`SetWriter`. Acts like a read-only dict that can be iterated |
||
750 | over (preserves write order). |
||
751 | |||
752 | Examples: |
||
753 | |||
754 | Get PDDs (k=100) of crystals in crystals.hdf5:: |
||
755 | |||
756 | pdds = [] |
||
757 | with amd.SetReader('crystals.hdf5') as reader: |
||
758 | for periodic_set in reader: |
||
759 | pdds.append(amd.PDD(periodic_set, 100)) |
||
760 | |||
761 | # above is equivalent to: |
||
762 | pdds = [amd.PDD(pset, 100) for pset in amd.SetReader('crystals.hdf5')] |
||
763 | """ |
||
764 | |||
765 | def __init__(self, filename: str): |
||
766 | |||
767 | self.file = h5py.File(filename, 'r', track_order=True) |
||
768 | |||
769 | def _get_set(self, name: str) -> PeriodicSet: |
||
770 | # take a name in the set and return the PeriodicSet |
||
771 | group = self.file[name] |
||
772 | periodic_set = PeriodicSet(group['motif'][:], group['cell'][:], name=name) |
||
773 | |||
774 | if 'tags' in group: |
||
775 | for tag in group['tags']: |
||
776 | data = group['tags'][tag][:] |
||
777 | |||
778 | if any(isinstance(d, (bytes, bytearray)) for d in data): |
||
779 | periodic_set.tags[tag] = [d.decode() for d in data] |
||
780 | else: |
||
781 | periodic_set.tags[tag] = data |
||
782 | |||
783 | for attr in group['tags'].attrs: |
||
784 | data = group['tags'].attrs[attr] |
||
785 | periodic_set.tags[attr] = None if data == '__None' else data |
||
786 | |||
787 | return periodic_set |
||
788 | |||
789 | def close(self): |
||
790 | """Close the :class:`SetReader`.""" |
||
791 | self.file.close() |
||
792 | |||
793 | def family(self, refcode: str) -> Iterable[PeriodicSet]: |
||
794 | """Yield any :class:`.periodicset.PeriodicSet` whose name starts with |
||
795 | input refcode.""" |
||
796 | for name in self.keys(): |
||
797 | if name.startswith(refcode): |
||
798 | yield self._get_set(name) |
||
799 | |||
800 | def __getitem__(self, name): |
||
801 | # index by name. Not found exc? |
||
802 | return self._get_set(name) |
||
803 | |||
804 | def __len__(self): |
||
805 | return len(self.keys()) |
||
806 | |||
807 | def __iter__(self): |
||
808 | # interface to loop over the SetReader; does not close the SetReader when done |
||
809 | for name in self.keys(): |
||
810 | yield self._get_set(name) |
||
811 | |||
812 | def __contains__(self, item): |
||
813 | if item in self.keys(): |
||
814 | return True |
||
815 | else: |
||
816 | return False |
||
817 | |||
818 | def keys(self): |
||
819 | """Yield names of items in the :class:`SetReader`.""" |
||
820 | return self.file['/'].keys() |
||
821 | |||
822 | def __enter__(self): |
||
823 | return self |
||
824 | |||
825 | # handle exceptions? |
||
826 | def __exit__(self, exc_type, exc_value, tb): |
||
1 ignored issue
–
show
|
|||
827 | self.file.close() |
||
828 | |||
829 | |||
830 | def crystal_to_periodicset(crystal): |
||
831 | """ccdc.crystal.Crystal --> amd.periodicset.PeriodicSet. |
||
832 | Ignores disorder, missing sites/coords, checks & no options. |
||
833 | Is a stripped-down version of the function used in CifReader.""" |
||
834 | |||
835 | cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles) |
||
836 | |||
837 | # asymmetric unit fractional coordinates |
||
838 | asym_frac_motif = np.array([tuple(a.fractional_coordinates) |
||
839 | for a in crystal.asymmetric_unit_molecule.atoms]) |
||
840 | asym_frac_motif = np.mod(asym_frac_motif, 1) |
||
841 | |||
842 | # if the above removed everything, skip this structure |
||
843 | if asym_frac_motif.shape[0] == 0: |
||
844 | raise ValueError(f'{crystal.identifier} has no coordinates') |
||
845 | |||
846 | sitesym = crystal.symmetry_operators |
||
847 | if not sitesym: sitesym = ('x,y,z', ) |
||
848 | r = _Reader() |
||
1 ignored issue
–
show
|
|||
849 | r.current_identifier = crystal.identifier |
||
850 | frac_motif, asym_unit, multiplicities, _ = r._expand(asym_frac_motif, sitesym) |
||
851 | motif = frac_motif @ cell |
||
852 | |||
853 | kwargs = { |
||
854 | 'name': crystal.identifier, |
||
855 | 'asymmetric_unit': asym_unit, |
||
856 | 'wyckoff_multiplicities': multiplicities, |
||
857 | } |
||
858 | |||
859 | periodic_set = PeriodicSet(motif, cell, **kwargs) |
||
860 | |||
861 | return periodic_set |
||
862 | |||
863 | |||
864 | def cifblock_to_periodicset(block): |
||
865 | """ase.io.cif.CIFBlock --> amd.periodicset.PeriodicSet. |
||
866 | Ignores disorder, missing sites/coords, checks & no options. |
||
867 | Is a stripped-down version of the function used in CifReader.""" |
||
868 | |||
869 | cell = block.get_cell().array |
||
870 | asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags] |
||
871 | |||
872 | if None in asym_frac_motif: |
||
873 | asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags] |
||
874 | if None in asym_motif: |
||
875 | warnings.warn( |
||
876 | f'Skipping {block.name} as coordinates were not found') |
||
877 | return None |
||
878 | |||
879 | asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell) |
||
880 | |||
881 | asym_frac_motif = np.mod(np.array(asym_frac_motif).T, 1) |
||
882 | |||
883 | if asym_frac_motif.shape[0] == 0: |
||
884 | raise ValueError(f'{block.name} has no coordinates') |
||
885 | |||
886 | sitesym = ('x,y,z', ) |
||
887 | for tag in _Reader.symop_tags: |
||
888 | if tag in block: |
||
889 | sitesym = block[tag] |
||
890 | break |
||
891 | |||
892 | if isinstance(sitesym, str): |
||
893 | sitesym = [sitesym] |
||
894 | |||
895 | dummy_reader = _Reader() |
||
896 | dummy_reader.current_identifier = block.name |
||
897 | frac_motif, asym_unit, multiplicities, _ = dummy_reader._expand(asym_frac_motif, sitesym) |
||
898 | motif = frac_motif @ cell |
||
899 | |||
900 | kwargs = { |
||
901 | 'name': block.name, |
||
902 | 'asymmetric_unit': asym_unit, |
||
903 | 'wyckoff_multiplicities': multiplicities |
||
904 | } |
||
905 | |||
906 | periodic_set = PeriodicSet(motif, cell, **kwargs) |
||
907 | return periodic_set |
||
908 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.