Passed
Push — master ( 0f880f...d6a4c8 )
by Daniel
01:47
created

amd.io   F

Complexity

Total Complexity 103

Size/Duplication

Total Lines 581
Duplicated Lines 10.5 %

Importance

Changes 0
Metric Value
wmc 103
eloc 365
dl 61
loc 581
rs 2
c 0
b 0
f 0

11 Methods

Rating   Name   Duplication   Size   Complexity  
B CSDReader._map() 29 29 8
A CSDReader.read_one() 0 3 1
A CSDReader.__iter__() 0 2 1
C CSDReader.__init__() 0 57 9
A CifReader.__iter__() 0 2 1
A CSDReader._ccdc_generator() 0 13 5
C CifReader._map() 32 32 9
A CifReader._folder_generator() 0 6 3
C CifReader.__init__() 0 51 9
A CSDReader.entry() 0 6 1
A CifReader.read_one() 0 3 1

8 Functions

Rating   Name   Duplication   Size   Complexity  
B expand_asym_unit() 0 53 8
A _heaviest_component() 0 16 5
F cifblock_to_periodicset() 0 71 15
B _validate_kwargs() 0 26 8
A atom_has_disorder() 0 2 1
F entry_to_periodicset() 0 66 16
A _custom_warning() 0 2 1
A _unique_sites() 0 7 1

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like amd.io often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Contains I/O tools, including a .CIF reader and CSD reader
2
(``csd-python-api`` only) to extract periodic set representations
3
of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`.
4
5
These intermediate :class:`.periodicset.PeriodicSet` representations can be written
6
to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`.
7
This is much faster than rereading a .CIF and recomputing invariants.
8
"""
9
10
import os
11
import functools
12
import warnings
13
from typing import Callable, Iterable, Sequence, Tuple
14
15
import numpy as np
16
import ase.io.cif
17
import ase.spacegroup.spacegroup
18
19
from .periodicset import PeriodicSet
20
from .utils import cellpar_to_cell
21
22
try:
23
    import ccdc.io
24
    import ccdc.search
25
    _CSD_PYTHON_API_ENABLED = True
26
except (ImportError, RuntimeError) as _:
27
    _CSD_PYTHON_API_ENABLED = False
28
29
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
0 ignored issues
show
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
Unused Code introduced by
The argument filename seems to be unused.
Loading history...
Unused Code introduced by
The argument lineno seems to be unused.
Loading history...
30
    return f'{category.__name__}: {message}\n'
31
32
warnings.formatwarning = _custom_warning
33
34
class ParseError(ValueError):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
35
    pass
36
37
_EQUIV_SITE_TOL = 1e-3
38
_DISORDER_OPTIONS = {'skip', 'ordered_sites', 'all_sites',}
39
_ATOM_SITE_FRACT_TAGS = [
40
    '_atom_site_fract_x',
41
    '_atom_site_fract_y',
42
    '_atom_site_fract_z',]
43
_ATOM_SITE_CARTN_TAGS = [
44
    '_atom_site_cartn_x',
45
    '_atom_site_cartn_y',
46
    '_atom_site_cartn_z',]
47
_SYMOP_TAGS = [
48
    '_space_group_symop_operation_xyz',
49
    '_space_group_symop.operation_xyz',
50
    '_symmetry_equiv_pos_as_xyz',]
51
52
53
class CifReader:
54
    """Read all structures in a .CIF with ``ase`` or ``ccdc``
55
    (``csd-python-api`` only), yielding  :class:`.periodicset.PeriodicSet`
56
    objects which can be passed to :func:`.calculate.AMD` or
57
    :func:`.calculate.PDD`.
58
59
    Examples:
60
61
        ::
62
63
            # Put all crystals in a .CIF in a list
64
            structures = list(amd.CifReader('mycif.cif'))
65
66
            # Reads just one if the .CIF has just one crystal
67
            periodic_set = amd.CifReader('mycif.cif').read_one()
68
69
            # If a folder has several .CIFs each with one crystal, use
70
            structures = list(amd.CifReader('path/to/folder', folder=True))
71
72
            # Make list of AMDs (with k=100) of crystals in a .CIF
73
            amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')]
74
    """
75
76
    def __init__(self,
0 ignored issues
show
best-practice introduced by
Too many arguments (10/5)
Loading history...
77
                 path,
78
                 reader='ase',
79
                 folder=False,
80
                 remove_hydrogens=False,
81
                 disorder='skip',
82
                 heaviest_component=False,
83
                 show_warnings=True,
84
                 extract_data=None,
85
                 include_if=None
86
    ):
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation.
Loading history...
87
  
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
88
        if disorder not in _DISORDER_OPTIONS:
89
            raise ValueError(f'disorder parameter {disorder} must be one of {_DISORDER_OPTIONS}')
90
91
        if reader not in ('ase', 'ccdc'):
92
            raise ValueError(f'Invalid reader {reader}; must be ase or ccdc.')
93
94
        if reader == 'ase' and heaviest_component:
95
            raise NotImplementedError('Parameter heaviest_component not implimented for ase, only ccdc.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
96
97
        extract_data, include_if = _validate_kwargs(extract_data, include_if)
98
99
        self.show_warnings = show_warnings
100
        self.extract_data = extract_data
101
        self.include_if = include_if
102
        self.current_filename = None
103
104
        if reader == 'ase':
105
            extensions = {'cif'}
106
            file_parser = ase.io.cif.parse_cif
107
            converter = functools.partial(cifblock_to_periodicset, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
108
                                          remove_hydrogens=remove_hydrogens,
109
                                          disorder=disorder)
110
111
        elif reader == 'ccdc':
112
            if not _CSD_PYTHON_API_ENABLED:
113
                raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
114
            extensions = ccdc.io.EntryReader.known_suffixes
115
            file_parser = ccdc.io.EntryReader
116
            converter = functools.partial(entry_to_periodicset,
117
                                          remove_hydrogens=remove_hydrogens,
118
                                          disorder=disorder,
119
                                          heaviest_component=heaviest_component)
120
121
        if folder:
122
            generator = self._folder_generator(path, file_parser, extensions)
0 ignored issues
show
introduced by
The variable extensions does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable file_parser does not seem to be defined for all execution paths.
Loading history...
123
        else:
124
            generator = file_parser(path)
125
126
        self._generator = self._map(converter, generator)
0 ignored issues
show
introduced by
The variable converter does not seem to be defined for all execution paths.
Loading history...
127
128
    def __iter__(self):
129
        yield from self._generator
130
131
    def read_one(self):
132
        """Read the next (usually first and only) item."""
133
        return next(iter(self._generator))
134
135
    def _folder_generator(self, path, file_parser, extensions):
136
        for file in os.listdir(path):
137
            suff = os.path.splitext(file)[1][1:]
138
            if suff.lower() in extensions:
139
                self.current_filename = file
140
                yield from file_parser(os.path.join(path, file))
141
142 View Code Duplication
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
143
        """Iterates over iterable, passing items through parser and yielding the result.
144
        Applies warning and include_if filters, catches bad structures and warns.
145
        """
146
147
        for item in iterable:
148
149
            with warnings.catch_warnings(record=True) as warning_msgs:
150
151
                if not self.show_warnings:
152
                    warnings.simplefilter('ignore')
153
154
                if any(not check(item) for check in self.include_if):
155
                    continue
156
157
                try:
158
                    periodic_set = func(item)
159
                except ParseError as err:
160
                    warnings.warn(err, category=UserWarning)
161
                    continue
162
163
            for warning in warning_msgs:
164
                msg = f'{periodic_set.name}: {warning.message}'
165
                warnings.warn(msg, category=warning.category)
166
167
            if self.current_filename:
168
                periodic_set.tags['filename'] = self.current_filename
169
170
            for key, func in self.extract_data.items():
0 ignored issues
show
unused-code introduced by
Redefining argument with the local name 'func'
Loading history...
171
                periodic_set.tags[key] = func(item)
172
173
            yield periodic_set
174
175
176
class CSDReader:
177
    """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects.
178
179
    The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed
180
    to :func:`.calculate.AMD` or :func:`.calculate.PDD`.
181
182
    Examples:
183
184
        Get crystals with refcodes in a list::
185
186
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
187
            structures = list(amd.CSDReader(refcodes))
188
189
        Read refcode families (any whose refcode starts with strings in the list)::
190
191
            refcodes = ['ACSALA', 'HXACAN']
192
            structures = list(amd.CSDReader(refcodes, families=True))
193
194
        Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`::
195
196
            reader = amd.CSDReader()
197
            debxit01 = reader.entry('DEBXIT01')
198
199
            # looping over this generic reader will yield all CSD entries
200
            for periodic_set in reader:
201
                ...
202
203
        Make list of AMD (with k=100) for crystals in these families::
204
205
            refcodes = ['ACSALA', 'HXACAN']
206
            amds = []
207
            for periodic_set in amd.CSDReader(refcodes, families=True):
208
                amds.append(amd.AMD(periodic_set, 100))
209
    """
210
211
    def __init__(self,
0 ignored issues
show
best-practice introduced by
Too many arguments (9/5)
Loading history...
212
                 refcodes=None,
213
                 families=False,
214
                 remove_hydrogens=False,
215
                 disorder='skip',
216
                 heaviest_component=False,
217
                 show_warnings=True,
218
                 extract_data=None,
219
                 include_if=None,
220
    ):
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation.
Loading history...
221
222
        if not _CSD_PYTHON_API_ENABLED:
223
            raise ImportError('Failed to import csd-python-api; check it is installed and licensed.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (101/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
224
225
        if disorder not in _DISORDER_OPTIONS:
226
            raise ValueError(f'disorder parameter {disorder} must be one of {_DISORDER_OPTIONS}')
227
228
        extract_data, include_if = _validate_kwargs(extract_data, include_if)
229
230
        self.show_warnings = show_warnings
231
        self.extract_data = extract_data
232
        self.include_if = include_if
233
        self.current_filename = None
234
235
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
236
            refcodes = None
237
238
        if refcodes is None:
239
            families = False
240
        else:
241
            refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes)
242
243
        # families parameter reads all crystals with ids starting with passed refcodes
244
        if families:
245
            all_refcodes = []
246
            for refcode in refcodes:
247
                query = ccdc.search.TextNumericSearch()
248
                query.add_identifier(refcode)
249
                all_refcodes.extend((hit.identifier for hit in query.search()))
0 ignored issues
show
introduced by
The variable hit does not seem to be defined in case the for loop on line 246 is not entered. Are you sure this can never be the case?
Loading history...
250
251
            # filter to unique refcodes
252
            seen = set()
253
            seen_add = seen.add
254
            refcodes = [
255
                refcode for refcode in all_refcodes
256
                if not (refcode in seen or seen_add(refcode))]
257
258
        self._entry_reader = ccdc.io.EntryReader('CSD')
259
260
        converter = functools.partial(entry_to_periodicset, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
261
                                      remove_hydrogens=remove_hydrogens,
262
                                      disorder=disorder,
263
                                      heaviest_component=heaviest_component)
264
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
265
        generator = self._ccdc_generator(refcodes)
266
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
267
        self._generator = self._map(converter, generator)
268
269
    def __iter__(self):
270
        yield from self._generator
271
272
    def read_one(self):
273
        """Read the next (usually first and only) item."""
274
        return next(iter(self._generator))
275
276
    def entry(self, refcode: str, **kwargs) -> PeriodicSet:
277
        """Read a PeriodicSet given any CSD refcode."""
278
279
        entry = self._entry_reader.entry(refcode)
280
        periodic_set = entry_to_periodicset(entry, **kwargs)
281
        return periodic_set
282
283
    def _ccdc_generator(self, refcodes):
284
        """Generates ccdc Entries from CSD refcodes."""
285
286
        if refcodes is None:
287
            for entry in self._entry_reader:
288
                yield entry
289
        else:
290
            for refcode in refcodes:
291
                try:
292
                    entry = self._entry_reader.entry(refcode)
293
                    yield entry
294
                except RuntimeError:
295
                    warnings.warn(f'Identifier {refcode} not found in database')
296
297 View Code Duplication
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
298
        """Iterates over iterable, passing items through parser and yielding the result.
299
        Applies warning and include_if filters, catches bad structures and warns.
300
        """
301
302
        for item in iterable:
303
304
            with warnings.catch_warnings(record=True) as warning_msgs:
305
306
                if not self.show_warnings:
307
                    warnings.simplefilter('ignore')
308
309
                if any(not check(item) for check in self.include_if):
310
                    continue
311
312
                try:
313
                    periodic_set = func(item)
314
                except ParseError as err:
315
                    warnings.warn(err, category=UserWarning)
316
                    continue
317
318
            for warning in warning_msgs:
319
                msg = f'{periodic_set.name}: {warning.message}'
320
                warnings.warn(msg, category=warning.category)
321
322
            for key, func in self.extract_data.items():
0 ignored issues
show
unused-code introduced by
Redefining argument with the local name 'func'
Loading history...
323
                periodic_set.tags[key] = func(item)
324
325
            yield periodic_set
326
327
328
def entry_to_periodicset(entry,
329
                         remove_hydrogens=False,
330
                         disorder='skip',
331
                         heaviest_component=False
332
) -> PeriodicSet:
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation.
Loading history...
333
    """ccdc.entry.Entry --> PeriodicSet."""
334
335
    crystal = entry.crystal
336
337
    if not entry.has_3d_structure:
338
        raise ParseError(f'Has no 3D structure')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
339
340
    molecule = crystal.disordered_molecule
341
342
    if disorder == 'skip':
343
        if crystal.has_disorder or entry.has_disorder or \
344
            any(atom_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
345
            raise ParseError(f'Has disorder')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
346
347
    elif disorder == 'ordered_sites':
348
        molecule.remove_atoms(a for a in molecule.atoms
349
                                if atom_has_disorder(a.label, a.occupancy))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 2 spaces).
Loading history...
350
351
    if remove_hydrogens:
352
        molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
353
354
    if heaviest_component and len(molecule.components) > 1:
355
        molecule = _heaviest_component(molecule)
356
357
    if not molecule.all_atoms_have_sites or \
358
        any(a.fractional_coordinates is None for a in molecule.atoms):
359
        raise ParseError(f'Has atoms without sites')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
360
361
    crystal.molecule = molecule
362
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
363
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
364
    asym_unit = np.mod(asym_unit, 1)
365
    asym_symbols = [a.atomic_symbol for a in asym_atoms]
366
    cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
367
368
    sitesym = crystal.symmetry_operators
369
    if not sitesym:
370
        sitesym = ['x,y,z', ]
371
372
    if disorder != 'all_sites':
373
        keep_sites = _unique_sites(asym_unit)
374
        if np.any(keep_sites == False):
0 ignored issues
show
introduced by
Comparison to False should be 'not expr'
Loading history...
375
            warnings.warn(f'May have overlapping sites; duplicates will be removed')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
376
        asym_unit = asym_unit[keep_sites]
377
        asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
378
379
    if asym_unit.shape[0] == 0:
380
        raise ParseError(f'Has no valid sites')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
381
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
382
    frac_motif, asym_inds, multiplicities, inverses = expand_asym_unit(asym_unit, sitesym)
383
    full_types = [asym_symbols[i] for i in inverses]
384
    motif = frac_motif @ cell
385
386
    tags = {
387
        'name': entry.identifier,
388
        'asymmetric_unit': asym_inds,
389
        'wyckoff_multiplicities': multiplicities,
390
        'types': full_types,
391
    }
392
393
    return PeriodicSet(motif, cell, **tags)
394
395
396
def cifblock_to_periodicset(block, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
397
                            remove_hydrogens=False,
398
                            disorder='skip'
399
) -> PeriodicSet:
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation.
Loading history...
400
    """ase.io.cif.CIFBlock --> PeriodicSet."""
401
402
    cell = block.get_cell().array
403
404
    # asymmetric unit fractional coords
405
    asym_unit = [block.get(name) for name in _ATOM_SITE_FRACT_TAGS]
406
    if None in asym_unit:
407
        asym_motif = [block.get(name) for name in _ATOM_SITE_CARTN_TAGS]
408
        if None in asym_motif:
409
            raise ParseError(f'Has no sites')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
410
        asym_unit = np.array(asym_motif) @ np.linalg.inv(cell)
411
    asym_unit = np.mod(np.array(asym_unit).T, 1)
412
413
    try:
414
        asym_symbols = block.get_symbols()
415
    except ase.io.cif.NoStructureData as _:
416
        asym_symbols = ['Unknown' for _ in range(len(asym_unit))]
417
418
    sitesym = ['x,y,z', ]
419
    for tag in _SYMOP_TAGS:
420
        if tag in block:
421
            sitesym = block[tag]
422
            break
423
    if isinstance(sitesym, str):
424
        sitesym = [sitesym]
425
426
    remove_sites = []
427
428
    occupancies = block.get('_atom_site_occupancy')
429
    labels = block.get('_atom_site_label')
430
    if occupancies is not None:
431
        if disorder == 'skip':
432
            if any(atom_has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)):
433
                raise ParseError(f'Has disorder')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
434
        elif disorder == 'ordered_sites':
435
            remove_sites.extend(
436
                (i for i, (lab, occ) in enumerate(zip(labels, occupancies))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
437
                    if atom_has_disorder(lab, occ)))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 3 spaces).
Loading history...
438
439
    if remove_hydrogens:
440
        remove_sites.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
441
442
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
443
    asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove_sites]
444
445
    if disorder != 'all_sites':
446
        keep_sites = _unique_sites(asym_unit)
447
        if np.any(keep_sites == False):
0 ignored issues
show
introduced by
Comparison to False should be 'not expr'
Loading history...
448
            warnings.warn(f'May have overlapping sites; duplicates will be removed')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
449
        asym_unit = asym_unit[keep_sites]
450
        asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
451
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
452
    if asym_unit.shape[0] == 0:
453
        raise ParseError(f'Has no valid sites')
0 ignored issues
show
introduced by
Using an f-string that does not have any interpolated variables
Loading history...
454
455
    frac_motif, asym_inds, multiplicities, inverses = expand_asym_unit(asym_unit, sitesym)
456
    full_types = [asym_symbols[i] for i in inverses]
457
    motif = frac_motif @ cell
458
459
    tags = {
460
        'name': block.name,
461
        'asymmetric_unit': asym_inds,
462
        'wyckoff_multiplicities': multiplicities,
463
        'types': full_types,
464
    }
465
466
    return PeriodicSet(motif, cell, **tags)
467
468
469
def expand_asym_unit(
470
    asym_unit: np.ndarray, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
471
    sitesym: Sequence[str]
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
472
) -> Tuple[np.ndarray, ...]:
473
    """
474
    Asymmetric unit's fractional coords + sitesyms (as strings)
475
    -->
476
    frac motif, asym unit inds, multiplicities, inverses
477
    """
478
479
    rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
480
    all_sites = []
481
    asym_inds = [0]
482
    multiplicities = []
483
    inverses = []
484
485
    for inv, site in enumerate(asym_unit):
486
        multiplicity = 0
487
488
        for rot, trans in zip(rotations, translations):
489
            site_ = np.mod(np.dot(rot, site) + trans, 1)
490
491
            if not all_sites:
492
                all_sites.append(site_)
493
                inverses.append(inv)
494
                multiplicity += 1
495
                continue
496
497
            # check if site_ overlaps with existing sites
498
            diffs1 = np.abs(site_ - all_sites)
499
            diffs2 = np.abs(diffs1 - 1)
500
            mask = np.all((diffs1 <= _EQUIV_SITE_TOL) | (diffs2 <= _EQUIV_SITE_TOL), axis=-1)
501
502
            if np.any(mask):
503
                where_equal = np.argwhere(mask).flatten()
504
                for ind in where_equal:
505
                    if inverses[ind] == inv:
506
                        pass
507
                    else:
508
                        warnings.warn(f'Equivalent sites at positions {inverses[ind]}, {inv}')
509
            else:
510
                all_sites.append(site_)
511
                inverses.append(inv)
512
                multiplicity += 1
513
514
        if multiplicity > 0:
515
            multiplicities.append(multiplicity)
516
            asym_inds.append(len(all_sites))
517
518
    frac_motif = np.array(all_sites)
519
    asym_inds = np.array(asym_inds[:-1])
520
    multiplicities = np.array(multiplicities)
521
    return frac_motif, asym_inds, multiplicities, inverses
522
523
524
def atom_has_disorder(label, occupancy):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
525
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
526
527
528
def _unique_sites(asym_unit):
529
    site_diffs1 = np.abs(asym_unit[:, None] - asym_unit)
530
    site_diffs2 = np.abs(site_diffs1 - 1)
531
    overlapping = np.triu(np.all(
532
        (site_diffs1 <= _EQUIV_SITE_TOL) | (site_diffs2 <= _EQUIV_SITE_TOL),
533
        axis=-1), 1)
534
    return ~overlapping.any(axis=0)
535
536
537
def _heaviest_component(molecule):
538
    """Heaviest component (removes all but the heaviest component of the asym unit).
539
    Intended for removing solvents. Probably doesn't play well with disorder"""
540
    component_weights = []
541
    for component in molecule.components:
542
        weight = 0
543
        for a in component.atoms:
544
            if isinstance(a.atomic_weight, (float, int)):
545
                if isinstance(a.occupancy, (float, int)):
546
                    weight += a.occupancy * a.atomic_weight
547
                else:
548
                    weight += a.atomic_weight
549
        component_weights.append(weight)
550
    largest_component_ind = np.argmax(np.array(component_weights))
551
    molecule = molecule.components[largest_component_ind]
552
    return molecule
553
554
555
def _validate_kwargs(extract_data, include_if):
556
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
557
    reserved_tags = {'motif', 'cell', 'name',
558
                     'asymmetric_unit', 'wyckoff_multiplicities',
559
                     'types', 'filename'}
560
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
561
    if extract_data is None:
562
        extract_data = {}
563
    else:
564
        if not isinstance(extract_data, dict):
565
            raise ValueError('extract_data must be a dict of callables')
566
        for key in extract_data:
567
            if not callable(extract_data[key]):
568
                raise ValueError('extract_data must be a dict of callables')
569
            if key in reserved_tags:
570
                raise ValueError(f'extract_data includes reserved key {key}')
571
        extract_data = extract_data
0 ignored issues
show
introduced by
Assigning the same variable 'extract_data' to itself
Loading history...
572
573
    if include_if is None:
574
        include_if = ()
575
    elif not all(callable(func) for func in include_if):
576
        raise ValueError('include_if must be a list of callables')
577
    else:
578
        include_if = include_if
0 ignored issues
show
introduced by
Assigning the same variable 'include_if' to itself
Loading history...
579
580
    return extract_data, include_if
581