Passed
Push — master ( 93bfb7...64a86a )
by Daniel
01:42
created

amd._reader   F

Complexity

Total Complexity 71

Size/Duplication

Total Lines 375
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 71
eloc 259
dl 0
loc 375
rs 2.7199
c 0
b 0
f 0

13 Methods

Rating   Name   Duplication   Size   Complexity  
B _Reader.__init__() 0 32 5
A _Reader.__iter__() 0 2 1
A _Reader.read_one() 0 3 1
B _Reader._map() 0 14 6
B _Reader.expand() 0 38 6
A _Reader._asym_unit_from_crystal() 0 13 2
A _Reader._has_no_valid_sites() 0 6 2
B _Reader._asym_unit_from_cifblock() 0 29 7
A _Reader._is_site_overlapping() 0 19 4
A _Reader._validate_sites() 0 16 2
A _Reader._construct_periodic_set() 0 18 2
F _Reader._entry_to_periodicset() 0 46 14
B _Reader._cifblock_to_periodicset() 0 35 8

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _atom_has_disorder() 0 2 1
A _validate_extract_data() 0 8 5
A _heaviest_component() 0 16 5

How to fix   Complexity   

Complexity

Complex classes like amd._reader often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Contains base reader class for the io module. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
2
This class implements the converters from CifBlock, Entry to PeriodicSets.
3
"""
4
5
import warnings
6
from typing import Callable, Iterable, Sequence, Tuple
7
8
import numpy as np
9
import ase.spacegroup.spacegroup    # parse_sitesym
10
import ase.io.cif
11
12
from .periodicset import PeriodicSet
13
from .utils import cellpar_to_cell
14
15
16
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (10/7)
Loading history...
17
    """Base Reader class. Contains parsers for converting ase CifBlock
18
    and ccdc Entry objects to PeriodicSets.
19
20
    Intended use:
21
22
    First make a new method for _Reader converting object to PeriodicSet
23
    (e.g. named _X_to_PSet). Then make this class outline:
24
25
    class XReader(_Reader):
26
        def __init__(self, ..., **kwargs):
27
28
        super().__init__(**kwargs)
29
30
        # setup and checks
31
32
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
33
34
        # set self._generator like this
35
        self._generator = self._read(iterable, self._X_to_PSet)
36
    """
37
38
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
39
    reserved_tags = {
40
        'motif',
41
        'cell',
42
        'name',
43
        'asymmetric_unit',
44
        'wyckoff_multiplicities',
45
        'types',
46
        'filename',}
47
    atom_site_fract_tags = [
48
        '_atom_site_fract_x',
49
        '_atom_site_fract_y',
50
        '_atom_site_fract_z',]
51
    atom_site_cartn_tags = [
52
        '_atom_site_cartn_x',
53
        '_atom_site_cartn_y',
54
        '_atom_site_cartn_z',]
55
    symop_tags = [
56
        '_space_group_symop_operation_xyz',
57
        '_space_group_symop.operation_xyz',
58
        '_symmetry_equiv_pos_as_xyz',]
59
60
    equiv_site_tol = 1e-3
61
62
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
63
            self,
64
            remove_hydrogens=False,
65
            disorder='skip',
66
            heaviest_component=False,
67
            show_warnings=True,
68
            extract_data=None,
69
            include_if=None):
70
71
        if disorder not in _Reader.disorder_options:
72
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
73
74
        if extract_data is None:
75
            self.extract_data = {}
76
        else:
77
            _validate_extract_data(extract_data)
78
            self.extract_data = extract_data
79
80
        if include_if is None:
81
            self.include_if = ()
82
        elif not all(callable(func) for func in include_if):
83
            raise ValueError('include_if must be a list of callables')
84
        else:
85
            self.include_if = include_if
86
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
87
        self.remove_hydrogens = remove_hydrogens
88
        self.disorder = disorder
89
        self.heaviest_component = heaviest_component
90
        self.show_warnings = show_warnings
91
        self.current_name = None
92
        self.current_filename = None
93
        self._generator = []
94
95
    def __iter__(self):
96
        yield from self._generator
97
98
    def read_one(self):
99
        """Read the next (or first) item."""
100
        return next(iter(self._generator))
101
102
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
103
        """Iterates over iterable, passing items through parser and yielding the 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
104
        result if it is not None. Applies warning and include_if filter.
105
        """
106
107
        with warnings.catch_warnings():
108
            if not self.show_warnings:
109
                warnings.simplefilter('ignore')
110
            for item in iterable:
111
                if any(not check(item) for check in self.include_if):
112
                    continue
113
                res = func(item)
114
                if res is not None:
115
                    yield res
116
117
    def _cifblock_to_periodicset(self, block) -> PeriodicSet:
118
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
119
120
        data = {key: func(block) for key, func in self.extract_data.items()}
121
        self.current_name = block.name
122
        asym_unit, asym_symbols, sitesym, cell = self._asym_unit_from_cifblock(block)
123
        occupancies = block.get('_atom_site_occupancy')
124
        labels = block.get('_atom_site_label')
125
        remove_sites = []
126
127
        if occupancies is not None:
128
            if self.disorder == 'skip':
129
                if any(_atom_has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)):
130
                    warnings.warn(f'Skipping {self.current_name} as structure is disordered')
131
                    return None
132
133
            elif self.disorder == 'ordered_sites':
134
                remove_sites.extend(
135
                    (i for i, (lab, occ) in enumerate(zip(labels, occupancies)) 
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
Coding Style introduced by
Trailing whitespace
Loading history...
136
                     if _atom_has_disorder(lab, occ)))
137
138
        if self.remove_hydrogens:
139
            remove_sites.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
140
141
        asym_unit = np.delete(asym_unit, remove_sites, axis=0)
142
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove_sites]
143
144
        if not self.disorder == 'all_sites':
0 ignored issues
show
Unused Code introduced by
Consider changing "not self.disorder == 'all_sites'" to "self.disorder != 'all_sites'"
Loading history...
145
            asym_unit, asym_symbols = self._validate_sites(asym_unit, asym_symbols)
146
147
        if self._has_no_valid_sites(asym_unit):
148
            return None
149
150
        periodic_set = self._construct_periodic_set(asym_unit, asym_symbols, sitesym, cell, **data)
151
        return periodic_set
152
153
    def _asym_unit_from_cifblock(self, block):
154
        """ase.io.cif.CIFBlock --> 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
155
        asymmetric unit (frac coords), asym_symbols, cell, symops (as strings)"""
156
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
157
        cell = block.get_cell().array
158
        asym_unit = [block.get(name) for name in _Reader.atom_site_fract_tags]
159
        if None in asym_unit:
160
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
161
            if None in asym_motif:
162
                warnings.warn(f'Skipping {self.current_name} as coordinates were not found')
163
                return None
164
            asym_unit = np.array(asym_motif) @ np.linalg.inv(cell)
165
        asym_unit = np.mod(np.array(asym_unit).T, 1)
166
167
        try:
168
            asym_symbols = block.get_symbols()
169
        except ase.io.cif.NoStructureData as _:
170
            asym_symbols = ['Unknown' for _ in range(len(asym_unit))]
171
172
        sitesym = ['x,y,z', ]
173
        for tag in _Reader.symop_tags:
174
            if tag in block:
175
                sitesym = block[tag]
176
                break
177
178
        if isinstance(sitesym, str):
179
            sitesym = [sitesym]
180
181
        return asym_unit, asym_symbols, sitesym, cell
182
183
    def _entry_to_periodicset(self, entry) -> PeriodicSet:
184
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
185
186
        data = {key: func(entry) for key, func in self.extract_data.items()}
187
        self.current_name = entry.identifier
188
        crystal = entry.crystal
189
190
        if not entry.has_3d_structure:
191
            warnings.warn(f'Skipping {self.current_name} as entry has no 3D structure')
192
            return None
193
194
        molecule = crystal.disordered_molecule   
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
195
196
        if self.disorder == 'skip':
197
            if crystal.has_disorder or entry.has_disorder or \
198
               any(_atom_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
199
                warnings.warn(f'Skipping {self.current_name} as structure is disordered')
200
                return None
201
202
        elif self.disorder == 'ordered_sites':
203
            molecule.remove_atoms(a for a in molecule.atoms 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
204
                                  if _atom_has_disorder(a.label, a.occupancy))
205
206
        if self.remove_hydrogens:
207
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
208
209
        if self.heaviest_component and len(molecule.components) > 1:
210
            molecule = _heaviest_component(molecule)
211
212
        # check all atoms have coords. option/default remove unknown sites?
213
        if not molecule.all_atoms_have_sites or \
214
           any(a.fractional_coordinates is None for a in molecule.atoms):
215
            warnings.warn(f'Skipping {self.current_name} as some atoms do not have sites')
216
            return None
217
218
        crystal.molecule = molecule
219
        asym_unit, asym_symbols, sitesym, cell = self._asym_unit_from_crystal(crystal)
220
221
        if not self.disorder == 'all_sites':
0 ignored issues
show
Unused Code introduced by
Consider changing "not self.disorder == 'all_sites'" to "self.disorder != 'all_sites'"
Loading history...
222
            asym_unit, asym_symbols = self._validate_sites(asym_unit, asym_symbols)
223
224
        if self._has_no_valid_sites(asym_unit):
225
            return None
226
227
        periodic_set = self._construct_periodic_set(asym_unit, asym_symbols, sitesym, cell, **data)
228
        return periodic_set
229
230
    def _asym_unit_from_crystal(self, crystal):
0 ignored issues
show
Coding Style introduced by
This method could be written as a function/class method.

If a method does not access any attributes of the class, it could also be implemented as a function or static method. This can help improve readability. For example

class Foo:
    def some_method(self, x, y):
        return x + y;

could be written as

class Foo:
    @classmethod
    def some_method(cls, x, y):
        return x + y;
Loading history...
231
        """ase.io.cif.CIFBlock -->
232
        asymmetric unit (frac coords), asym_symbols, symops, cell"""
233
234
        asym_atoms = crystal.asymmetric_unit_molecule.atoms
235
        asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
236
        asym_unit = np.mod(asym_unit, 1)
237
        asym_symbols = [a.atomic_symbol for a in asym_atoms]
238
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
239
        sitesym = crystal.symmetry_operators
240
        if not sitesym:
241
            sitesym = ['x,y,z', ]
242
        return asym_unit, asym_symbols, sitesym, cell
243
244
    def _is_site_overlapping(self, new_site, all_sites, inverses, inv):
245
        """Return True (and warn) if new_site overlaps with a site in all_sites."""
246
        diffs1 = np.abs(new_site - all_sites)
247
        diffs2 = np.abs(diffs1 - 1)
248
        mask = np.all((diffs1 <= _Reader.equiv_site_tol) |
249
                      (diffs2 <= _Reader.equiv_site_tol),
250
                    axis=-1)
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (add 2 spaces).
Loading history...
251
252
        if np.any(mask):
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
253
            where_equal = np.argwhere(mask).flatten()
254
            for ind in where_equal:
255
                if inverses[ind] == inv:
256
                    pass
257
                else:
258
                    warnings.warn(
259
                        f'{self.current_name} has equivalent positions {inverses[ind]} and {inv}')
260
            return True
261
        else:
262
            return False
263
264
    def _validate_sites(self, asym_unit, asym_symbols):
265
        site_diffs1 = np.abs(asym_unit[:, None] - asym_unit)
266
        site_diffs2 = np.abs(site_diffs1 - 1)
267
        overlapping = np.triu(np.all(
268
            (site_diffs1 <= _Reader.equiv_site_tol) |
269
            (site_diffs2 <= _Reader.equiv_site_tol),
270
            axis=-1), 1)
271
272
        if overlapping.any():
273
            warnings.warn(
274
                f'{self.current_name} may have overlapping sites; duplicates will be removed')
275
            keep_sites = ~overlapping.any(0)
276
            asym_unit = asym_unit[keep_sites]
277
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
278
279
        return asym_unit, asym_symbols
280
281
    def _has_no_valid_sites(self, motif):
282
        if motif.shape[0] == 0:
283
            warnings.warn(
284
                f'Skipping {self.current_name} as there are no sites with coordinates')
285
            return True
286
        return False
287
288
    def _construct_periodic_set(self, asym_unit, asym_symbols, sitesym, cell, **kwargs):
289
        """Asym motif + symbols + sitesym + cell (+kwargs) --> PeriodicSet"""
290
        frac_motif, asym_inds, multiplicities, inverses = self.expand(asym_unit, sitesym)
291
        full_types = [asym_symbols[i] for i in inverses]
292
        motif = frac_motif @ cell
293
294
        tags = {
295
            'name': self.current_name,
296
            'asymmetric_unit': asym_inds,
297
            'wyckoff_multiplicities': multiplicities,
298
            'types': full_types,
299
            **kwargs
300
        }
301
302
        if self.current_filename:
303
            tags['filename'] = self.current_filename
304
305
        return PeriodicSet(motif, cell, **tags)
306
307
    def expand(self, asym_unit: np.ndarray, sitesym: Sequence[str]) -> Tuple[np.ndarray, ...]:
308
        """
309
        Asymmetric unit's fractional coords + sitesyms (as strings)
310
        -->
311
        frac motif, asym unit inds, multiplicities, inverses
312
        """
313
314
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
315
        all_sites = []
316
        asym_inds = [0]
317
        multiplicities = []
318
        inverses = []
319
320
        for inv, site in enumerate(asym_unit):
321
            multiplicity = 0
322
323
            for rot, trans in zip(rotations, translations):
324
                site_ = np.mod(np.dot(rot, site) + trans, 1)
325
326
                if not all_sites:
327
                    all_sites.append(site_)
328
                    inverses.append(inv)
329
                    multiplicity += 1
330
                    continue
331
332
                if not self._is_site_overlapping(site_, all_sites, inverses, inv):
333
                    all_sites.append(site_)
334
                    inverses.append(inv)
335
                    multiplicity += 1
336
337
            if multiplicity > 0:
338
                multiplicities.append(multiplicity)
339
                asym_inds.append(len(all_sites))
340
341
        frac_motif = np.array(all_sites)
342
        asym_inds = np.array(asym_inds[:-1])
343
        multiplicities = np.array(multiplicities)
344
        return frac_motif, asym_inds, multiplicities, inverses
345
346
347
def _atom_has_disorder(label, occupancy):
348
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
349
350
def _heaviest_component(molecule):
351
    """Heaviest component (removes all but the heaviest component of the asym unit).
352
    Intended for removing solvents. Probably doesn't play well with disorder"""
353
    component_weights = []
354
    for component in molecule.components:
355
        weight = 0
356
        for a in component.atoms:
357
            if isinstance(a.atomic_weight, (float, int)):
358
                if isinstance(a.occupancy, (float, int)):
359
                    weight += a.occupancy * a.atomic_weight
360
                else:
361
                    weight += a.atomic_weight
362
        component_weights.append(weight)
363
    largest_component_arg = np.argmax(np.array(component_weights))
364
    molecule = molecule.components[largest_component_arg]
365
    return molecule
366
367
def _validate_extract_data(extract_data):
368
    if not isinstance(extract_data, dict):
369
        raise ValueError('extract_data must be a dict of callables')
370
    for key in extract_data:
371
        if not callable(extract_data[key]):
372
            raise ValueError('extract_data must be a dict of callables')
373
        if key in _Reader.reserved_tags:
374
            raise ValueError(f'extract_data includes reserved key {key}')
375