Passed
Push — master ( e05216...ca5a49 )
by Daniel
01:45
created

amd._reader._Reader._cifblock_to_periodicset()   F

Complexity

Conditions 18

Size

Total Lines 73
Code Lines 55

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 55
dl 0
loc 73
rs 1.2
c 0
b 0
f 0
cc 18
nop 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd._reader._Reader._cifblock_to_periodicset() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Contains base reader class for the io module. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
2
This class implements the converters from CifBlock, Entry to PeriodicSets.
3
"""
4
5
import warnings
6
from typing import Callable, Iterable, Sequence, Tuple
7
8
import numpy as np
9
import ase.spacegroup.spacegroup    # parse_sitesym
10
import ase.io.cif
11
12
from .periodicset import PeriodicSet
13
from .utils import cellpar_to_cell
14
15
16
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (9/7)
Loading history...
17
    """Base Reader class. Contains parsers for converting ase CifBlock
18
    and ccdc Entry objects to PeriodicSets.
19
20
    Intended use:
21
22
    First make a new method for _Reader converting object to PeriodicSet
23
    (e.g. named _X_to_PSet). Then make this class outline:
24
25
    class XReader(_Reader):
26
        def __init__(self, ..., **kwargs):
27
28
        super().__init__(**kwargs)
29
30
        # setup and checks
31
32
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
33
34
        # set self._generator like this
35
        self._generator = self._read(iterable, self._X_to_PSet)
36
    """
37
38
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
39
    reserved_tags = {
40
        'motif',
41
        'cell',
42
        'name',
43
        'asymmetric_unit',
44
        'wyckoff_multiplicities',
45
        'types',
46
        'filename',}
47
    atom_site_fract_tags = [
48
        '_atom_site_fract_x',
49
        '_atom_site_fract_y',
50
        '_atom_site_fract_z',]
51
    atom_site_cartn_tags = [
52
        '_atom_site_cartn_x',
53
        '_atom_site_cartn_y',
54
        '_atom_site_cartn_z',]
55
    symop_tags = [
56
        '_space_group_symop_operation_xyz',
57
        '_space_group_symop.operation_xyz',
58
        '_symmetry_equiv_pos_as_xyz',]
59
60
    equiv_site_tol = 1e-3
61
62
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
63
            self,
64
            remove_hydrogens=False,
65
            disorder='skip',
66
            heaviest_component=False,
67
            show_warnings=True,
68
            extract_data=None,
69
            include_if=None):
70
71
        if disorder not in _Reader.disorder_options:
72
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
73
74
        if extract_data is None:
75
            self.extract_data = {}
76
        else:
77
            _validate_extract_data(extract_data)
78
            self.extract_data = extract_data
79
80
        if include_if is None:
81
            self.include_if = ()
82
        elif not all(callable(func) for func in include_if):
83
            raise ValueError('include_if must be a list of callables')
84
        else:
85
            self.include_if = include_if
86
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
87
        self.remove_hydrogens = remove_hydrogens
88
        self.disorder = disorder
89
        self.heaviest_component = heaviest_component
90
        self.show_warnings = show_warnings
91
        self.current_identifier = None
92
        self.current_filename = None
93
        self._generator = []
94
95
    def __iter__(self):
96
        yield from self._generator
97
98
    def read_one(self):
99
        """Read the next (or first) item."""
100
        return next(iter(self._generator))
101
102
    # basically the builtin map, but skips items if the function returned None.
103
    # The object returned by this function (Iterable of PeriodicSets) is set to
104
    # self._generator; then iterating over the Reader iterates over
105
    # self._generator.
106
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
107
        """Iterates over iterable, passing items through parser and
108
        yielding the result if it is not None.
109
        """
110
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
111
        with warnings.catch_warnings():
112
            if not self.show_warnings:
113
                warnings.simplefilter('ignore')
114
            for item in iterable:
115
                res = func(item)
116
                if res is not None:
117
                    yield res
118
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
119
    def _cifblock_to_periodicset(self, block) -> PeriodicSet:
120
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
121
122
        if not all(check(block) for check in self.include_if):
123
            return None
124
125
        self.current_identifier = block.name
126
        cell = block.get_cell().array
127
        asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
128
        if None in asym_frac_motif:
129
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
130
            if None in asym_motif:
131
                warnings.warn(f'Skipping {self.current_identifier} as coordinates were not found')
132
                return None
133
            asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
134
        asym_frac_motif = np.array(asym_frac_motif).T
135
136
        try:
137
            asym_symbols = block.get_symbols()
138
        except ase.io.cif.NoStructureData as _:
139
            asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))]
140
141
        # indices of sites to remove
142
        remove = []
143
        if self.remove_hydrogens:
144
            remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
145
146
        # find disordered sites
147
        asym_is_disordered = []
148
        occupancies = block.get('_atom_site_occupancy')
149
        labels = block.get('_atom_site_label')
150
        if occupancies is not None:
151
            disordered = []     # indices where there is disorder
152
            for i, (occ, label) in enumerate(zip(occupancies, labels)):
153
                if _atom_has_disorder(label, occ):
154
                    if i not in remove:
155
                        disordered.append(i)
156
                        asym_is_disordered.append(True)
157
                else:
158
                    asym_is_disordered.append(False)
159
160
            if self.disorder == 'skip' and len(disordered) > 0:
161
                warnings.warn(f'Skipping {self.current_identifier} as structure is disordered')
162
                return None
163
164
            if self.disorder == 'ordered_sites':
165
                remove.extend(disordered)
166
167
        # remove sites
168
        asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1)
169
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove]
170
        asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove]
171
172
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
173
        if keep_sites is not None:
174
            asym_frac_motif = asym_frac_motif[keep_sites]
175
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
176
177
        if self._has_no_valid_sites(asym_frac_motif):
178
            return None
179
180
        sitesym = ['x,y,z', ]
181
        for tag in _Reader.symop_tags:
182
            if tag in block:
183
                sitesym = block[tag]
184
                break
185
186
        if isinstance(sitesym, str):
187
            sitesym = [sitesym]
188
189
        data = {key: func(block) for key, func in self.extract_data.items()}
190
        periodic_set = self._construct_periodic_set(asym_frac_motif, asym_symbols, sitesym, cell, **data)
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
191
        return periodic_set
192
193
    def _entry_to_periodicset(self, entry) -> PeriodicSet:
0 ignored issues
show
best-practice introduced by
Too many return statements (7/6)
Loading history...
194
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
195
196
        if not all(check(entry) for check in self.include_if):
197
            return None
198
199
        crystal = entry.crystal
200
        self.current_identifier = entry.identifier
201
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
202
203
        if not entry.has_3d_structure:
204
            warnings.warn(f'Skipping {self.current_identifier} as entry has no 3D structure')
205
            return None
206
207
        # first disorder check, if skipping. If occ == 1 for all atoms but the entry
208
        # or crystal is listed as having disorder, skip (can't know where disorder is).
209
        # If occ != 1 for any atoms, we wait to see if we remove them before skipping.
210
        molecule = crystal.disordered_molecule
211
        if self.disorder == 'ordered_sites':
212
            molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?'))
213
214
        may_have_disorder = False
215
        if self.disorder == 'skip':
216
            for a in molecule.atoms:
217
                occ = a.occupancy
218
                if _atom_has_disorder(a.label, occ):
219
                    may_have_disorder = True
220
                    break
221
222
            if not may_have_disorder:
223
                if crystal.has_disorder or entry.has_disorder:
224
                    warnings.warn(f'Skipping {self.current_identifier} as structure is disordered')
225
                    return None
226
227
        if self.remove_hydrogens:
228
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
229
230
        if self.heaviest_component and len(molecule.components) > 1:
231
            molecule = _heaviest_component(molecule)
232
233
        crystal.molecule = molecule
234
235
        # by here all atoms to be removed have been (except via ordered_sites).
236
        # If disorder == 'skip' and there were atom(s) with occ < 1 found
237
        # eariler, we check if all such atoms were removed. If not, skip.
238
        if self.disorder == 'skip' and may_have_disorder:
239
            for a in crystal.disordered_molecule.atoms:
240
                occ = a.occupancy
241
                if _atom_has_disorder(a.label, occ):
242
                    warnings.warn(f'Skipping {self.current_identifier} as structure is disordered')
243
                    return None
244
245
        # if disorder is all_sites, we need to know where disorder is to ignore overlaps
246
        asym_is_disordered = []     # True/False list same length as asym unit
247
        if self.disorder == 'all_sites':
248
            for a in crystal.asymmetric_unit_molecule.atoms:
249
                occ = a.occupancy
250
                if _atom_has_disorder(a.label, occ):
251
                    asym_is_disordered.append(True)
252
                else:
253
                    asym_is_disordered.append(False)
254
255
        # check all atoms have coords. option/default remove unknown sites?
256
        if not molecule.all_atoms_have_sites or \
257
           any(a.fractional_coordinates is None for a in molecule.atoms):
258
            warnings.warn(f'Skipping {self.current_identifier} as some atoms do not have sites')
259
            return None
260
261
        asym_frac_motif = np.array([tuple(a.fractional_coordinates)
262
                                    for a in crystal.asymmetric_unit_molecule.atoms])
263
        asym_frac_motif = np.mod(asym_frac_motif, 1)
264
        asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms]
265
266
        # remove overlapping sites, check sites exist
267
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
268
        if keep_sites is not None:
269
            asym_frac_motif = asym_frac_motif[keep_sites]
270
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
271
272
        if self._has_no_valid_sites(asym_frac_motif):
273
            return None
274
275
        sitesym = crystal.symmetry_operators
276
        if not sitesym:
277
            sitesym = ['x,y,z', ]
278
279
        entry.crystal.molecule = crystal.disordered_molecule
280
        data = {key: func(entry) for key, func in self.extract_data.items()}
281
        periodic_set = self._construct_periodic_set(asym_frac_motif, asym_symbols, sitesym, cell, **data)
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
282
        return periodic_set
283
284
    def _is_site_overlapping(self, new_site, all_sites, inverses, inv):
285
        """Return True (and warn) if new_site overlaps with a site in all_sites."""
286
        diffs1 = np.abs(new_site - all_sites)
287
        diffs2 = np.abs(diffs1 - 1)
288
        mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol,
289
                                    diffs2 <= _Reader.equiv_site_tol),
290
                        axis=-1)
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 2 spaces).
Loading history...
291
292
        if np.any(mask):
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
293
            where_equal = np.argwhere(mask).flatten()
294
            for ind in where_equal:
295
                if inverses[ind] == inv:
296
                    pass
297
                else:
298
                    warnings.warn(
299
                        f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
300
            return True
301
        else:
302
            return False
303
304
    def _validate_sites(self, asym_frac_motif, asym_is_disordered):
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
305
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
306
        site_diffs2 = np.abs(site_diffs1 - 1)
307
        overlapping = np.triu(np.all(
308
            (site_diffs1 <= _Reader.equiv_site_tol) |
309
            (site_diffs2 <= _Reader.equiv_site_tol),
310
            axis=-1), 1)
311
312
        if self.disorder == 'all_sites':
313
            for i, j in np.argwhere(overlapping):
314
                if asym_is_disordered[i] or asym_is_disordered[j]:
315
                    overlapping[i, j] = False
316
317
        if overlapping.any():
318
            warnings.warn(
319
                f'{self.current_identifier} may have overlapping sites; duplicates will be removed')
320
            keep_sites = ~overlapping.any(0)
321
            return keep_sites
322
323
    def _has_no_valid_sites(self, motif):
324
        if motif.shape[0] == 0:
325
            warnings.warn(
326
                f'Skipping {self.current_identifier} as there are no sites with coordinates')
327
            return True
328
        return False
329
330
    def _construct_periodic_set(self, asym_frac_motif, asym_symbols, sitesym, cell, **kwargs):
331
        """Asym motif + symbols + sitesym + cell (+kwargs) --> PeriodicSet"""
332
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
333
        full_types = [asym_symbols[i] for i in inverses]
334
        motif = frac_motif @ cell
335
336
        tags = {
337
            'name': self.current_identifier,
338
            'asymmetric_unit': asym_unit,
339
            'wyckoff_multiplicities': multiplicities,
340
            'types': full_types,
341
            **kwargs
342
        }
343
344
        if self.current_filename:
345
            tags['filename'] = self.current_filename
346
347
        return PeriodicSet(motif, cell, **tags)
348
349
    def expand(
350
            self,
351
            asym_frac_motif: np.ndarray,
352
            sitesym: Sequence[str]
353
    ) -> Tuple[np.ndarray, ...]:
354
        """
355
        Asymmetric unit's fractional coords + sitesyms (as strings)
356
        -->
357
        frac_motif, asym_unit, multiplicities, inverses
358
        """
359
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
360
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
361
        all_sites = []
362
        asym_unit = [0]
363
        multiplicities = []
364
        inverses = []
365
366
        for inv, site in enumerate(asym_frac_motif):
367
            multiplicity = 0
368
369
            for rot, trans in zip(rotations, translations):
370
                site_ = np.mod(np.dot(rot, site) + trans, 1)
371
372
                if not all_sites:
373
                    all_sites.append(site_)
374
                    inverses.append(inv)
375
                    multiplicity += 1
376
                    continue
377
378
                if not self._is_site_overlapping(site_, all_sites, inverses, inv):
379
                    all_sites.append(site_)
380
                    inverses.append(inv)
381
                    multiplicity += 1
382
383
            if multiplicity > 0:
384
                multiplicities.append(multiplicity)
385
                asym_unit.append(len(all_sites))
386
387
        frac_motif = np.array(all_sites)
388
        asym_unit = np.array(asym_unit[:-1])
389
        multiplicities = np.array(multiplicities)
390
        return frac_motif, asym_unit, multiplicities, inverses
391
392
393
def _atom_has_disorder(label, occupancy):
394
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
395
396
def _heaviest_component(molecule):
397
    """Heaviest component (removes all but the heaviest component of the asym unit).
398
    Intended for removing solvents. Probably doesn't play well with disorder"""
399
    component_weights = []
400
    for component in molecule.components:
401
        weight = 0
402
        for a in component.atoms:
403
            if isinstance(a.atomic_weight, (float, int)):
404
                if isinstance(a.occupancy, (float, int)):
405
                    weight += a.occupancy * a.atomic_weight
406
                else:
407
                    weight += a.atomic_weight
408
        component_weights.append(weight)
409
    largest_component_arg = np.argmax(np.array(component_weights))
410
    molecule = molecule.components[largest_component_arg]
411
    return molecule
412
413
def _validate_extract_data(extract_data):
414
    if not isinstance(extract_data, dict):
415
        raise ValueError('extract_data must be a dict of callables')
416
    for key in extract_data:
417
        if not callable(extract_data[key]):
418
            raise ValueError('extract_data must be a dict of callables')
419
        if key in _Reader.reserved_tags:
420
            raise ValueError(f'extract_data includes reserved key {key}')
421
422
def _warning(message, category, filename, lineno, *args, **kwargs):
423
    return f'{filename}:{lineno}: {category.__name__}: {message}\n'
424