| Total Complexity | 99 |
| Total Lines | 445 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
Complex classes like amd._reader often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """Contains base reader class for the io module. |
||
|
|
|||
| 2 | This class implements the converters from CifBlock, Entry to PeriodicSets. |
||
| 3 | """ |
||
| 4 | |||
| 5 | import warnings |
||
| 6 | from typing import Callable, Iterable, Sequence, Tuple |
||
| 7 | |||
| 8 | import numpy as np |
||
| 9 | import ase.spacegroup.spacegroup # parse_sitesym |
||
| 10 | import ase.io.cif |
||
| 11 | |||
| 12 | from .periodicset import PeriodicSet |
||
| 13 | from .utils import cellpar_to_cell |
||
| 14 | |||
| 15 | |||
| 16 | def _warning(message, category, filename, lineno, *args, **kwargs): |
||
|
2 ignored issues
–
show
|
|||
| 17 | return f'{filename}:{lineno}: {category.__name__}: {message}\n' |
||
| 18 | |||
| 19 | warnings.formatwarning = _warning |
||
| 20 | |||
| 21 | |||
| 22 | class _Reader: |
||
| 23 | """Base Reader class. Contains parsers for converting ase CifBlock |
||
| 24 | and ccdc Entry objects to PeriodicSets. |
||
| 25 | |||
| 26 | Intended use: |
||
| 27 | |||
| 28 | First make a new method for _Reader converting object to PeriodicSet |
||
| 29 | (e.g. named _X_to_PSet). Then make this class outline: |
||
| 30 | |||
| 31 | class XReader(_Reader): |
||
| 32 | def __init__(self, ..., **kwargs): |
||
| 33 | |||
| 34 | super().__init__(**kwargs) |
||
| 35 | |||
| 36 | # setup and checks |
||
| 37 | |||
| 38 | # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry) |
||
| 39 | |||
| 40 | # set self._generator like this |
||
| 41 | self._generator = self._read(iterable, self._X_to_PSet) |
||
| 42 | """ |
||
| 43 | |||
| 44 | disorder_options = {'skip', 'ordered_sites', 'all_sites'} |
||
| 45 | reserved_tags = { |
||
| 46 | 'motif', |
||
| 47 | 'cell', |
||
| 48 | 'name', |
||
| 49 | 'asymmetric_unit', |
||
| 50 | 'wyckoff_multiplicities', |
||
| 51 | 'types', |
||
| 52 | 'filename',} |
||
| 53 | atom_site_fract_tags = [ |
||
| 54 | '_atom_site_fract_x', |
||
| 55 | '_atom_site_fract_y', |
||
| 56 | '_atom_site_fract_z',] |
||
| 57 | atom_site_cartn_tags = [ |
||
| 58 | '_atom_site_cartn_x', |
||
| 59 | '_atom_site_cartn_y', |
||
| 60 | '_atom_site_cartn_z',] |
||
| 61 | symop_tags = [ |
||
| 62 | '_space_group_symop_operation_xyz', |
||
| 63 | '_space_group_symop.operation_xyz', |
||
| 64 | '_symmetry_equiv_pos_as_xyz',] |
||
| 65 | |||
| 66 | equiv_site_tol = 1e-3 |
||
| 67 | |||
| 68 | def __init__( |
||
| 69 | self, |
||
| 70 | remove_hydrogens=False, |
||
| 71 | disorder='skip', |
||
| 72 | heaviest_component=False, |
||
| 73 | show_warnings=True, |
||
| 74 | extract_data=None, |
||
| 75 | include_if=None): |
||
| 76 | |||
| 77 | # settings |
||
| 78 | if disorder not in _Reader.disorder_options: |
||
| 79 | raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}') |
||
| 80 | |||
| 81 | if extract_data: |
||
| 82 | _validate_extract_data(extract_data) |
||
| 83 | |||
| 84 | if include_if: |
||
| 85 | for func in include_if: |
||
| 86 | if not callable(func): |
||
| 87 | raise ValueError('include_if must be a list of callables') |
||
| 88 | |||
| 89 | self.remove_hydrogens = remove_hydrogens |
||
| 90 | self.disorder = disorder |
||
| 91 | self.heaviest_component = heaviest_component |
||
| 92 | self.extract_data = extract_data |
||
| 93 | self.include_if = include_if |
||
| 94 | self.show_warnings = show_warnings |
||
| 95 | self.current_identifier = None |
||
| 96 | self.current_filename = None |
||
| 97 | self._generator = [] |
||
| 98 | |||
| 99 | def __iter__(self): |
||
| 100 | yield from self._generator |
||
| 101 | |||
| 102 | def read_one(self): |
||
| 103 | """Read the next (or first) item.""" |
||
| 104 | return next(iter(self._generator)) |
||
| 105 | |||
| 106 | # basically the builtin map, but skips items if the function returned None. |
||
| 107 | # The object returned by this function (Iterable of PeriodicSets) is set to |
||
| 108 | # self._generator; then iterating over the Reader iterates over |
||
| 109 | # self._generator. |
||
| 110 | @staticmethod |
||
| 111 | def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]: |
||
| 112 | """Iterates over iterable, passing items through parser and |
||
| 113 | yielding the result if it is not None. |
||
| 114 | """ |
||
| 115 | |||
| 116 | for item in iterable: |
||
| 117 | res = func(item) |
||
| 118 | if res is not None: |
||
| 119 | yield res |
||
| 120 | |||
| 121 | def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet: |
||
| 122 | """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set.""" |
||
| 123 | |||
| 124 | # skip if structure does not pass checks in include_if |
||
| 125 | if self.include_if: |
||
| 126 | if not all(check(block) for check in self.include_if): |
||
| 127 | return None |
||
| 128 | |||
| 129 | # read name, cell, asym motif and atomic symbols |
||
| 130 | self.current_identifier = block.name |
||
| 131 | cell = block.get_cell().array |
||
| 132 | asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags] |
||
| 133 | if None in asym_frac_motif: |
||
| 134 | asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags] |
||
| 135 | if None in asym_motif: |
||
| 136 | if self.show_warnings: |
||
| 137 | warnings.warn( |
||
| 138 | f'Skipping {self.current_identifier} as coordinates were not found') |
||
| 139 | return None |
||
| 140 | asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell) |
||
| 141 | asym_frac_motif = np.array(asym_frac_motif).T |
||
| 142 | |||
| 143 | try: |
||
| 144 | asym_symbols = block.get_symbols() |
||
| 145 | except ase.io.cif.NoStructureData as _: |
||
| 146 | asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))] |
||
| 147 | |||
| 148 | # indices of sites to remove |
||
| 149 | remove = [] |
||
| 150 | if self.remove_hydrogens: |
||
| 151 | remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD')) |
||
| 152 | |||
| 153 | # find disordered sites |
||
| 154 | asym_is_disordered = [] |
||
| 155 | occupancies = block.get('_atom_site_occupancy') |
||
| 156 | labels = block.get('_atom_site_label') |
||
| 157 | if occupancies is not None: |
||
| 158 | disordered = [] # indices where there is disorder |
||
| 159 | for i, (occ, label) in enumerate(zip(occupancies, labels)): |
||
| 160 | if _atom_has_disorder(label, occ): |
||
| 161 | if i not in remove: |
||
| 162 | disordered.append(i) |
||
| 163 | asym_is_disordered.append(True) |
||
| 164 | else: |
||
| 165 | asym_is_disordered.append(False) |
||
| 166 | |||
| 167 | if self.disorder == 'skip' and len(disordered) > 0: |
||
| 168 | if self.show_warnings: |
||
| 169 | warnings.warn( |
||
| 170 | f'Skipping {self.current_identifier} as structure is disordered') |
||
| 171 | return None |
||
| 172 | |||
| 173 | if self.disorder == 'ordered_sites': |
||
| 174 | remove.extend(disordered) |
||
| 175 | |||
| 176 | # remove sites |
||
| 177 | asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1) |
||
| 178 | asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove] |
||
| 179 | asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove] |
||
| 180 | |||
| 181 | keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered) |
||
| 182 | if keep_sites is not None: |
||
| 183 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
| 184 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
| 185 | |||
| 186 | if self._has_no_valid_sites(asym_frac_motif): |
||
| 187 | return None |
||
| 188 | |||
| 189 | sitesym = ['x,y,z', ] |
||
| 190 | for tag in _Reader.symop_tags: |
||
| 191 | if tag in block: |
||
| 192 | sitesym = block[tag] |
||
| 193 | break |
||
| 194 | |||
| 195 | if isinstance(sitesym, str): |
||
| 196 | sitesym = [sitesym] |
||
| 197 | |||
| 198 | return self._construct_periodic_set(block, asym_frac_motif, asym_symbols, sitesym, cell) |
||
| 199 | |||
| 200 | |||
| 201 | def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet: |
||
| 202 | """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set.""" |
||
| 203 | |||
| 204 | # skip if structure does not pass checks in include_if |
||
| 205 | if self.include_if: |
||
| 206 | if not all(check(entry) for check in self.include_if): |
||
| 207 | return None |
||
| 208 | |||
| 209 | self.current_identifier = entry.identifier |
||
| 210 | # structure must pass this test |
||
| 211 | if not entry.has_3d_structure: |
||
| 212 | if self.show_warnings: |
||
| 213 | warnings.warn( |
||
| 214 | f'Skipping {self.current_identifier} as entry has no 3D structure') |
||
| 215 | return None |
||
| 216 | |||
| 217 | crystal = entry.crystal |
||
| 218 | |||
| 219 | # first disorder check, if skipping. If occ == 1 for all atoms but the entry |
||
| 220 | # or crystal is listed as having disorder, skip (can't know where disorder is). |
||
| 221 | # If occ != 1 for any atoms, we wait to see if we remove them before skipping. |
||
| 222 | molecule = crystal.disordered_molecule |
||
| 223 | if self.disorder == 'ordered_sites': |
||
| 224 | molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?')) |
||
| 225 | |||
| 226 | may_have_disorder = False |
||
| 227 | if self.disorder == 'skip': |
||
| 228 | for a in molecule.atoms: |
||
| 229 | occ = a.occupancy |
||
| 230 | if _atom_has_disorder(a.label, occ): |
||
| 231 | may_have_disorder = True |
||
| 232 | break |
||
| 233 | |||
| 234 | if not may_have_disorder: |
||
| 235 | if crystal.has_disorder or entry.has_disorder: |
||
| 236 | if self.show_warnings: |
||
| 237 | warnings.warn(f'Skipping {self.current_identifier} as structure is disordered') |
||
| 238 | return None |
||
| 239 | |||
| 240 | if self.remove_hydrogens: |
||
| 241 | molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD') |
||
| 242 | |||
| 243 | if self.heaviest_component: |
||
| 244 | molecule = _heaviest_component(molecule) |
||
| 245 | |||
| 246 | crystal.molecule = molecule |
||
| 247 | |||
| 248 | # by here all atoms to be removed have been (except via ordered_sites). |
||
| 249 | # If disorder == 'skip' and there were atom(s) with occ < 1 found |
||
| 250 | # eariler, we check if all such atoms were removed. If not, skip. |
||
| 251 | if self.disorder == 'skip' and may_have_disorder: |
||
| 252 | for a in crystal.disordered_molecule.atoms: |
||
| 253 | occ = a.occupancy |
||
| 254 | if _atom_has_disorder(a.label, occ): |
||
| 255 | if self.show_warnings: |
||
| 256 | warnings.warn( |
||
| 257 | f'Skipping {self.current_identifier} as structure is disordered') |
||
| 258 | return None |
||
| 259 | |||
| 260 | # if disorder is all_sites, we need to know where disorder is to ignore overlaps |
||
| 261 | asym_is_disordered = [] # True/False list same length as asym unit |
||
| 262 | if self.disorder == 'all_sites': |
||
| 263 | for a in crystal.asymmetric_unit_molecule.atoms: |
||
| 264 | occ = a.occupancy |
||
| 265 | if _atom_has_disorder(a.label, occ): |
||
| 266 | asym_is_disordered.append(True) |
||
| 267 | else: |
||
| 268 | asym_is_disordered.append(False) |
||
| 269 | |||
| 270 | # check all atoms have coords. option/default remove unknown sites? |
||
| 271 | if not molecule.all_atoms_have_sites or \ |
||
| 272 | any(a.fractional_coordinates is None for a in molecule.atoms): |
||
| 273 | if self.show_warnings: |
||
| 274 | warnings.warn( |
||
| 275 | f'Skipping {self.current_identifier} as some atoms do not have sites') |
||
| 276 | return None |
||
| 277 | |||
| 278 | # get cell & asymmetric unit |
||
| 279 | cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles) |
||
| 280 | asym_frac_motif = np.array([tuple(a.fractional_coordinates) |
||
| 281 | for a in crystal.asymmetric_unit_molecule.atoms]) |
||
| 282 | asym_frac_motif = np.mod(asym_frac_motif, 1) |
||
| 283 | asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms] |
||
| 284 | |||
| 285 | # remove overlapping sites, check sites exist |
||
| 286 | keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered) |
||
| 287 | if keep_sites is not None: |
||
| 288 | asym_frac_motif = asym_frac_motif[keep_sites] |
||
| 289 | asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep] |
||
| 290 | |||
| 291 | if self._has_no_valid_sites(asym_frac_motif): |
||
| 292 | return None |
||
| 293 | |||
| 294 | sitesym = crystal.symmetry_operators |
||
| 295 | if not sitesym: |
||
| 296 | sitesym = ['x,y,z', ] |
||
| 297 | |||
| 298 | entry.crystal.molecule = crystal.disordered_molecule # for extract_data. remove? |
||
| 299 | |||
| 300 | return self._construct_periodic_set(entry, asym_frac_motif, asym_symbols, sitesym, cell) |
||
| 301 | |||
| 302 | def expand( |
||
| 303 | self, |
||
| 304 | asym_frac_motif: np.ndarray, |
||
| 305 | sitesym: Sequence[str] |
||
| 306 | ) -> Tuple[np.ndarray, ...]: |
||
| 307 | """ |
||
| 308 | Asymmetric unit's fractional coords + sitesyms (as strings) |
||
| 309 | --> |
||
| 310 | frac_motif, asym_unit, multiplicities, inverses |
||
| 311 | """ |
||
| 312 | |||
| 313 | rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym) |
||
| 314 | all_sites = [] |
||
| 315 | asym_unit = [0] |
||
| 316 | multiplicities = [] |
||
| 317 | inverses = [] |
||
| 318 | |||
| 319 | for inv, site in enumerate(asym_frac_motif): |
||
| 320 | multiplicity = 0 |
||
| 321 | |||
| 322 | for rot, trans in zip(rotations, translations): |
||
| 323 | site_ = np.mod(np.dot(rot, site) + trans, 1) |
||
| 324 | |||
| 325 | if not all_sites: |
||
| 326 | all_sites.append(site_) |
||
| 327 | inverses.append(inv) |
||
| 328 | multiplicity += 1 |
||
| 329 | continue |
||
| 330 | |||
| 331 | if not self._is_site_overlapping(site_, all_sites, inverses, inv): |
||
| 332 | all_sites.append(site_) |
||
| 333 | inverses.append(inv) |
||
| 334 | multiplicity += 1 |
||
| 335 | |||
| 336 | if multiplicity > 0: |
||
| 337 | multiplicities.append(multiplicity) |
||
| 338 | asym_unit.append(len(all_sites)) |
||
| 339 | |||
| 340 | frac_motif = np.array(all_sites) |
||
| 341 | asym_unit = np.array(asym_unit[:-1]) |
||
| 342 | multiplicities = np.array(multiplicities) |
||
| 343 | return frac_motif, asym_unit, multiplicities, inverses |
||
| 344 | |||
| 345 | def _is_site_overlapping(self, new_site, all_sites, inverses, inv): |
||
| 346 | """Return True (and warn) if new_site overlaps with a site in all_sites.""" |
||
| 347 | diffs1 = np.abs(new_site - all_sites) |
||
| 348 | diffs2 = np.abs(diffs1 - 1) |
||
| 349 | mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol, |
||
| 350 | diffs2 <= _Reader.equiv_site_tol), |
||
| 351 | axis=-1) |
||
| 352 | |||
| 353 | if np.any(mask): |
||
| 354 | where_equal = np.argwhere(mask).flatten() |
||
| 355 | for ind in where_equal: |
||
| 356 | if inverses[ind] == inv: |
||
| 357 | pass |
||
| 358 | else: |
||
| 359 | if self.show_warnings: |
||
| 360 | warnings.warn( |
||
| 361 | f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}') |
||
| 362 | return True |
||
| 363 | else: |
||
| 364 | return False |
||
| 365 | |||
| 366 | def _validate_sites(self, asym_frac_motif, asym_is_disordered): |
||
| 367 | site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif) |
||
| 368 | site_diffs2 = np.abs(site_diffs1 - 1) |
||
| 369 | overlapping = np.triu(np.all( |
||
| 370 | (site_diffs1 <= _Reader.equiv_site_tol) | |
||
| 371 | (site_diffs2 <= _Reader.equiv_site_tol), |
||
| 372 | axis=-1), 1) |
||
| 373 | |||
| 374 | if self.disorder == 'all_sites': |
||
| 375 | for i, j in np.argwhere(overlapping): |
||
| 376 | if asym_is_disordered[i] or asym_is_disordered[j]: |
||
| 377 | overlapping[i, j] = False |
||
| 378 | |||
| 379 | if overlapping.any(): |
||
| 380 | if self.show_warnings: |
||
| 381 | warnings.warn( |
||
| 382 | f'{self.current_identifier} may have overlapping sites; duplicates will be removed') |
||
| 383 | keep_sites = ~overlapping.any(0) |
||
| 384 | return keep_sites |
||
| 385 | |||
| 386 | def _has_no_valid_sites(self, motif): |
||
| 387 | if motif.shape[0] == 0: |
||
| 388 | if self.show_warnings: |
||
| 389 | warnings.warn( |
||
| 390 | f'Skipping {self.current_identifier} as there are no sites with coordinates') |
||
| 391 | return True |
||
| 392 | return False |
||
| 393 | |||
| 394 | def _construct_periodic_set(self, raw_item, asym_frac_motif, asym_symbols, sitesym, cell): |
||
| 395 | frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym) |
||
| 396 | full_types = [asym_symbols[i] for i in inverses] |
||
| 397 | motif = frac_motif @ cell |
||
| 398 | |||
| 399 | kwargs = { |
||
| 400 | 'name': self.current_identifier, |
||
| 401 | 'asymmetric_unit': asym_unit, |
||
| 402 | 'wyckoff_multiplicities': multiplicities, |
||
| 403 | 'types': full_types, |
||
| 404 | } |
||
| 405 | |||
| 406 | if self.current_filename: |
||
| 407 | kwargs['filename'] = self.current_filename |
||
| 408 | |||
| 409 | if self.extract_data is not None: |
||
| 410 | for key in self.extract_data: |
||
| 411 | kwargs[key] = self.extract_data[key](raw_item) |
||
| 412 | |||
| 413 | return PeriodicSet(motif, cell, **kwargs) |
||
| 414 | |||
| 415 | def _heaviest_component(molecule): |
||
| 416 | """Heaviest component (removes all but the heaviest component of the asym unit). |
||
| 417 | Intended for removing solvents. Probably doesn't play well with disorder""" |
||
| 418 | if len(molecule.components) > 1: |
||
| 419 | component_weights = [] |
||
| 420 | for component in molecule.components: |
||
| 421 | weight = 0 |
||
| 422 | for a in component.atoms: |
||
| 423 | if isinstance(a.atomic_weight, (float, int)): |
||
| 424 | if isinstance(a.occupancy, (float, int)): |
||
| 425 | weight += a.occupancy * a.atomic_weight |
||
| 426 | else: |
||
| 427 | weight += a.atomic_weight |
||
| 428 | component_weights.append(weight) |
||
| 429 | largest_component_arg = np.argmax(np.array(component_weights)) |
||
| 430 | molecule = molecule.components[largest_component_arg] |
||
| 431 | |||
| 432 | return molecule |
||
| 433 | |||
| 434 | def _atom_has_disorder(label, occupancy): |
||
| 435 | return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1) |
||
| 436 | |||
| 437 | def _validate_extract_data(extract_data): |
||
| 438 | if not isinstance(extract_data, dict): |
||
| 439 | raise ValueError('extract_data must be a dict with callable values') |
||
| 440 | for key in extract_data: |
||
| 441 | if not callable(extract_data[key]): |
||
| 442 | raise ValueError('extract_data must be a dict with callable values') |
||
| 443 | if key in _Reader.reserved_tags: |
||
| 444 | raise ValueError(f'extract_data includes reserved key {key}') |
||