Passed
Push — master ( f71bbb...e05216 )
by Daniel
01:58
created

amd._reader._validate_extract_data()   A

Complexity

Conditions 5

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 8
rs 9.3333
c 0
b 0
f 0
cc 5
nop 1
1
"""Contains base reader class for the io module. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
2
This class implements the converters from CifBlock, Entry to PeriodicSets.
3
"""
4
5
import warnings
6
from typing import Callable, Iterable, Sequence, Tuple
7
8
import numpy as np
9
import ase.spacegroup.spacegroup    # parse_sitesym
10
import ase.io.cif
11
12
from .periodicset import PeriodicSet
13
from .utils import cellpar_to_cell
14
15
16
def _warning(message, category, filename, lineno, *args, **kwargs):
2 ignored issues
show
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
17
    return f'{filename}:{lineno}: {category.__name__}: {message}\n'
18
19
warnings.formatwarning = _warning
20
21
22
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (9/7)
Loading history...
23
    """Base Reader class. Contains parsers for converting ase CifBlock
24
    and ccdc Entry objects to PeriodicSets.
25
26
    Intended use:
27
28
    First make a new method for _Reader converting object to PeriodicSet
29
    (e.g. named _X_to_PSet). Then make this class outline:
30
31
    class XReader(_Reader):
32
        def __init__(self, ..., **kwargs):
33
34
        super().__init__(**kwargs)
35
36
        # setup and checks
37
38
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
39
40
        # set self._generator like this
41
        self._generator = self._read(iterable, self._X_to_PSet)
42
    """
43
44
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
45
    reserved_tags = {
46
        'motif',
47
        'cell',
48
        'name',
49
        'asymmetric_unit',
50
        'wyckoff_multiplicities',
51
        'types',
52
        'filename',}
53
    atom_site_fract_tags = [
54
        '_atom_site_fract_x',
55
        '_atom_site_fract_y',
56
        '_atom_site_fract_z',]
57
    atom_site_cartn_tags = [
58
        '_atom_site_cartn_x',
59
        '_atom_site_cartn_y',
60
        '_atom_site_cartn_z',]
61
    symop_tags = [
62
        '_space_group_symop_operation_xyz',
63
        '_space_group_symop.operation_xyz',
64
        '_symmetry_equiv_pos_as_xyz',]
65
66
    equiv_site_tol = 1e-3
67
68
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
69
            self,
70
            remove_hydrogens=False,
71
            disorder='skip',
72
            heaviest_component=False,
73
            show_warnings=True,
74
            extract_data=None,
75
            include_if=None):
76
77
        # settings
78
        if disorder not in _Reader.disorder_options:
79
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
80
81
        if extract_data:
82
            _validate_extract_data(extract_data)
83
84
        if include_if:
85
            for func in include_if:
86
                if not callable(func):
87
                    raise ValueError('include_if must be a list of callables')
88
89
        self.remove_hydrogens = remove_hydrogens
90
        self.disorder = disorder
91
        self.heaviest_component = heaviest_component
92
        self.extract_data = extract_data
93
        self.include_if = include_if
94
        self.show_warnings = show_warnings
95
        self.current_identifier = None
96
        self.current_filename = None
97
        self._generator = []
98
99
    def __iter__(self):
100
        yield from self._generator
101
102
    def read_one(self):
103
        """Read the next (or first) item."""
104
        return next(iter(self._generator))
105
106
    # basically the builtin map, but skips items if the function returned None.
107
    # The object returned by this function (Iterable of PeriodicSets) is set to
108
    # self._generator; then iterating over the Reader iterates over
109
    # self._generator.
110
    @staticmethod
111
    def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
112
        """Iterates over iterable, passing items through parser and
113
        yielding the result if it is not None.
114
        """
115
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
116
        for item in iterable:
117
            res = func(item)
118
            if res is not None:
119
                yield res
120
121
    def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet:
122
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
123
124
        # skip if structure does not pass checks in include_if
125
        if self.include_if:
126
            if not all(check(block) for check in self.include_if):
127
                return None
128
129
        # read name, cell, asym motif and atomic symbols
130
        self.current_identifier = block.name
131
        cell = block.get_cell().array
132
        asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
133
        if None in asym_frac_motif:
134
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
135
            if None in asym_motif:
136
                if self.show_warnings:
137
                    warnings.warn(
138
                        f'Skipping {self.current_identifier} as coordinates were not found')
139
                return None
140
            asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
141
        asym_frac_motif = np.array(asym_frac_motif).T
142
143
        try:
144
            asym_symbols = block.get_symbols()
145
        except ase.io.cif.NoStructureData as _:
146
            asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))]
147
148
        # indices of sites to remove
149
        remove = []
150
        if self.remove_hydrogens:
151
            remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
152
153
        # find disordered sites
154
        asym_is_disordered = []
155
        occupancies = block.get('_atom_site_occupancy')
156
        labels = block.get('_atom_site_label')
157
        if occupancies is not None:
158
            disordered = []     # indices where there is disorder
159
            for i, (occ, label) in enumerate(zip(occupancies, labels)):
160
                if _atom_has_disorder(label, occ):
161
                    if i not in remove:
162
                        disordered.append(i)
163
                        asym_is_disordered.append(True)
164
                else:
165
                    asym_is_disordered.append(False)
166
167
            if self.disorder == 'skip' and len(disordered) > 0:
168
                if self.show_warnings:
169
                    warnings.warn(
170
                        f'Skipping {self.current_identifier} as structure is disordered')
171
                return None
172
173
            if self.disorder == 'ordered_sites':
174
                remove.extend(disordered)
175
176
        # remove sites
177
        asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1)
178
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove]
179
        asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove]
180
181
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
182
        if keep_sites is not None:
183
            asym_frac_motif = asym_frac_motif[keep_sites]
184
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
185
186
        if self._has_no_valid_sites(asym_frac_motif):
187
            return None
188
189
        sitesym = ['x,y,z', ]
190
        for tag in _Reader.symop_tags:
191
            if tag in block:
192
                sitesym = block[tag]
193
                break
194
195
        if isinstance(sitesym, str):
196
            sitesym = [sitesym]
197
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
198
        return self._construct_periodic_set(block, asym_frac_motif, asym_symbols, sitesym, cell)
199
200
201
    def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet:
0 ignored issues
show
best-practice introduced by
Too many return statements (7/6)
Loading history...
202
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
203
204
        # skip if structure does not pass checks in include_if
205
        if self.include_if:
206
            if not all(check(entry) for check in self.include_if):
207
                return None
208
209
        self.current_identifier = entry.identifier
210
        # structure must pass this test
211
        if not entry.has_3d_structure:
212
            if self.show_warnings:
213
                warnings.warn(
214
                    f'Skipping {self.current_identifier} as entry has no 3D structure')
215
            return None
216
217
        crystal = entry.crystal
218
219
        # first disorder check, if skipping. If occ == 1 for all atoms but the entry
220
        # or crystal is listed as having disorder, skip (can't know where disorder is).
221
        # If occ != 1 for any atoms, we wait to see if we remove them before skipping.
222
        molecule = crystal.disordered_molecule
223
        if self.disorder == 'ordered_sites':
224
            molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?'))
225
226
        may_have_disorder = False
227
        if self.disorder == 'skip':
228
            for a in molecule.atoms:
229
                occ = a.occupancy
230
                if _atom_has_disorder(a.label, occ):
231
                    may_have_disorder = True
232
                    break
233
234
            if not may_have_disorder:
235
                if crystal.has_disorder or entry.has_disorder:
236
                    if self.show_warnings:
237
                        warnings.warn(f'Skipping {self.current_identifier} as structure is disordered')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (103/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
238
                    return None
239
240
        if self.remove_hydrogens:
241
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
242
243
        if self.heaviest_component:
244
            molecule = _heaviest_component(molecule)
245
246
        crystal.molecule = molecule
247
248
        # by here all atoms to be removed have been (except via ordered_sites).
249
        # If disorder == 'skip' and there were atom(s) with occ < 1 found
250
        # eariler, we check if all such atoms were removed. If not, skip.
251
        if self.disorder == 'skip' and may_have_disorder:
252
            for a in crystal.disordered_molecule.atoms:
253
                occ = a.occupancy
254
                if _atom_has_disorder(a.label, occ):
255
                    if self.show_warnings:
256
                        warnings.warn(
257
                            f'Skipping {self.current_identifier} as structure is disordered')
258
                    return None
259
260
        # if disorder is all_sites, we need to know where disorder is to ignore overlaps
261
        asym_is_disordered = []     # True/False list same length as asym unit
262
        if self.disorder == 'all_sites':
263
            for a in crystal.asymmetric_unit_molecule.atoms:
264
                occ = a.occupancy
265
                if _atom_has_disorder(a.label, occ):
266
                    asym_is_disordered.append(True)
267
                else:
268
                    asym_is_disordered.append(False)
269
270
        # check all atoms have coords. option/default remove unknown sites?
271
        if not molecule.all_atoms_have_sites or \
272
           any(a.fractional_coordinates is None for a in molecule.atoms):
273
            if self.show_warnings:
274
                warnings.warn(
275
                    f'Skipping {self.current_identifier} as some atoms do not have sites')
276
            return None
277
278
        # get cell & asymmetric unit
279
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
280
        asym_frac_motif = np.array([tuple(a.fractional_coordinates)
281
                                    for a in crystal.asymmetric_unit_molecule.atoms])
282
        asym_frac_motif = np.mod(asym_frac_motif, 1)
283
        asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms]
284
285
        # remove overlapping sites, check sites exist
286
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
287
        if keep_sites is not None:
288
            asym_frac_motif = asym_frac_motif[keep_sites]
289
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
290
291
        if self._has_no_valid_sites(asym_frac_motif):
292
            return None
293
294
        sitesym = crystal.symmetry_operators
295
        if not sitesym:
296
            sitesym = ['x,y,z', ]
297
298
        entry.crystal.molecule = crystal.disordered_molecule    # for extract_data. remove?
299
300
        return self._construct_periodic_set(entry, asym_frac_motif, asym_symbols, sitesym, cell)
301
302
    def expand(
303
            self,
304
            asym_frac_motif: np.ndarray,
305
            sitesym: Sequence[str]
306
    ) -> Tuple[np.ndarray, ...]:
307
        """
308
        Asymmetric unit's fractional coords + sitesyms (as strings)
309
        -->
310
        frac_motif, asym_unit, multiplicities, inverses
311
        """
312
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
313
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
314
        all_sites = []
315
        asym_unit = [0]
316
        multiplicities = []
317
        inverses = []
318
319
        for inv, site in enumerate(asym_frac_motif):
320
            multiplicity = 0
321
322
            for rot, trans in zip(rotations, translations):
323
                site_ = np.mod(np.dot(rot, site) + trans, 1)
324
325
                if not all_sites:
326
                    all_sites.append(site_)
327
                    inverses.append(inv)
328
                    multiplicity += 1
329
                    continue
330
331
                if not self._is_site_overlapping(site_, all_sites, inverses, inv):
332
                    all_sites.append(site_)
333
                    inverses.append(inv)
334
                    multiplicity += 1
335
336
            if multiplicity > 0:
337
                multiplicities.append(multiplicity)
338
                asym_unit.append(len(all_sites))
339
340
        frac_motif = np.array(all_sites)
341
        asym_unit = np.array(asym_unit[:-1])
342
        multiplicities = np.array(multiplicities)
343
        return frac_motif, asym_unit, multiplicities, inverses
344
345
    def _is_site_overlapping(self, new_site, all_sites, inverses, inv):
346
        """Return True (and warn) if new_site overlaps with a site in all_sites."""
347
        diffs1 = np.abs(new_site - all_sites)
348
        diffs2 = np.abs(diffs1 - 1)
349
        mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol,
350
                                    diffs2 <= _Reader.equiv_site_tol),
351
                        axis=-1)
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 2 spaces).
Loading history...
352
353
        if np.any(mask):
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
354
            where_equal = np.argwhere(mask).flatten()
355
            for ind in where_equal:
356
                if inverses[ind] == inv:
357
                    pass
358
                else:
359
                    if self.show_warnings:
360
                        warnings.warn(
361
                            f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (108/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
362
            return True
363
        else:
364
            return False
365
366
    def _validate_sites(self, asym_frac_motif, asym_is_disordered):
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
367
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
368
        site_diffs2 = np.abs(site_diffs1 - 1)
369
        overlapping = np.triu(np.all(
370
            (site_diffs1 <= _Reader.equiv_site_tol) |
371
            (site_diffs2 <= _Reader.equiv_site_tol),
372
            axis=-1), 1)
373
374
        if self.disorder == 'all_sites':
375
            for i, j in np.argwhere(overlapping):
376
                if asym_is_disordered[i] or asym_is_disordered[j]:
377
                    overlapping[i, j] = False
378
379
        if overlapping.any():
380
            if self.show_warnings:
381
                warnings.warn(
382
                    f'{self.current_identifier} may have overlapping sites; duplicates will be removed')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
383
            keep_sites = ~overlapping.any(0)
384
            return keep_sites
385
386
    def _has_no_valid_sites(self, motif):
387
        if motif.shape[0] == 0:
388
            if self.show_warnings:
389
                warnings.warn(
390
                    f'Skipping {self.current_identifier} as there are no sites with coordinates')
391
            return True
392
        return False
393
394
    def _construct_periodic_set(self, raw_item, asym_frac_motif, asym_symbols, sitesym, cell):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
395
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
396
        full_types = [asym_symbols[i] for i in inverses]
397
        motif = frac_motif @ cell
398
399
        kwargs = {
400
            'name': self.current_identifier,
401
            'asymmetric_unit': asym_unit,
402
            'wyckoff_multiplicities': multiplicities,
403
            'types': full_types,
404
        }
405
406
        if self.current_filename:
407
            kwargs['filename'] = self.current_filename
408
409
        if self.extract_data is not None:
410
            for key in self.extract_data:
411
                kwargs[key] = self.extract_data[key](raw_item)
412
413
        return PeriodicSet(motif, cell, **kwargs)
414
415
def _heaviest_component(molecule):
416
    """Heaviest component (removes all but the heaviest component of the asym unit).
417
    Intended for removing solvents. Probably doesn't play well with disorder"""
418
    if len(molecule.components) > 1:
419
        component_weights = []
420
        for component in molecule.components:
421
            weight = 0
422
            for a in component.atoms:
423
                if isinstance(a.atomic_weight, (float, int)):
424
                    if isinstance(a.occupancy, (float, int)):
425
                        weight += a.occupancy * a.atomic_weight
426
                    else:
427
                        weight += a.atomic_weight
428
            component_weights.append(weight)
429
        largest_component_arg = np.argmax(np.array(component_weights))
430
        molecule = molecule.components[largest_component_arg]
431
432
    return molecule
433
434
def _atom_has_disorder(label, occupancy):
435
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
436
437
def _validate_extract_data(extract_data):
438
    if not isinstance(extract_data, dict):
439
        raise ValueError('extract_data must be a dict with callable values')
440
    for key in extract_data:
441
        if not callable(extract_data[key]):
442
            raise ValueError('extract_data must be a dict with callable values')
443
        if key in _Reader.reserved_tags:
444
            raise ValueError(f'extract_data includes reserved key {key}')
0 ignored issues
show
Coding Style introduced by
Final newline missing
Loading history...