Passed
Push — master ( 8b2ed0...4823d2 )
by Daniel
03:53
created

amd.io._heaviest_component()   A

Complexity

Conditions 5

Size

Total Lines 16
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 16
rs 9.2833
c 0
b 0
f 0
cc 5
nop 1
1
"""Contains I/O tools, including a .CIF reader and CSD reader
2
(``csd-python-api`` only) to extract periodic set representations
3
of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`.
4
5
These intermediate :class:`.periodicset.PeriodicSet` representations can be written
6
to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`.
7
This is much faster than rereading a .CIF and recomputing invariants.
8
"""
9
10
import os
11
import functools
12
import warnings
13
from typing import Callable, Iterable, Sequence, Tuple
14
15
import numpy as np
16
import ase.io.cif
17
import ase.data
18
import ase.spacegroup.spacegroup
19
20
from . import utils
21
from .periodicset import PeriodicSet
22
23
try:
24
    import ccdc.io
25
    import ccdc.search
26
    _CSD_PYTHON_API_ENABLED = True
27
except (ImportError, RuntimeError) as _:
28
    _CSD_PYTHON_API_ENABLED = False
29
30
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
0 ignored issues
show
Unused Code introduced by
The argument filename seems to be unused.
Loading history...
Unused Code introduced by
The argument lineno seems to be unused.
Loading history...
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
31
    return f'{category.__name__}: {message}\n'
32
33
warnings.formatwarning = _custom_warning
34
35
_EQUIV_SITE_TOL = 1e-3
36
_ATOM_SITE_FRACT_TAGS = [
37
    '_atom_site_fract_x',
38
    '_atom_site_fract_y',
39
    '_atom_site_fract_z',]
40
_ATOM_SITE_CARTN_TAGS = [
41
    '_atom_site_cartn_x',
42
    '_atom_site_cartn_y',
43
    '_atom_site_cartn_z',]
44
_SYMOP_TAGS = [
45
    '_space_group_symop_operation_xyz',
46
    '_space_group_symop.operation_xyz',
47
    '_symmetry_equiv_pos_as_xyz',]
48
49
50
class _ParseError(ValueError):
51
    """Raised when an item cannot be parsed into a periodic set."""
52
    pass
0 ignored issues
show
Unused Code introduced by
Unnecessary pass statement
Loading history...
53
54
55
class _Reader:
56
    """Base Reader class. Contains parsers for converting ase CifBlock
57
    and ccdc Entry objects to PeriodicSets.
58
    Intended to be inherited and then a generator set to self._generator.
59
    First make a new method for _Reader converting object to PeriodicSet
60
    (e.g. named _X_to_PSet). Then make this class outline:
61
    class XReader(_Reader):
62
        def __init__(self, ..., **kwargs):
63
        super().__init__(**kwargs)
64
        # setup and checks
65
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
66
        # set self._generator like this
67
        self._generator = self._map(iterable, self._X_to_PSet)
68
    """
69
70
    _DISORDER_OPTIONS = {'skip', 'ordered_sites', 'all_sites'}
71
72
    def __init__(
73
            self,
74
            remove_hydrogens=False,
75
            disorder='skip',
76
            heaviest_component=False,
77
            show_warnings=True,
78
    ):
79
80
        if disorder not in _Reader._DISORDER_OPTIONS:
81
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader._DISORDER_OPTIONS}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
82
83
        self.remove_hydrogens = remove_hydrogens
84
        self.disorder = disorder
85
        self.heaviest_component = heaviest_component
86
        self.show_warnings = show_warnings
87
        self.current_filename = None
88
        self._generator = []
89
90
    def __iter__(self):
91
        yield from self._generator
92
93
    def read_one(self):
94
        """Read the next (or first) item."""
95
        return next(iter(self._generator))
96
97
    def _map(self, func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
98
        """Iterates over iterable, passing items through parser and yielding the result.
99
        Applies warning and include_if filters, catches bad structures and warns.
100
        """
101
102
        if not self.show_warnings:
103
            warnings.simplefilter('ignore')
104
105
        for item in iterable:
106
107
            with warnings.catch_warnings(record=True) as warning_msgs:
108
109
                parse_failed = False
110
                try:
111
                    periodic_set = func(item)
112
                except _ParseError as err:
113
                    parse_failed = str(err)
114
115
            if parse_failed:
116
                warnings.warn(parse_failed)
117
                continue
118
119
            for warning in warning_msgs:
120
                msg = f'{periodic_set.name}: {warning.message}'
121
                warnings.warn(msg, category=warning.category)
122
123
            if self.current_filename:
124
                periodic_set.tags['filename'] = self.current_filename
125
126
            yield periodic_set
127
128
129
class CifReader(_Reader):
130
    """Read all structures in a .CIF with ``ase`` or ``ccdc``
131
    (``csd-python-api`` only), yielding  :class:`.periodicset.PeriodicSet`
132
    objects which can be passed to :func:`.calculate.AMD` or
133
    :func:`.calculate.PDD`.
134
135
    Examples:
136
137
        ::
138
139
            # Put all crystals in a .CIF in a list
140
            structures = list(amd.CifReader('mycif.cif'))
141
142
            # Reads just one if the .CIF has just one crystal
143
            periodic_set = amd.CifReader('mycif.cif').read_one()
144
145
            # If a folder has several .CIFs each with one crystal, use
146
            structures = list(amd.CifReader('path/to/folder', folder=True))
147
148
            # Make list of AMDs (with k=100) of crystals in a .CIF
149
            amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')]
150
    """
151
152
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (8/5)
Loading history...
153
            self,
154
            path,
155
            reader='ase',
156
            folder=False,
157
            remove_hydrogens=False,
158
            disorder='skip',
159
            heaviest_component=False,
160
            show_warnings=True,
161
    ):
162
163
        super().__init__(
164
            remove_hydrogens=remove_hydrogens,
165
            disorder=disorder,
166
            heaviest_component=heaviest_component,
167
            show_warnings=show_warnings,
168
        )
169
170
        if reader not in ('ase', 'ccdc'):
171
            raise ValueError(f'Invalid reader {reader}; must be ase or ccdc.')
172
173
        if reader == 'ase' and heaviest_component:
174
            raise NotImplementedError('Parameter heaviest_component not implimented for ase, only ccdc.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
175
176
        if reader == 'ase':
177
            extensions = {'cif'}
178
            file_parser = ase.io.cif.parse_cif
179
            converter = functools.partial(cifblock_to_periodicset,
180
                                          remove_hydrogens=remove_hydrogens,
181
                                          disorder=disorder)
182
183
        elif reader == 'ccdc':
184
            if not _CSD_PYTHON_API_ENABLED:
185
                raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
186
            extensions = ccdc.io.EntryReader.known_suffixes
187
            file_parser = ccdc.io.EntryReader
188
            converter = functools.partial(entry_to_periodicset,
189
                                          remove_hydrogens=remove_hydrogens,
190
                                          disorder=disorder,
191
                                          heaviest_component=heaviest_component)
192
193
        if folder:
194
            generator = self._folder_generator(path, file_parser, extensions)
0 ignored issues
show
introduced by
The variable extensions does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable file_parser does not seem to be defined for all execution paths.
Loading history...
195
        else:
196
            generator = file_parser(path)
197
198
        self._generator = self._map(converter, generator)
0 ignored issues
show
introduced by
The variable converter does not seem to be defined for all execution paths.
Loading history...
199
200
    def _folder_generator(self, path, file_parser, extensions):
201
        for file in os.listdir(path):
202
            suff = os.path.splitext(file)[1][1:]
203
            if suff.lower() in extensions:
204
                self.current_filename = file
205
                yield from file_parser(os.path.join(path, file))
206
207
208
class CSDReader(_Reader):
209
    """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects.
210
211
    The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed
212
    to :func:`.calculate.AMD` or :func:`.calculate.PDD`.
213
214
    Examples:
215
216
        Get crystals with refcodes in a list::
217
218
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
219
            structures = list(amd.CSDReader(refcodes))
220
221
        Read refcode families (any whose refcode starts with strings in the list)::
222
223
            refcodes = ['ACSALA', 'HXACAN']
224
            structures = list(amd.CSDReader(refcodes, families=True))
225
226
        Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`::
227
228
            reader = amd.CSDReader()
229
            debxit01 = reader.entry('DEBXIT01')
230
231
            # looping over this generic reader will yield all CSD entries
232
            for periodic_set in reader:
233
                ...
234
235
        Make list of AMD (with k=100) for crystals in these families::
236
237
            refcodes = ['ACSALA', 'HXACAN']
238
            amds = []
239
            for periodic_set in amd.CSDReader(refcodes, families=True):
240
                amds.append(amd.AMD(periodic_set, 100))
241
    """
242
243
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
244
            self,
245
            refcodes=None,
246
            families=False,
247
            remove_hydrogens=False,
248
            disorder='skip',
249
            heaviest_component=False,
250
            show_warnings=True,
251
    ):
252
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
253
        super().__init__(
254
            remove_hydrogens=remove_hydrogens,
255
            disorder=disorder,
256
            heaviest_component=heaviest_component,
257
            show_warnings=show_warnings,
258
        )
259
260
        if not _CSD_PYTHON_API_ENABLED:
261
            raise ImportError('Failed to import csd-python-api; check it is installed and licensed.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (101/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
262
263
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
264
            refcodes = None
265
266
        if refcodes is None:
267
            families = False
268
        else:
269
            refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes)
270
271
        # families parameter reads all crystals with ids starting with passed refcodes
272
        if families:
273
            all_refcodes = []
274
            for refcode in refcodes:
275
                query = ccdc.search.TextNumericSearch()
276
                query.add_identifier(refcode)
277
                all_refcodes.extend((hit.identifier for hit in query.search()))
0 ignored issues
show
introduced by
The variable hit does not seem to be defined in case the for loop on line 274 is not entered. Are you sure this can never be the case?
Loading history...
278
279
            # filter to unique refcodes
280
            seen = set()
281
            seen_add = seen.add
282
            refcodes = [
283
                refcode for refcode in all_refcodes
284
                if not (refcode in seen or seen_add(refcode))]
285
286
        self._entry_reader = ccdc.io.EntryReader('CSD')
287
288
        converter = functools.partial(entry_to_periodicset,
289
                                      remove_hydrogens=remove_hydrogens,
290
                                      disorder=disorder,
291
                                      heaviest_component=heaviest_component)
292
293
        generator = self._ccdc_generator(refcodes)
294
        self._generator = self._map(converter, generator)
295
296
    def entry(self, refcode: str, **kwargs) -> PeriodicSet:
297
        """Read a PeriodicSet given any CSD refcode."""
298
299
        entry = self._entry_reader.entry(refcode)
300
        periodic_set = entry_to_periodicset(entry, **kwargs)
301
        return periodic_set
302
303
    def _ccdc_generator(self, refcodes):
304
        """Generates ccdc Entries from CSD refcodes."""
305
306
        if refcodes is None:
307
            for entry in self._entry_reader:
308
                yield entry
309
        else:
310
            for refcode in refcodes:
311
                try:
312
                    entry = self._entry_reader.entry(refcode)
313
                    yield entry
314
                except RuntimeError:    # if self.show_warnings?
315
                    warnings.warn(f'Identifier {refcode} not found in database')
316
317
318
def entry_to_periodicset(
319
        entry,
320
        remove_hydrogens=False,
321
        disorder='skip',
322
        heaviest_component=False
323
) -> PeriodicSet:
324
    """ccdc.entry.Entry --> PeriodicSet."""
325
326
    crystal = entry.crystal
327
328
    if not entry.has_3d_structure:
329
        raise _ParseError(f'{crystal.identifier}: Has no 3D structure')
330
331
    molecule = crystal.disordered_molecule
332
333
    if disorder == 'skip':
334
        if crystal.has_disorder or entry.has_disorder or \
335
            any(_atom_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
336
            raise _ParseError(f'{crystal.identifier}: Has disorder')
337
338
    elif disorder == 'ordered_sites':
339
        molecule.remove_atoms(a for a in molecule.atoms
340
                              if _atom_has_disorder(a.label, a.occupancy))
341
342
    if remove_hydrogens:
343
        molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
344
345
    if heaviest_component and len(molecule.components) > 1:
346
        molecule = _heaviest_component(molecule)
347
348
    if not molecule.all_atoms_have_sites or \
349
        any(a.fractional_coordinates is None for a in molecule.atoms):
350
        raise _ParseError(f'{crystal.identifier}: Has atoms without sites')
351
352
    crystal.molecule = molecule
353
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
354
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
355
    asym_unit = np.mod(asym_unit, 1)
356
    asym_types = [a.atomic_number for a in asym_atoms]
357
    cell = utils.cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
358
359
    sitesym = crystal.symmetry_operators
360
    if not sitesym:
361
        sitesym = ['x,y,z', ]
362
363
    if disorder != 'all_sites':
364
        keep_sites = _unique_sites(asym_unit)
365
        if not np.all(keep_sites):
366
            warnings.warn(f'{crystal.identifier}: May have overlapping sites; duplicates will be removed')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (106/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
367
        asym_unit = asym_unit[keep_sites]
368
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
369
370
    if asym_unit.shape[0] == 0:
371
        raise _ParseError(f'{crystal.identifier}: Has no valid sites')
372
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
373
    frac_motif, asym_inds, multiplicities, inverses = expand_asym_unit(asym_unit, sitesym)
374
    full_types = np.array([asym_types[i] for i in inverses])
375
    motif = frac_motif @ cell
376
377
    tags = {
378
        'name': crystal.identifier,
379
        'asymmetric_unit': asym_inds,
380
        'wyckoff_multiplicities': multiplicities,
381
        'types': full_types
382
    }
383
384
    return PeriodicSet(motif, cell, **tags)
385
386
387
def cifblock_to_periodicset(
388
        block,
389
        remove_hydrogens=False,
390
        disorder='skip'
391
) -> PeriodicSet:
392
    """ase.io.cif.CIFBlock --> PeriodicSet."""
393
394
    cell = block.get_cell().array
395
396
    # asymmetric unit fractional coords
397
    asym_unit = [block.get(name) for name in _ATOM_SITE_FRACT_TAGS]
398
    if None in asym_unit:
399
        asym_motif = [block.get(name) for name in _ATOM_SITE_CARTN_TAGS]
400
        if None in asym_motif:
401
            raise _ParseError(f'{block.name}: Has no sites')
402
        asym_unit = np.array(asym_motif) @ np.linalg.inv(cell)
403
    asym_unit = np.mod(np.array(asym_unit).T, 1)
404
405
    try:
406
        asym_types = [ase.data.atomic_numbers[s] for s in block.get_symbols()]
407
    except ase.io.cif.NoStructureData as _:
408
        asym_types = [0 for _ in range(len(asym_unit))]
409
410
    sitesym = ['x,y,z', ]
411
    for tag in _SYMOP_TAGS:
412
        if tag in block:
413
            sitesym = block[tag]
414
            break
415
    if isinstance(sitesym, str):
416
        sitesym = [sitesym]
417
418
    remove_sites = []
419
420
    occupancies = block.get('_atom_site_occupancy')
421
    labels = block.get('_atom_site_label')
422
    if occupancies is not None:
423
        if disorder == 'skip':
424
            if any(_atom_has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)):
425
                raise _ParseError(f'{block.name}: Has disorder')
426
        elif disorder == 'ordered_sites':
427
            remove_sites.extend(
428
                (i for i, (lab, occ) in enumerate(zip(labels, occupancies))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
429
                    if _atom_has_disorder(lab, occ)))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 3 spaces).
Loading history...
430
431
    if remove_hydrogens:
432
        remove_sites.extend((i for i, sym in enumerate(asym_types) if sym in 'HD'))
433
434
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
435
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
436
437
    if disorder != 'all_sites':
438
        keep_sites = _unique_sites(asym_unit)
439
        if not np.all(keep_sites):
440
            warnings.warn(f'{block.name}: May have overlapping sites; duplicates will be removed')
441
        asym_unit = asym_unit[keep_sites]
442
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
443
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
444
    if asym_unit.shape[0] == 0:
445
        raise _ParseError(f'{block.name}: Has no valid sites')
446
447
    frac_motif, asym_inds, multiplicities, inverses = expand_asym_unit(asym_unit, sitesym)
448
    full_types = np.array([asym_types[i] for i in inverses])
449
    motif = frac_motif @ cell
450
451
    tags = {
452
        'name': block.name,
453
        'asymmetric_unit': asym_inds,
454
        'wyckoff_multiplicities': multiplicities,
455
        'types': full_types
456
    }
457
458
    return PeriodicSet(motif, cell, **tags)
459
460
461
def expand_asym_unit(
462
        asym_unit: np.ndarray,
463
        sitesym: Sequence[str]
464
) -> Tuple[np.ndarray, ...]:
465
    """
466
    Asymmetric unit's fractional coords + site symmetries (as strings)
467
    -->
468
    fractional motif, asymmetric unit indices, multiplicities and inverses.
469
    """
470
471
    rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
472
    all_sites = []
473
    asym_inds = [0]
474
    multiplicities = []
475
    inverses = []
476
477
    for inv, site in enumerate(asym_unit):
478
        multiplicity = 0
479
480
        for rot, trans in zip(rotations, translations):
481
            site_ = np.mod(np.dot(rot, site) + trans, 1)
482
483
            if not all_sites:
484
                all_sites.append(site_)
485
                inverses.append(inv)
486
                multiplicity += 1
487
                continue
488
489
            # check if site_ overlaps with existing sites
490
            diffs1 = np.abs(site_ - all_sites)
491
            diffs2 = np.abs(diffs1 - 1)
492
            mask = np.all((diffs1 <= _EQUIV_SITE_TOL) | (diffs2 <= _EQUIV_SITE_TOL), axis=-1)
493
494
            if np.any(mask):
495
                where_equal = np.argwhere(mask).flatten()
496
                for ind in where_equal:
497
                    if inverses[ind] == inv:
498
                        pass
499
                    else:
500
                        warnings.warn(f'Equivalent sites at positions {inverses[ind]}, {inv}')
501
            else:
502
                all_sites.append(site_)
503
                inverses.append(inv)
504
                multiplicity += 1
505
506
        if multiplicity > 0:
507
            multiplicities.append(multiplicity)
508
            asym_inds.append(len(all_sites))
509
510
    frac_motif = np.array(all_sites)
511
    asym_inds = np.array(asym_inds[:-1])
512
    multiplicities = np.array(multiplicities)
513
    return frac_motif, asym_inds, multiplicities, inverses
514
515
516
def _atom_has_disorder(label, occupancy):
517
    """Return True if atom has disorder and False otherwise."""
518
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
519
520
521
def _unique_sites(asym_unit):
522
    site_diffs1 = np.abs(asym_unit[:, None] - asym_unit)
523
    site_diffs2 = np.abs(site_diffs1 - 1)
524
    overlapping = np.triu(np.all(
525
        (site_diffs1 <= _EQUIV_SITE_TOL) | (site_diffs2 <= _EQUIV_SITE_TOL),
526
        axis=-1), 1)
527
    return ~overlapping.any(axis=0)
528
529
530
def _heaviest_component(molecule):
531
    """Heaviest component (removes all but the heaviest component of the asym unit).
532
    Intended for removing solvents. Probably doesn't play well with disorder"""
533
    component_weights = []
534
    for component in molecule.components:
535
        weight = 0
536
        for a in component.atoms:
537
            if isinstance(a.atomic_weight, (float, int)):
538
                if isinstance(a.occupancy, (float, int)):
539
                    weight += a.occupancy * a.atomic_weight
540
                else:
541
                    weight += a.atomic_weight
542
        component_weights.append(weight)
543
    largest_component_ind = np.argmax(np.array(component_weights))
544
    molecule = molecule.components[largest_component_ind]
545
    return molecule
546