Passed
Push — master ( c9cb02...43c976 )
by Daniel
01:44
created

amd._reader.atom_has_disorder()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Contains base reader class for the io module.
2
This class implements the converters from CifBlock, Entry to PeriodicSets.
3
"""
4
5
import warnings
6
from typing import Callable, Iterable, Sequence, Tuple
7
8
import numpy as np
9
import ase.spacegroup.spacegroup    # parse_sitesym
10
import ase.io.cif
11
12
from .periodicset import PeriodicSet
13
from .utils import cellpar_to_cell
14
15
16
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (10/7)
Loading history...
17
    """Base Reader class. Contains parsers for converting ase CifBlock
18
    and ccdc Entry objects to PeriodicSets.
19
    Intended to be inherited and then a generator set to self._generator.
20
    First make a new method for _Reader converting object to PeriodicSet
21
    (e.g. named _X_to_PSet). Then make this class outline:
22
23
    class XReader(_Reader):
24
        def __init__(self, ..., **kwargs):
25
26
        super().__init__(**kwargs)
27
28
        # setup and checks
29
30
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
31
32
        # set self._generator like this
33
        self._generator = self._read(iterable, self._X_to_PSet)
34
    """
35
36
    equiv_site_tol = 1e-3
37
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
38
    reserved_tags = {
39
        'motif',
40
        'cell',
41
        'name',
42
        'asymmetric_unit',
43
        'wyckoff_multiplicities',
44
        'types',
45
        'filename',}
46
    atom_site_fract_tags = [
47
        '_atom_site_fract_x',
48
        '_atom_site_fract_y',
49
        '_atom_site_fract_z',]
50
    atom_site_cartn_tags = [
51
        '_atom_site_cartn_x',
52
        '_atom_site_cartn_y',
53
        '_atom_site_cartn_z',]
54
    symop_tags = [
55
        '_space_group_symop_operation_xyz',
56
        '_space_group_symop.operation_xyz',
57
        '_symmetry_equiv_pos_as_xyz',]
58
59
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
60
            self,
61
            remove_hydrogens=False,
62
            disorder='skip',
63
            heaviest_component=False,
64
            show_warnings=True,
65
            extract_data=None,
66
            include_if=None):
67
68
        if disorder not in _Reader.disorder_options:
69
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
70
71
        # validate extract_data
72
        if extract_data is None:
73
            self.extract_data = {}
74
        else:
75
            if not isinstance(extract_data, dict):
76
                raise ValueError('extract_data must be a dict of callables')
77
            for key in extract_data:
78
                if not callable(extract_data[key]):
79
                    raise ValueError('extract_data must be a dict of callables')
80
                if key in _Reader.reserved_tags:
81
                    raise ValueError(f'extract_data includes reserved key {key}')
82
            self.extract_data = extract_data
83
84
        # validate include_if
85
        if include_if is None:
86
            self.include_if = ()
87
        elif not all(callable(func) for func in include_if):
88
            raise ValueError('include_if must be a list of callables')
89
        else:
90
            self.include_if = include_if
91
92
        self.remove_hydrogens = remove_hydrogens
93
        self.disorder = disorder
94
        self.heaviest_component = heaviest_component
95
        self.show_warnings = show_warnings
96
        self.current_name = None
97
        self.current_filename = None
98
        self._generator = []
99
100
    def __iter__(self):
101
        yield from self._generator
102
103
    def read_one(self):
104
        """Read the next (or first) item."""
105
        return next(iter(self._generator))
106
107
    def _map(self, converter: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
108
        """Iterates over iterable, passing items through parser and yielding the 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
109
        result if it is not None. Applies warning and include_if filters.
110
        """
111
112
        with warnings.catch_warnings():
113
            if not self.show_warnings:
114
                warnings.simplefilter('ignore')
115
116
            for item in iterable:
117
118
                if any(not check(item) for check in self.include_if):
119
                    continue
120
121
                periodic_set = converter(item)
122
                
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
123
                if periodic_set is not None:
124
125
                    if self.current_filename:
126
                        periodic_set.tags['filename'] = self.current_filename
127
128
                    for key, func in self.extract_data.items():
129
                        periodic_set.tags[key] = func(item)
130
131
                    yield periodic_set
132
133
    def _cifblock_to_periodicset(self, block) -> PeriodicSet:
134
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
135
136
        self.current_name = block.name
137
        cell = block.get_cell().array
138
139
        # asymmetric unit fractional coords
140
        asym_unit = [block.get(name) for name in _Reader.atom_site_fract_tags]
141
        if None in asym_unit:
142
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
143
            if None in asym_motif:
144
                warnings.warn(f'Skipping {self.current_name} as coordinates were not found')
145
                return None
146
            asym_unit = np.array(asym_motif) @ np.linalg.inv(cell)
147
        asym_unit = np.mod(np.array(asym_unit).T, 1)
148
149
        # asymmetric unit symbols
150
        try:
151
            asym_symbols = block.get_symbols()
152
        except ase.io.cif.NoStructureData as _:
153
            asym_symbols = ['Unknown' for _ in range(len(asym_unit))]
154
155
        # symmetry operators
156
        sitesym = ['x,y,z', ]
157
        for tag in _Reader.symop_tags:
158
            if tag in block:
159
                sitesym = block[tag]
160
                break
161
        if isinstance(sitesym, str):
162
            sitesym = [sitesym]
163
164
        remove_sites = []
165
166
        # handle disorder
167
        occupancies = block.get('_atom_site_occupancy')
168
        labels = block.get('_atom_site_label')
169
        if occupancies is not None:
170
            if self.disorder == 'skip':
171
                if any(atom_has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)):
172
                    warnings.warn(f'Skipping {self.current_name} as structure is disordered')
173
                    return None
174
            elif self.disorder == 'ordered_sites':
175
                remove_sites.extend(
176
                    (i for i, (lab, occ) in enumerate(zip(labels, occupancies))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
177
                        if atom_has_disorder(lab, occ)))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 3 spaces).
Loading history...
178
179
        if self.remove_hydrogens:
180
            remove_sites.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
181
182
        asym_unit = np.delete(asym_unit, remove_sites, axis=0)
183
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove_sites]
184
185
        if self.disorder != 'all_sites':
186
            keep_sites = _unique_sites(asym_unit)
187
            if np.any(keep_sites == False):
0 ignored issues
show
introduced by
Comparison to False should be 'not expr'
Loading history...
188
                warnings.warn(
189
                    f'{self.current_name} may have overlapping sites; duplicates will be removed')
190
            asym_unit = asym_unit[keep_sites]
191
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
192
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
193
        if asym_unit.shape[0] == 0:
194
            warnings.warn(f'Skipping {self.current_name} as there are no sites with coordinates')
195
            return None
196
197
        frac_motif, asym_inds, multiplicities, inverses = self.expand_asym_unit(asym_unit, sitesym)
198
        full_types = [asym_symbols[i] for i in inverses]
199
        motif = frac_motif @ cell
200
201
        tags = {
202
            'name': self.current_name,
203
            'asymmetric_unit': asym_inds,
204
            'wyckoff_multiplicities': multiplicities,
205
            'types': full_types,
206
        }
207
208
        return PeriodicSet(motif, cell, **tags)
209
210
    def _entry_to_periodicset(self, entry) -> PeriodicSet:
211
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
212
213
        self.current_name = entry.identifier
214
        crystal = entry.crystal
215
216
        if not entry.has_3d_structure:
217
            warnings.warn(f'Skipping {self.current_name} as entry has no 3D structure')
218
            return None
219
220
        molecule = crystal.disordered_molecule
221
222
        # handle disorder
223
        if self.disorder == 'skip':
224
            if crystal.has_disorder or entry.has_disorder or \
225
               any(atom_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
226
                warnings.warn(f'Skipping {self.current_name} as structure is disordered')
227
                return None
228
229
        elif self.disorder == 'ordered_sites':
230
            molecule.remove_atoms(a for a in molecule.atoms
231
                                  if atom_has_disorder(a.label, a.occupancy))
232
233
        if self.remove_hydrogens:
234
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
235
236
        # remove all but heaviest component
237
        if self.heaviest_component and len(molecule.components) > 1:
238
            molecule = _heaviest_component(molecule)
239
240
        if not molecule.all_atoms_have_sites or \
241
           any(a.fractional_coordinates is None for a in molecule.atoms):
242
            warnings.warn(f'Skipping {self.current_name} as some atoms do not have sites')
243
            return None
244
245
        # asymmetric unit fractional coords + symbols
246
        crystal.molecule = molecule
247
        asym_atoms = crystal.asymmetric_unit_molecule.atoms
248
        asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
249
        asym_unit = np.mod(asym_unit, 1)
250
        asym_symbols = [a.atomic_symbol for a in asym_atoms]
251
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
252
253
        # symmetry operators
254
        sitesym = crystal.symmetry_operators
255
        if not sitesym:
256
            sitesym = ['x,y,z', ]
257
258
        if self.disorder != 'all_sites':
259
            keep_sites = _unique_sites(asym_unit)
260
            if np.any(keep_sites == False):
0 ignored issues
show
introduced by
Comparison to False should be 'not expr'
Loading history...
261
                warnings.warn(
262
                    f'{self.current_name} may have overlapping sites; duplicates will be removed')
263
            asym_unit = asym_unit[keep_sites]
264
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
265
266
        if asym_unit.shape[0] == 0:
267
            warnings.warn(f'Skipping {self.current_name} as there are no sites with coordinates')
268
            return None
269
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
270
        frac_motif, asym_inds, multiplicities, inverses = self.expand_asym_unit(asym_unit, sitesym)
271
        full_types = [asym_symbols[i] for i in inverses]
272
        motif = frac_motif @ cell
273
274
        tags = {
275
            'name': self.current_name,
276
            'asymmetric_unit': asym_inds,
277
            'wyckoff_multiplicities': multiplicities,
278
            'types': full_types,
279
        }
280
281
        return PeriodicSet(motif, cell, **tags)
282
283
    def expand_asym_unit(
284
        self, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
285
        asym_unit: np.ndarray, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
286
        sitesym: Sequence[str]
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
287
    ) -> Tuple[np.ndarray, ...]:
288
        """
289
        Asymmetric unit's fractional coords + sitesyms (as strings)
290
        -->
291
        frac motif, asym unit inds, multiplicities, inverses
292
        """
293
294
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
295
        all_sites = []
296
        asym_inds = [0]
297
        multiplicities = []
298
        inverses = []
299
300
        for inv, site in enumerate(asym_unit):
301
            multiplicity = 0
302
303
            for rot, trans in zip(rotations, translations):
304
                site_ = np.mod(np.dot(rot, site) + trans, 1)
305
306
                if not all_sites:
307
                    all_sites.append(site_)
308
                    inverses.append(inv)
309
                    multiplicity += 1
310
                    continue
311
312
                # check if site_ overlaps with existing sites
313
                diffs1 = np.abs(site_ - all_sites)
314
                diffs2 = np.abs(diffs1 - 1)
315
                mask = np.all((diffs1 <= _Reader.equiv_site_tol) |
316
                              (diffs2 <= _Reader.equiv_site_tol),
317
                              axis=-1)
318
319
                if np.any(mask):
320
                    where_equal = np.argwhere(mask).flatten()
321
                    for ind in where_equal:
322
                        if inverses[ind] == inv:
323
                            pass
324
                        else:
325
                            warnings.warn(
326
                                f'{self.current_name} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (106/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
327
                else:
328
                    all_sites.append(site_)
329
                    inverses.append(inv)
330
                    multiplicity += 1
331
332
            if multiplicity > 0:
333
                multiplicities.append(multiplicity)
334
                asym_inds.append(len(all_sites))
335
336
        frac_motif = np.array(all_sites)
337
        asym_inds = np.array(asym_inds[:-1])
338
        multiplicities = np.array(multiplicities)
339
        return frac_motif, asym_inds, multiplicities, inverses
340
341
342
def atom_has_disorder(label, occupancy):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
343
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
344
345
346
def _unique_sites(asym_unit):
347
    site_diffs1 = np.abs(asym_unit[:, None] - asym_unit)
348
    site_diffs2 = np.abs(site_diffs1 - 1)
349
    overlapping = np.triu(np.all(
350
        (site_diffs1 <= _Reader.equiv_site_tol) |
351
        (site_diffs2 <= _Reader.equiv_site_tol),
352
        axis=-1), 1)
353
    return ~overlapping.any(axis=0)
354
355
356
def _heaviest_component(molecule):
357
    """Heaviest component (removes all but the heaviest component of the asym unit).
358
    Intended for removing solvents. Probably doesn't play well with disorder"""
359
    component_weights = []
360
    for component in molecule.components:
361
        weight = 0
362
        for a in component.atoms:
363
            if isinstance(a.atomic_weight, (float, int)):
364
                if isinstance(a.occupancy, (float, int)):
365
                    weight += a.occupancy * a.atomic_weight
366
                else:
367
                    weight += a.atomic_weight
368
        component_weights.append(weight)
369
    largest_component_ind = np.argmax(np.array(component_weights))
370
    molecule = molecule.components[largest_component_ind]
371
    return molecule
372