Passed
Push — master ( 40aa73...93bfb7 )
by Daniel
02:13
created

amd._reader._Reader._asym_unit_from_cifblock()   B

Complexity

Conditions 7

Size

Total Lines 29
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 29
rs 7.952
c 0
b 0
f 0
cc 7
nop 2
1
"""Contains base reader class for the io module. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
2
This class implements the converters from CifBlock, Entry to PeriodicSets.
3
"""
4
5
import warnings
6
from typing import Callable, Iterable, Sequence, Tuple
7
8
import numpy as np
9
import ase.spacegroup.spacegroup    # parse_sitesym
10
import ase.io.cif
11
12
from .periodicset import PeriodicSet
13
from .utils import cellpar_to_cell
14
15
16
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (10/7)
Loading history...
17
    """Base Reader class. Contains parsers for converting ase CifBlock
18
    and ccdc Entry objects to PeriodicSets.
19
20
    Intended use:
21
22
    First make a new method for _Reader converting object to PeriodicSet
23
    (e.g. named _X_to_PSet). Then make this class outline:
24
25
    class XReader(_Reader):
26
        def __init__(self, ..., **kwargs):
27
28
        super().__init__(**kwargs)
29
30
        # setup and checks
31
32
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
33
34
        # set self._generator like this
35
        self._generator = self._read(iterable, self._X_to_PSet)
36
    """
37
38
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
39
    reserved_tags = {
40
        'motif',
41
        'cell',
42
        'name',
43
        'asymmetric_unit',
44
        'wyckoff_multiplicities',
45
        'types',
46
        'filename',}
47
    atom_site_fract_tags = [
48
        '_atom_site_fract_x',
49
        '_atom_site_fract_y',
50
        '_atom_site_fract_z',]
51
    atom_site_cartn_tags = [
52
        '_atom_site_cartn_x',
53
        '_atom_site_cartn_y',
54
        '_atom_site_cartn_z',]
55
    symop_tags = [
56
        '_space_group_symop_operation_xyz',
57
        '_space_group_symop.operation_xyz',
58
        '_symmetry_equiv_pos_as_xyz',]
59
60
    equiv_site_tol = 1e-3
61
62
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
63
            self,
64
            remove_hydrogens=False,
65
            disorder='skip',
66
            heaviest_component=False,
67
            show_warnings=True,
68
            extract_data=None,
69
            include_if=None):
70
71
        if disorder not in _Reader.disorder_options:
72
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
73
74
        if extract_data is None:
75
            self.extract_data = {}
76
        else:
77
            _validate_extract_data(extract_data)
78
            self.extract_data = extract_data
79
80
        if include_if is None:
81
            self.include_if = ()
82
        elif not all(callable(func) for func in include_if):
83
            raise ValueError('include_if must be a list of callables')
84
        else:
85
            self.include_if = include_if
86
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
87
        self.remove_hydrogens = remove_hydrogens
88
        self.disorder = disorder
89
        self.heaviest_component = heaviest_component
90
        self.show_warnings = show_warnings
91
        self.current_name = None
92
        self.current_filename = None
93
        self._generator = []
94
95
    def __iter__(self):
96
        yield from self._generator
97
98
    def read_one(self):
99
        """Read the next (or first) item."""
100
        return next(iter(self._generator))
101
102
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
103
        """Iterates over iterable, passing items through parser and yielding the 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
104
        result if it is not None. Applies warning and include_if filter.
105
        """
106
107
        with warnings.catch_warnings():
108
            if not self.show_warnings:
109
                warnings.simplefilter('ignore')
110
            for item in iterable:
111
                if any(not check(item) for check in self.include_if):
112
                    continue
113
                res = func(item)
114
                if res is not None:
115
                    yield res
116
117
    def _cifblock_to_periodicset(self, block) -> PeriodicSet:
118
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
119
120
        self.current_name = block.name
121
        asym_unit, asym_symbols, sitesym, cell = self._asym_unit_from_cifblock(block)
122
123
        # indices of sites to remove
124
        remove = []
125
        if self.remove_hydrogens:
126
            remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
127
128
        # find disordered sites
129
        asym_is_disordered = []
130
        occupancies = block.get('_atom_site_occupancy')
131
        labels = block.get('_atom_site_label')
132
        if occupancies is not None:
133
            disordered = []     # indices where there is disorder
134
            for i, (occ, label) in enumerate(zip(occupancies, labels)):
135
                if _atom_has_disorder(label, occ):
136
                    if i not in remove:
137
                        disordered.append(i)
138
                        asym_is_disordered.append(True)
139
                else:
140
                    asym_is_disordered.append(False)
141
142
            if self.disorder == 'skip' and len(disordered) > 0:
143
                warnings.warn(f'Skipping {self.current_name} as structure is disordered')
144
                return None
145
146
            if self.disorder == 'ordered_sites':
147
                remove.extend(disordered)
148
149
        # remove sites
150
        asym_unit = np.delete(asym_unit, remove, axis=0)
151
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove]
152
        asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove]
153
154
        keep_sites = self._validate_sites(asym_unit, asym_is_disordered)
155
        if keep_sites is not None:
156
            asym_unit = asym_unit[keep_sites]
157
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
158
159
        if self._has_no_valid_sites(asym_unit):
160
            return None
161
162
        data = {key: func(block) for key, func in self.extract_data.items()}
163
        periodic_set = self._construct_periodic_set(asym_unit, asym_symbols, sitesym, cell, **data)
164
        return periodic_set
165
166
    def _asym_unit_from_cifblock(self, block):
167
        """ase.io.cif.CIFBlock --> 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
168
        asymmetric unit (frac coords), asym_symbols, cell, symops (as strings)"""
169
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
170
        cell = block.get_cell().array
171
        asym_unit = [block.get(name) for name in _Reader.atom_site_fract_tags]
172
        if None in asym_unit:
173
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
174
            if None in asym_motif:
175
                warnings.warn(f'Skipping {self.current_name} as coordinates were not found')
176
                return None
177
            asym_unit = np.array(asym_motif) @ np.linalg.inv(cell)
178
        asym_unit = np.mod(np.array(asym_unit).T, 1)
179
180
        try:
181
            asym_symbols = block.get_symbols()
182
        except ase.io.cif.NoStructureData as _:
183
            asym_symbols = ['Unknown' for _ in range(len(asym_unit))]
184
185
        sitesym = ['x,y,z', ]
186
        for tag in _Reader.symop_tags:
187
            if tag in block:
188
                sitesym = block[tag]
189
                break
190
191
        if isinstance(sitesym, str):
192
            sitesym = [sitesym]
193
194
        return asym_unit, asym_symbols, sitesym, cell
195
196
    def _entry_to_periodicset(self, entry) -> PeriodicSet:
197
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
198
199
        self.current_name = entry.identifier
200
        crystal = entry.crystal
201
202
        if not entry.has_3d_structure:
203
            warnings.warn(f'Skipping {self.current_name} as entry has no 3D structure')
204
            return None
205
206
        # first disorder check, if skipping. If occ == 1 for all atoms but the entry
207
        # or crystal is listed as having disorder, skip (can't know where disorder is).
208
        # If occ != 1 for any atoms, we wait to see if we remove them before skipping.
209
        molecule = crystal.disordered_molecule
210
        if self.disorder == 'ordered_sites':
211
            molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?'))
212
213
        may_have_disorder = False
214
        if self.disorder == 'skip':
215
            for a in molecule.atoms:
216
                occ = a.occupancy
217
                if _atom_has_disorder(a.label, occ):
218
                    may_have_disorder = True
219
                    break
220
221
            if not may_have_disorder:
222
                if crystal.has_disorder or entry.has_disorder:
223
                    warnings.warn(f'Skipping {self.current_name} as structure is disordered')
224
                    return None
225
226
        # make same as cifblock version??
227
        if self.remove_hydrogens:
228
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
229
230
        if self.heaviest_component and len(molecule.components) > 1:
231
            molecule = _heaviest_component(molecule)
232
233
        crystal.molecule = molecule
234
235
        # by here all atoms to be removed have been (except via ordered_sites).
236
        # If disorder == 'skip' and there were atom(s) with occ < 1 found
237
        # eariler, we check if all such atoms were removed. If not, skip.
238
        if self.disorder == 'skip' and may_have_disorder:
239
            for a in crystal.disordered_molecule.atoms:
240
                occ = a.occupancy
241
                if _atom_has_disorder(a.label, occ):
242
                    warnings.warn(f'Skipping {self.current_name} as structure is disordered')
243
                    return None
244
245
        # if disorder is all_sites, we need to know where disorder is to ignore overlaps
246
        asym_is_disordered = []     # True/False list same length as asym unit
247
        if self.disorder == 'all_sites':
248
            for a in crystal.asymmetric_unit_molecule.atoms:
249
                occ = a.occupancy
250
                if _atom_has_disorder(a.label, occ):
251
                    asym_is_disordered.append(True)
252
                else:
253
                    asym_is_disordered.append(False)
254
255
        # check all atoms have coords. option/default remove unknown sites?
256
        if not molecule.all_atoms_have_sites or \
257
           any(a.fractional_coordinates is None for a in molecule.atoms):
258
            warnings.warn(f'Skipping {self.current_name} as some atoms do not have sites')
259
            return None
260
261
        asym_unit, asym_symbols, sitesym, cell = self._asym_unit_from_crystal(crystal)
262
263
        # remove overlapping sites, check sites exist
264
        keep_sites = self._validate_sites(asym_unit, asym_is_disordered)
265
        if keep_sites is not None:
266
            asym_unit = asym_unit[keep_sites]
267
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
268
269
        if self._has_no_valid_sites(asym_unit):
270
            return None
271
272
        entry.crystal.molecule = crystal.disordered_molecule
273
        data = {key: func(entry) for key, func in self.extract_data.items()}
274
        periodic_set = self._construct_periodic_set(asym_unit, asym_symbols, sitesym, cell, **data)
275
        return periodic_set
276
277
    def _asym_unit_from_crystal(self, crystal):
0 ignored issues
show
Coding Style introduced by
This method could be written as a function/class method.

If a method does not access any attributes of the class, it could also be implemented as a function or static method. This can help improve readability. For example

class Foo:
    def some_method(self, x, y):
        return x + y;

could be written as

class Foo:
    @classmethod
    def some_method(cls, x, y):
        return x + y;
Loading history...
278
        """ase.io.cif.CIFBlock -->
279
        asymmetric unit (frac coords), asym_symbols, symops, cell"""
280
281
        asym_atoms = crystal.asymmetric_unit_molecule.atoms
282
        asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
283
        asym_unit = np.mod(asym_unit, 1)
284
        asym_symbols = [a.atomic_symbol for a in asym_atoms]
285
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
286
        sitesym = crystal.symmetry_operators
287
        if not sitesym:
288
            sitesym = ['x,y,z', ]
289
        return asym_unit, asym_symbols, sitesym, cell
290
291
    def _is_site_overlapping(self, new_site, all_sites, inverses, inv):
292
        """Return True (and warn) if new_site overlaps with a site in all_sites."""
293
        diffs1 = np.abs(new_site - all_sites)
294
        diffs2 = np.abs(diffs1 - 1)
295
        mask = np.all((diffs1 <= _Reader.equiv_site_tol) |
296
                      (diffs2 <= _Reader.equiv_site_tol),
297
                    axis=-1)
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (add 2 spaces).
Loading history...
298
299
        if np.any(mask):
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
300
            where_equal = np.argwhere(mask).flatten()
301
            for ind in where_equal:
302
                if inverses[ind] == inv:
303
                    pass
304
                else:
305
                    warnings.warn(
306
                        f'{self.current_name} has equivalent positions {inverses[ind]} and {inv}')
307
            return True
308
        else:
309
            return False
310
311
    def _validate_sites(self, asym_unit, asym_is_disordered):
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
312
        site_diffs1 = np.abs(asym_unit[:, None] - asym_unit)
313
        site_diffs2 = np.abs(site_diffs1 - 1)
314
        overlapping = np.triu(np.all(
315
            (site_diffs1 <= _Reader.equiv_site_tol) |
316
            (site_diffs2 <= _Reader.equiv_site_tol),
317
            axis=-1), 1)
318
319
        if self.disorder == 'all_sites':
320
            for i, j in np.argwhere(overlapping):
321
                if asym_is_disordered[i] or asym_is_disordered[j]:
322
                    overlapping[i, j] = False
323
324
        if overlapping.any():
325
            warnings.warn(
326
                f'{self.current_name} may have overlapping sites; duplicates will be removed')
327
            keep_sites = ~overlapping.any(0)
328
            return keep_sites
329
330
    def _has_no_valid_sites(self, motif):
331
        if motif.shape[0] == 0:
332
            warnings.warn(
333
                f'Skipping {self.current_name} as there are no sites with coordinates')
334
            return True
335
        return False
336
337
    def _construct_periodic_set(self, asym_unit, asym_symbols, sitesym, cell, **kwargs):
338
        """Asym motif + symbols + sitesym + cell (+kwargs) --> PeriodicSet"""
339
        frac_motif, asym_inds, multiplicities, inverses = self.expand(asym_unit, sitesym)
340
        full_types = [asym_symbols[i] for i in inverses]
341
        motif = frac_motif @ cell
342
343
        tags = {
344
            'name': self.current_name,
345
            'asymmetric_unit': asym_inds,
346
            'wyckoff_multiplicities': multiplicities,
347
            'types': full_types,
348
            **kwargs
349
        }
350
351
        if self.current_filename:
352
            tags['filename'] = self.current_filename
353
354
        return PeriodicSet(motif, cell, **tags)
355
356
    def expand(self, asym_unit: np.ndarray, sitesym: Sequence[str]) -> Tuple[np.ndarray, ...]:
357
        """
358
        Asymmetric unit's fractional coords + sitesyms (as strings)
359
        -->
360
        frac motif, asym unit inds, multiplicities, inverses
361
        """
362
363
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
364
        all_sites = []
365
        asym_inds = [0]
366
        multiplicities = []
367
        inverses = []
368
369
        for inv, site in enumerate(asym_unit):
370
            multiplicity = 0
371
372
            for rot, trans in zip(rotations, translations):
373
                site_ = np.mod(np.dot(rot, site) + trans, 1)
374
375
                if not all_sites:
376
                    all_sites.append(site_)
377
                    inverses.append(inv)
378
                    multiplicity += 1
379
                    continue
380
381
                if not self._is_site_overlapping(site_, all_sites, inverses, inv):
382
                    all_sites.append(site_)
383
                    inverses.append(inv)
384
                    multiplicity += 1
385
386
            if multiplicity > 0:
387
                multiplicities.append(multiplicity)
388
                asym_inds.append(len(all_sites))
389
390
        frac_motif = np.array(all_sites)
391
        asym_inds = np.array(asym_inds[:-1])
392
        multiplicities = np.array(multiplicities)
393
        return frac_motif, asym_inds, multiplicities, inverses
394
395
396
def _atom_has_disorder(label, occupancy):
397
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
398
399
def _heaviest_component(molecule):
400
    """Heaviest component (removes all but the heaviest component of the asym unit).
401
    Intended for removing solvents. Probably doesn't play well with disorder"""
402
    component_weights = []
403
    for component in molecule.components:
404
        weight = 0
405
        for a in component.atoms:
406
            if isinstance(a.atomic_weight, (float, int)):
407
                if isinstance(a.occupancy, (float, int)):
408
                    weight += a.occupancy * a.atomic_weight
409
                else:
410
                    weight += a.atomic_weight
411
        component_weights.append(weight)
412
    largest_component_arg = np.argmax(np.array(component_weights))
413
    molecule = molecule.components[largest_component_arg]
414
    return molecule
415
416
def _validate_extract_data(extract_data):
417
    if not isinstance(extract_data, dict):
418
        raise ValueError('extract_data must be a dict of callables')
419
    for key in extract_data:
420
        if not callable(extract_data[key]):
421
            raise ValueError('extract_data must be a dict of callables')
422
        if key in _Reader.reserved_tags:
423
            raise ValueError(f'extract_data includes reserved key {key}')
424