Passed
Push — master ( acad41...c85150 )
by Daniel
01:41
created

amd._reader._validate_sites()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Contains base reader class for the io module.
2
This class implements the converters from CifBlock, Entry to PeriodicSets.
3
"""
4
5
import warnings
6
from typing import Callable, Iterable, Sequence, Tuple
7
8
import numpy as np
9
import ase.spacegroup.spacegroup    # parse_sitesym
10
import ase.io.cif
11
12
from .periodicset import PeriodicSet
13
from .utils import cellpar_to_cell
14
15
16
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (10/7)
Loading history...
17
    """Base Reader class. Contains parsers for converting ase CifBlock
18
    and ccdc Entry objects to PeriodicSets.
19
20
    Intended use:
21
22
    First make a new method for _Reader converting object to PeriodicSet
23
    (e.g. named _X_to_PSet). Then make this class outline:
24
25
    class XReader(_Reader):
26
        def __init__(self, ..., **kwargs):
27
28
        super().__init__(**kwargs)
29
30
        # setup and checks
31
32
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
33
34
        # set self._generator like this
35
        self._generator = self._read(iterable, self._X_to_PSet)
36
    """
37
38
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
39
    reserved_tags = {
40
        'motif',
41
        'cell',
42
        'name',
43
        'asymmetric_unit',
44
        'wyckoff_multiplicities',
45
        'types',
46
        'filename',}
47
    atom_site_fract_tags = [
48
        '_atom_site_fract_x',
49
        '_atom_site_fract_y',
50
        '_atom_site_fract_z',]
51
    atom_site_cartn_tags = [
52
        '_atom_site_cartn_x',
53
        '_atom_site_cartn_y',
54
        '_atom_site_cartn_z',]
55
    symop_tags = [
56
        '_space_group_symop_operation_xyz',
57
        '_space_group_symop.operation_xyz',
58
        '_symmetry_equiv_pos_as_xyz',]
59
60
    equiv_site_tol = 1e-3
61
62
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
63
            self,
64
            remove_hydrogens=False,
65
            disorder='skip',
66
            heaviest_component=False,
67
            show_warnings=True,
68
            extract_data=None,
69
            include_if=None):
70
71
        if disorder not in _Reader.disorder_options:
72
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
73
74
        # validate extract_data
75
        if extract_data is None:
76
            self.extract_data = {}
77
        else:
78
            if not isinstance(extract_data, dict):
79
                raise ValueError('extract_data must be a dict of callables')
80
            for key in extract_data:
81
                if not callable(extract_data[key]):
82
                    raise ValueError('extract_data must be a dict of callables')
83
                if key in _Reader.reserved_tags:
84
                    raise ValueError(f'extract_data includes reserved key {key}')
85
            self.extract_data = extract_data
86
87
        # validate include_if
88
        if include_if is None:
89
            self.include_if = ()
90
        elif not all(callable(func) for func in include_if):
91
            raise ValueError('include_if must be a list of callables')
92
        else:
93
            self.include_if = include_if
94
95
        self.remove_hydrogens = remove_hydrogens
96
        self.disorder = disorder
97
        self.heaviest_component = heaviest_component
98
        self.show_warnings = show_warnings
99
        self.current_name = None
100
        self.current_filename = None
101
        self._generator = []
102
103
    def __iter__(self):
104
        yield from self._generator
105
106
    def read_one(self):
107
        """Read the next (or first) item."""
108
        return next(iter(self._generator))
109
110
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
111
        """Iterates over iterable, passing items through parser and yielding the 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
112
        result if it is not None. Applies warning and include_if filter.
113
        """
114
115
        with warnings.catch_warnings():
116
            if not self.show_warnings:
117
                warnings.simplefilter('ignore')
118
119
            for item in iterable:
120
121
                if any(not check(item) for check in self.include_if):
122
                    continue
123
124
                periodic_set = func(item)
125
126
                if periodic_set is not None:
127
128
                    if self.current_filename:
129
                        periodic_set.tags['filename'] = self.current_filename
130
131
                    for key, func in self.extract_data.items():
0 ignored issues
show
unused-code introduced by
Redefining argument with the local name 'func'
Loading history...
132
                        periodic_set.tags[key] = func(item)
133
134
                    yield periodic_set
135
136
    def _cifblock_to_periodicset(self, block) -> PeriodicSet:
137
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
138
139
        self.current_name = block.name
140
        cell = block.get_cell().array
141
142
        # asymmetric unit fractional coords
143
        asym_unit = [block.get(name) for name in _Reader.atom_site_fract_tags]
144
        if None in asym_unit:
145
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
146
            if None in asym_motif:
147
                warnings.warn(f'Skipping {self.current_name} as coordinates were not found')
148
                return None
149
            asym_unit = np.array(asym_motif) @ np.linalg.inv(cell)
150
        asym_unit = np.mod(np.array(asym_unit).T, 1)
151
152
        # asymmetric unit symbols
153
        try:
154
            asym_symbols = block.get_symbols()
155
        except ase.io.cif.NoStructureData as _:
156
            asym_symbols = ['Unknown' for _ in range(len(asym_unit))]
157
158
        # symmetry operators
159
        sitesym = ['x,y,z', ]
160
        for tag in _Reader.symop_tags:
161
            if tag in block:
162
                sitesym = block[tag]
163
                break
164
        if isinstance(sitesym, str):
165
            sitesym = [sitesym]
166
167
        remove_sites = []
168
169
        # handle disorder
170
        occupancies = block.get('_atom_site_occupancy')
171
        labels = block.get('_atom_site_label')
172
        if occupancies is not None:
173
            if self.disorder == 'skip':
174
                if any(atom_has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)):
175
                    warnings.warn(f'Skipping {self.current_name} as structure is disordered')
176
                    return None
177
            elif self.disorder == 'ordered_sites':
178
                remove_sites.extend(
179
                    (i for i, (lab, occ) in enumerate(zip(labels, occupancies))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
180
                        if atom_has_disorder(lab, occ)))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 3 spaces).
Loading history...
181
182
        if self.remove_hydrogens:
183
            remove_sites.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
184
185
        asym_unit = np.delete(asym_unit, remove_sites, axis=0)
186
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove_sites]
187
188
        if self.disorder != 'all_sites':
189
            keep_sites = _validate_sites(asym_unit)
190
            if np.any(keep_sites == False):
0 ignored issues
show
introduced by
Comparison to False should be 'not expr'
Loading history...
191
                warnings.warn(
192
                    f'{self.current_name} may have overlapping sites; duplicates will be removed')
193
            asym_unit = asym_unit[keep_sites]
194
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
195
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
196
        if asym_unit.shape[0] == 0:
197
            warnings.warn(f'Skipping {self.current_name} as there are no sites with coordinates')
198
            return None
199
200
        frac_motif, asym_inds, multiplicities, inverses = self.expand_asym_unit(asym_unit, sitesym)
201
        full_types = [asym_symbols[i] for i in inverses]
202
        motif = frac_motif @ cell
203
204
        tags = {
205
            'name': self.current_name,
206
            'asymmetric_unit': asym_inds,
207
            'wyckoff_multiplicities': multiplicities,
208
            'types': full_types,
209
        }
210
211
        return PeriodicSet(motif, cell, **tags)
212
213
    def _entry_to_periodicset(self, entry) -> PeriodicSet:
214
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
215
216
        self.current_name = entry.identifier
217
        crystal = entry.crystal
218
219
        if not entry.has_3d_structure:
220
            warnings.warn(f'Skipping {self.current_name} as entry has no 3D structure')
221
            return None
222
223
        molecule = crystal.disordered_molecule
224
225
        # handle disorder
226
        if self.disorder == 'skip':
227
            if crystal.has_disorder or entry.has_disorder or \
228
               any(atom_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
229
                warnings.warn(f'Skipping {self.current_name} as structure is disordered')
230
                return None
231
232
        elif self.disorder == 'ordered_sites':
233
            molecule.remove_atoms(a for a in molecule.atoms
234
                                  if atom_has_disorder(a.label, a.occupancy))
235
236
        if self.remove_hydrogens:
237
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
238
239
        # remove all but heaviest component
240
        if self.heaviest_component and len(molecule.components) > 1:
241
            molecule = _heaviest_component(molecule)
242
243
        if not molecule.all_atoms_have_sites or \
244
           any(a.fractional_coordinates is None for a in molecule.atoms):
245
            warnings.warn(f'Skipping {self.current_name} as some atoms do not have sites')
246
            return None
247
248
        # asymmetric unit fractional coords + symbols
249
        crystal.molecule = molecule
250
        asym_atoms = crystal.asymmetric_unit_molecule.atoms
251
        asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
252
        asym_unit = np.mod(asym_unit, 1)
253
        asym_symbols = [a.atomic_symbol for a in asym_atoms]
254
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
255
256
        # symmetry operators
257
        sitesym = crystal.symmetry_operators
258
        if not sitesym:
259
            sitesym = ['x,y,z', ]
260
261
        if self.disorder != 'all_sites':
262
            keep_sites = _validate_sites(asym_unit)
263
            if np.any(keep_sites == False):
0 ignored issues
show
introduced by
Comparison to False should be 'not expr'
Loading history...
264
                warnings.warn(
265
                    f'{self.current_name} may have overlapping sites; duplicates will be removed')
266
            asym_unit = asym_unit[keep_sites]
267
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
268
269
        if asym_unit.shape[0] == 0:
270
            warnings.warn(f'Skipping {self.current_name} as there are no sites with coordinates')
271
            return None
272
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
273
        frac_motif, asym_inds, multiplicities, inverses = self.expand_asym_unit(asym_unit, sitesym)
274
        full_types = [asym_symbols[i] for i in inverses]
275
        motif = frac_motif @ cell
276
277
        tags = {
278
            'name': self.current_name,
279
            'asymmetric_unit': asym_inds,
280
            'wyckoff_multiplicities': multiplicities,
281
            'types': full_types,
282
        }
283
284
        return PeriodicSet(motif, cell, **tags)
285
286
    def expand_asym_unit(self, asym_unit: np.ndarray, sitesym: Sequence[str]) -> Tuple[np.ndarray, ...]:
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
287
        """
288
        Asymmetric unit's fractional coords + sitesyms (as strings)
289
        -->
290
        frac motif, asym unit inds, multiplicities, inverses
291
        """
292
293
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
294
        all_sites = []
295
        asym_inds = [0]
296
        multiplicities = []
297
        inverses = []
298
299
        for inv, site in enumerate(asym_unit):
300
            multiplicity = 0
301
302
            for rot, trans in zip(rotations, translations):
303
                site_ = np.mod(np.dot(rot, site) + trans, 1)
304
305
                if not all_sites:
306
                    all_sites.append(site_)
307
                    inverses.append(inv)
308
                    multiplicity += 1
309
                    continue
310
311
                # check if site_ overlaps with existing sites
312
                diffs1 = np.abs(site_ - all_sites)
313
                diffs2 = np.abs(diffs1 - 1)
314
                mask = np.all((diffs1 <= _Reader.equiv_site_tol) |
315
                              (diffs2 <= _Reader.equiv_site_tol),
316
                              axis=-1)
317
318
                if np.any(mask):
319
                    where_equal = np.argwhere(mask).flatten()
320
                    for ind in where_equal:
321
                        if inverses[ind] == inv:
322
                            pass
323
                        else:
324
                            warnings.warn(
325
                                f'{self.current_name} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (106/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
326
                else:
327
                    all_sites.append(site_)
328
                    inverses.append(inv)
329
                    multiplicity += 1
330
331
            if multiplicity > 0:
332
                multiplicities.append(multiplicity)
333
                asym_inds.append(len(all_sites))
334
335
        frac_motif = np.array(all_sites)
336
        asym_inds = np.array(asym_inds[:-1])
337
        multiplicities = np.array(multiplicities)
338
        return frac_motif, asym_inds, multiplicities, inverses
339
340
341
def atom_has_disorder(label, occupancy):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
342
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
343
344
345
def _validate_sites(asym_unit):
346
    site_diffs1 = np.abs(asym_unit[:, None] - asym_unit)
347
    site_diffs2 = np.abs(site_diffs1 - 1)
348
    overlapping = np.triu(np.all(
349
        (site_diffs1 <= _Reader.equiv_site_tol) |
350
        (site_diffs2 <= _Reader.equiv_site_tol),
351
        axis=-1), 1)
352
    return ~overlapping.any(0)
353
354
355
def _heaviest_component(molecule):
356
    """Heaviest component (removes all but the heaviest component of the asym unit).
357
    Intended for removing solvents. Probably doesn't play well with disorder"""
358
    component_weights = []
359
    for component in molecule.components:
360
        weight = 0
361
        for a in component.atoms:
362
            if isinstance(a.atomic_weight, (float, int)):
363
                if isinstance(a.occupancy, (float, int)):
364
                    weight += a.occupancy * a.atomic_weight
365
                else:
366
                    weight += a.atomic_weight
367
        component_weights.append(weight)
368
    largest_component_arg = np.argmax(np.array(component_weights))
369
    molecule = molecule.components[largest_component_arg]
370
    return molecule
371