Passed
Push — master ( 858f4b...acad41 )
by Daniel
01:45
created

amd._reader   F

Complexity

Total Complexity 64

Size/Duplication

Total Lines 364
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 64
eloc 243
dl 0
loc 364
rs 3.28
c 0
b 0
f 0

2 Functions

Rating   Name   Duplication   Size   Complexity  
A atom_has_disorder() 0 2 1
A _heaviest_component() 0 16 5

9 Methods

Rating   Name   Duplication   Size   Complexity  
B _Reader.expand_asym_unit() 0 53 8
C _Reader.__init__() 0 40 9
B _Reader._map() 0 14 6
A _Reader._validate_sites() 0 16 2
A _Reader.__iter__() 0 2 1
A _Reader._construct_periodic_set() 0 18 2
F _Reader._entry_to_periodicset() 0 62 15
F _Reader._cifblock_to_periodicset() 0 64 14
A _Reader.read_one() 0 3 1

How to fix   Complexity   

Complexity

Complex classes like amd._reader often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Contains base reader class for the io module.
2
This class implements the converters from CifBlock, Entry to PeriodicSets.
3
"""
4
5
import warnings
6
from typing import Callable, Iterable, Sequence, Tuple
7
8
import numpy as np
9
import ase.spacegroup.spacegroup    # parse_sitesym
10
import ase.io.cif
11
12
from .periodicset import PeriodicSet
13
from .utils import cellpar_to_cell
14
15
16
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (10/7)
Loading history...
17
    """Base Reader class. Contains parsers for converting ase CifBlock
18
    and ccdc Entry objects to PeriodicSets.
19
20
    Intended use:
21
22
    First make a new method for _Reader converting object to PeriodicSet
23
    (e.g. named _X_to_PSet). Then make this class outline:
24
25
    class XReader(_Reader):
26
        def __init__(self, ..., **kwargs):
27
28
        super().__init__(**kwargs)
29
30
        # setup and checks
31
32
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
33
34
        # set self._generator like this
35
        self._generator = self._read(iterable, self._X_to_PSet)
36
    """
37
38
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
39
    reserved_tags = {
40
        'motif',
41
        'cell',
42
        'name',
43
        'asymmetric_unit',
44
        'wyckoff_multiplicities',
45
        'types',
46
        'filename',}
47
    atom_site_fract_tags = [
48
        '_atom_site_fract_x',
49
        '_atom_site_fract_y',
50
        '_atom_site_fract_z',]
51
    atom_site_cartn_tags = [
52
        '_atom_site_cartn_x',
53
        '_atom_site_cartn_y',
54
        '_atom_site_cartn_z',]
55
    symop_tags = [
56
        '_space_group_symop_operation_xyz',
57
        '_space_group_symop.operation_xyz',
58
        '_symmetry_equiv_pos_as_xyz',]
59
60
    equiv_site_tol = 1e-3
61
62
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
63
            self,
64
            remove_hydrogens=False,
65
            disorder='skip',
66
            heaviest_component=False,
67
            show_warnings=True,
68
            extract_data=None,
69
            include_if=None):
70
71
        if disorder not in _Reader.disorder_options:
72
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
73
74
        # validate extract_data
75
        if extract_data is None:
76
            self.extract_data = {}
77
        else:
78
            if not isinstance(extract_data, dict):
79
                raise ValueError('extract_data must be a dict of callables')
80
            for key in extract_data:
81
                if not callable(extract_data[key]):
82
                    raise ValueError('extract_data must be a dict of callables')
83
                if key in _Reader.reserved_tags:
84
                    raise ValueError(f'extract_data includes reserved key {key}')
85
            self.extract_data = extract_data
86
87
        # validate include_if
88
        if include_if is None:
89
            self.include_if = ()
90
        elif not all(callable(func) for func in include_if):
91
            raise ValueError('include_if must be a list of callables')
92
        else:
93
            self.include_if = include_if
94
95
        self.remove_hydrogens = remove_hydrogens
96
        self.disorder = disorder
97
        self.heaviest_component = heaviest_component
98
        self.show_warnings = show_warnings
99
        self.current_name = None
100
        self.current_filename = None
101
        self._generator = []
102
103
    def __iter__(self):
104
        yield from self._generator
105
106
    def read_one(self):
107
        """Read the next (or first) item."""
108
        return next(iter(self._generator))
109
110
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
111
        """Iterates over iterable, passing items through parser and yielding the 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
112
        result if it is not None. Applies warning and include_if filter.
113
        """
114
115
        with warnings.catch_warnings():
116
            if not self.show_warnings:
117
                warnings.simplefilter('ignore')
118
            for item in iterable:
119
                if any(not check(item) for check in self.include_if):
120
                    continue
121
                res = func(item)
122
                if res is not None:
123
                    yield res
124
125
    def _cifblock_to_periodicset(self, block) -> PeriodicSet:
126
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
127
128
        self.current_name = block.name
129
        data = {key: func(block) for key, func in self.extract_data.items()}
130
131
        # unit cell
132
        cell = block.get_cell().array
133
134
        # asymmetric unit fractional coords
135
        asym_unit = [block.get(name) for name in _Reader.atom_site_fract_tags]
136
        if None in asym_unit:
137
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
138
            if None in asym_motif:
139
                warnings.warn(f'Skipping {self.current_name} as coordinates were not found')
140
                return None
141
            asym_unit = np.array(asym_motif) @ np.linalg.inv(cell)
142
        asym_unit = np.mod(np.array(asym_unit).T, 1)
143
144
        # asymmetric unit symbols
145
        try:
146
            asym_symbols = block.get_symbols()
147
        except ase.io.cif.NoStructureData as _:
148
            asym_symbols = ['Unknown' for _ in range(len(asym_unit))]
149
150
        # symmetry operators
151
        sitesym = ['x,y,z', ]
152
        for tag in _Reader.symop_tags:
153
            if tag in block:
154
                sitesym = block[tag]
155
                break
156
        if isinstance(sitesym, str):
157
            sitesym = [sitesym]
158
159
        remove_sites = []
160
161
        # handle disorder
162
        occupancies = block.get('_atom_site_occupancy')
163
        labels = block.get('_atom_site_label')
164
        if occupancies is not None:
165
            if self.disorder == 'skip':
166
                if any(atom_has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)):
167
                    warnings.warn(f'Skipping {self.current_name} as structure is disordered')
168
                    return None
169
            elif self.disorder == 'ordered_sites':
170
                remove_sites.extend(
171
                    (i for i, (lab, occ) in enumerate(zip(labels, occupancies))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
172
                        if atom_has_disorder(lab, occ)))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 3 spaces).
Loading history...
173
174
        if self.remove_hydrogens:
175
            remove_sites.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
176
177
        asym_unit = np.delete(asym_unit, remove_sites, axis=0)
178
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove_sites]
179
180
        if self.disorder != 'all_sites':
181
            asym_unit, asym_symbols = self._validate_sites(asym_unit, asym_symbols)
182
183
        if asym_unit.shape[0] == 0:
184
            warnings.warn(f'Skipping {self.current_name} as there are no sites with coordinates')
185
            return None
186
187
        periodic_set = self._construct_periodic_set(asym_unit, asym_symbols, sitesym, cell, **data)
188
        return periodic_set
189
190
    def _entry_to_periodicset(self, entry) -> PeriodicSet:
191
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
192
193
        data = {key: func(entry) for key, func in self.extract_data.items()}
194
        self.current_name = entry.identifier
195
        crystal = entry.crystal
196
197
        if not entry.has_3d_structure:
198
            warnings.warn(f'Skipping {self.current_name} as entry has no 3D structure')
199
            return None
200
201
        molecule = crystal.disordered_molecule
202
203
        # handle disorder
204
        if self.disorder == 'skip':
205
            if crystal.has_disorder or entry.has_disorder or \
206
               any(atom_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
207
                warnings.warn(f'Skipping {self.current_name} as structure is disordered')
208
                return None
209
210
        elif self.disorder == 'ordered_sites':
211
            molecule.remove_atoms(a for a in molecule.atoms
212
                                  if atom_has_disorder(a.label, a.occupancy))
213
214
        if self.remove_hydrogens:
215
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
216
217
        # remove all but heaviest component
218
        if self.heaviest_component and len(molecule.components) > 1:
219
            molecule = _heaviest_component(molecule)
220
221
        if not molecule.all_atoms_have_sites or \
222
           any(a.fractional_coordinates is None for a in molecule.atoms):
223
            warnings.warn(f'Skipping {self.current_name} as some atoms do not have sites')
224
            return None
225
226
        crystal.molecule = molecule
227
228
        # asymmetric unit fractional coords + symbols
229
        asym_atoms = crystal.asymmetric_unit_molecule.atoms
230
        asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
231
        asym_unit = np.mod(asym_unit, 1)
232
        # asymmetric unit symbols
233
        asym_symbols = [a.atomic_symbol for a in asym_atoms]
234
235
        # unit cell
236
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
237
238
        # symmetry operators
239
        sitesym = crystal.symmetry_operators
240
        if not sitesym:
241
            sitesym = ['x,y,z', ]
242
243
        if self.disorder != 'all_sites':
244
            asym_unit, asym_symbols = self._validate_sites(asym_unit, asym_symbols)
245
246
        if asym_unit.shape[0] == 0:
247
            warnings.warn(f'Skipping {self.current_name} as there are no sites with coordinates')
248
            return None
249
250
        periodic_set = self._construct_periodic_set(asym_unit, asym_symbols, sitesym, cell, **data)
251
        return periodic_set
252
253
    def _construct_periodic_set(self, asym_unit, asym_symbols, sitesym, cell, **kwargs):
254
        """Asym motif + symbols + sitesym + cell (+kwargs) --> PeriodicSet"""
255
        frac_motif, asym_inds, multiplicities, inverses = self.expand_asym_unit(asym_unit, sitesym)
256
        full_types = [asym_symbols[i] for i in inverses]
257
        motif = frac_motif @ cell
258
259
        tags = {
260
            'name': self.current_name,
261
            'asymmetric_unit': asym_inds,
262
            'wyckoff_multiplicities': multiplicities,
263
            'types': full_types,
264
            **kwargs
265
        }
266
267
        if self.current_filename:
268
            tags['filename'] = self.current_filename
269
270
        return PeriodicSet(motif, cell, **tags)
271
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
272
    def _validate_sites(self, asym_unit, asym_symbols):
273
        site_diffs1 = np.abs(asym_unit[:, None] - asym_unit)
274
        site_diffs2 = np.abs(site_diffs1 - 1)
275
        overlapping = np.triu(np.all(
276
            (site_diffs1 <= _Reader.equiv_site_tol) |
277
            (site_diffs2 <= _Reader.equiv_site_tol),
278
            axis=-1), 1)
279
280
        if overlapping.any():
281
            warnings.warn(
282
                f'{self.current_name} may have overlapping sites; duplicates will be removed')
283
            keep_sites = ~overlapping.any(0)
284
            asym_unit = asym_unit[keep_sites]
285
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
286
287
        return asym_unit, asym_symbols
288
289
    def expand_asym_unit(self, asym_unit: np.ndarray, sitesym: Sequence[str]) -> Tuple[np.ndarray, ...]:
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
290
        """
291
        Asymmetric unit's fractional coords + sitesyms (as strings)
292
        -->
293
        frac motif, asym unit inds, multiplicities, inverses
294
        """
295
296
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
297
        all_sites = []
298
        asym_inds = [0]
299
        multiplicities = []
300
        inverses = []
301
302
        for inv, site in enumerate(asym_unit):
303
            multiplicity = 0
304
305
            for rot, trans in zip(rotations, translations):
306
                site_ = np.mod(np.dot(rot, site) + trans, 1)
307
308
                if not all_sites:
309
                    all_sites.append(site_)
310
                    inverses.append(inv)
311
                    multiplicity += 1
312
                    continue
313
314
                # check if site_ overlaps with existing sites
315
                diffs1 = np.abs(site_ - all_sites)
316
                diffs2 = np.abs(diffs1 - 1)
317
                mask = np.all((diffs1 <= _Reader.equiv_site_tol) |
318
                              (diffs2 <= _Reader.equiv_site_tol),
319
                              axis=-1)
320
321
                if np.any(mask):
322
                    where_equal = np.argwhere(mask).flatten()
323
                    for ind in where_equal:
324
                        if inverses[ind] == inv:
325
                            pass
326
                        else:
327
                            warnings.warn(
328
                                f'{self.current_name} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (106/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
329
                else:
330
                    all_sites.append(site_)
331
                    inverses.append(inv)
332
                    multiplicity += 1
333
334
            if multiplicity > 0:
335
                multiplicities.append(multiplicity)
336
                asym_inds.append(len(all_sites))
337
338
        frac_motif = np.array(all_sites)
339
        asym_inds = np.array(asym_inds[:-1])
340
        multiplicities = np.array(multiplicities)
341
        return frac_motif, asym_inds, multiplicities, inverses
342
343
344
def atom_has_disorder(label, occupancy):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
345
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
346
347
348
def _heaviest_component(molecule):
349
    """Heaviest component (removes all but the heaviest component of the asym unit).
350
    Intended for removing solvents. Probably doesn't play well with disorder"""
351
    component_weights = []
352
    for component in molecule.components:
353
        weight = 0
354
        for a in component.atoms:
355
            if isinstance(a.atomic_weight, (float, int)):
356
                if isinstance(a.occupancy, (float, int)):
357
                    weight += a.occupancy * a.atomic_weight
358
                else:
359
                    weight += a.atomic_weight
360
        component_weights.append(weight)
361
    largest_component_arg = np.argmax(np.array(component_weights))
362
    molecule = molecule.components[largest_component_arg]
363
    return molecule
364