Passed
Push — master ( 6a3dea...1d1d87 )
by Daniel
01:50
created

amd.io._Reader._validate_sites()   B

Complexity

Conditions 7

Size

Total Lines 19
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 19
rs 8
c 0
b 0
f 0
cc 7
nop 3
1
"""Contains I/O tools, including a .CIF reader and CSD reader
2
(``csd-python-api`` only) to extract periodic set representations
3
of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`.
4
5
These intermediate :class:`.periodicset.PeriodicSet` representations can be written
6
to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`.
7
This is much faster than rereading a .CIF and recomputing invariants.
8
"""
9
10
import os
11
import warnings
12
from typing import Callable, Iterable, Sequence, Tuple, Optional
13
14
import numpy as np
15
import ase.spacegroup.spacegroup    # parse_sitesym
16
import ase.io.cif
17
import h5py
18
19
from .periodicset import PeriodicSet
20
from .utils import _extend_signature, cellpar_to_cell
21
22
try:
23
    import ccdc.io       # EntryReader
24
    import ccdc.search   # TextNumericSearch
25
    _CCDC_ENABLED = True
26
except (ImportError, RuntimeError) as exception:
27
    _CCDC_ENABLED = False
28
29
30
def _warning(message, category, filename, lineno, *args, **kwargs):
0 ignored issues
show
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
31
    return f'{filename}:{lineno}: {category.__name__}: {message}\n'
32
33
warnings.formatwarning = _warning
34
35
36
def _atom_has_disorder(label, occupancy):
37
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
38
39
40
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (9/7)
Loading history...
41
    """Base Reader class. Contains parsers for converting ase CifBlock
42
    and ccdc Entry objects to PeriodicSets.
43
44
    Intended use:
45
46
    First make a new method for _Reader converting object to PeriodicSet
47
    (e.g. named _X_to_PSet). Then make this class outline:
48
49
    class XReader(_Reader):
50
        def __init__(self, ..., **kwargs):
51
52
        super().__init__(**kwargs)
53
54
        # setup and checks
55
56
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
57
58
        # set self._generator like this
59
        self._generator = self._read(iterable, self._X_to_PSet)
60
    """
61
62
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
63
    reserved_tags = {
64
        'motif',
65
        'cell',
66
        'name',
67
        'asymmetric_unit',
68
        'wyckoff_multiplicities',
69
        'types',
70
        'filename',}
71
    atom_site_fract_tags = [
72
        '_atom_site_fract_x',
73
        '_atom_site_fract_y',
74
        '_atom_site_fract_z',]
75
    atom_site_cartn_tags = [
76
        '_atom_site_cartn_x',
77
        '_atom_site_cartn_y',
78
        '_atom_site_cartn_z',]
79
    symop_tags = [
80
        '_space_group_symop_operation_xyz',
81
        '_space_group_symop.operation_xyz',
82
        '_symmetry_equiv_pos_as_xyz',]
83
84
    equiv_site_tol = 1e-3
85
86
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
87
            self,
88
            remove_hydrogens=False,
89
            disorder='skip',
90
            heaviest_component=False,
91
            show_warnings=True,
92
            extract_data=None,
93
            include_if=None):
94
95
        # settings
96
        if disorder not in _Reader.disorder_options:
97
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
98
99
        if extract_data:
100
            if not isinstance(extract_data, dict):
101
                raise ValueError('extract_data must be a dict with callable values')
102
            for key in extract_data:
103
                if not callable(extract_data[key]):
104
                    raise ValueError('extract_data must be a dict with callable values')
105
                if key in _Reader.reserved_tags:
106
                    raise ValueError(f'extract_data includes reserved key {key}')
107
108
        if include_if:
109
            for func in include_if:
110
                if not callable(func):
111
                    raise ValueError('include_if must be a list of callables')
112
113
        self.remove_hydrogens = remove_hydrogens
114
        self.disorder = disorder
115
        self.heaviest_component = heaviest_component
116
        self.extract_data = extract_data
117
        self.include_if = include_if
118
        self.show_warnings = show_warnings
119
        self.current_identifier = None
120
        self.current_filename = None
121
        self._generator = []
122
123
    def __iter__(self):
124
        yield from self._generator
125
126
    def read_one(self):
127
        """Read the next (or first) item."""
128
        return next(iter(self._generator))
129
130
    # basically the builtin map, but skips items if the function returned None.
131
    # The object returned by this function (Iterable of PeriodicSets) is set to
132
    # self._generator; then iterating over the Reader iterates over
133
    # self._generator.
134
    @staticmethod
135
    def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
136
        """Iterates over iterable, passing items through parser and
137
        yielding the result if it is not None.
138
        """
139
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
140
        for item in iterable:
141
            res = func(item)
142
            if res is not None:
143
                yield res
144
145
    def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet:
146
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
147
148
        # skip if structure does not pass checks in include_if
149
        if self.include_if:
150
            if not all(check(block) for check in self.include_if):
151
                return None
152
153
        # read name, cell, asym motif and atomic symbols
154
        self.current_identifier = block.name
155
        cell = block.get_cell().array
156
        asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
157
        if None in asym_frac_motif:
158
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
159
            if None in asym_motif:
160
                if self.show_warnings:
161
                    warnings.warn(
162
                        f'Skipping {self.current_identifier} as coordinates were not found')
163
                return None
164
            asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
165
        asym_frac_motif = np.array(asym_frac_motif).T
166
167
        try:
168
            asym_symbols = block.get_symbols()
169
        except ase.io.cif.NoStructureData as _:
170
            asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))]
171
172
        # indices of sites to remove
173
        remove = []
174
        if self.remove_hydrogens:
175
            remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
176
177
        # find disordered sites
178
        asym_is_disordered = []
179
        occupancies = block.get('_atom_site_occupancy')
180
        labels = block.get('_atom_site_label')
181
        if occupancies is not None:
182
            disordered = []     # indices where there is disorder
183
            for i, (occ, label) in enumerate(zip(occupancies, labels)):
184
                if _atom_has_disorder(label, occ):
185
                    if i not in remove:
186
                        disordered.append(i)
187
                        asym_is_disordered.append(True)
188
                else:
189
                    asym_is_disordered.append(False)
190
191
            if self.disorder == 'skip' and len(disordered) > 0:
192
                if self.show_warnings:
193
                    warnings.warn(
194
                        f'Skipping {self.current_identifier} as structure is disordered')
195
                return None
196
197
            if self.disorder == 'ordered_sites':
198
                remove.extend(disordered)
199
200
        # remove sites
201
        asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1)
202
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove]
203
        asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove]
204
205
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
206
        if keep_sites is not None:
207
            asym_frac_motif = asym_frac_motif[keep_sites]
208
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
209
210
        if self._has_no_valid_sites(asym_frac_motif):
211
            return None
212
213
        sitesym = ['x,y,z', ]
214
        for tag in _Reader.symop_tags:
215
            if tag in block:
216
                sitesym = block[tag]
217
                break
218
219
        if isinstance(sitesym, str):
220
            sitesym = [sitesym]
221
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
222
        return self._construct_periodic_set(block, asym_frac_motif, asym_symbols, sitesym, cell)
223
224
225
    def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet:
0 ignored issues
show
best-practice introduced by
Too many return statements (7/6)
Loading history...
226
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
227
228
        # skip if structure does not pass checks in include_if
229
        if self.include_if:
230
            if not all(check(entry) for check in self.include_if):
231
                return None
232
233
        self.current_identifier = entry.identifier
234
        # structure must pass this test
235
        if not entry.has_3d_structure:
236
            if self.show_warnings:
237
                warnings.warn(
238
                    f'Skipping {self.current_identifier} as entry has no 3D structure')
239
            return None
240
241
        crystal = entry.crystal
242
243
        # first disorder check, if skipping. If occ == 1 for all atoms but the entry
244
        # or crystal is listed as having disorder, skip (can't know where disorder is).
245
        # If occ != 1 for any atoms, we wait to see if we remove them before skipping.
246
        molecule = crystal.disordered_molecule
247
        if self.disorder == 'ordered_sites':
248
            molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?'))
249
250
        may_have_disorder = False
251
        if self.disorder == 'skip':
252
            for a in molecule.atoms:
253
                occ = a.occupancy
254
                if _atom_has_disorder(a.label, occ):
255
                    may_have_disorder = True
256
                    break
257
258
            if not may_have_disorder:
259
                if crystal.has_disorder or entry.has_disorder:
260
                    if self.show_warnings:
261
                        warnings.warn(f'Skipping {self.current_identifier} as structure is disordered')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (103/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
262
                    return None
263
264
        if self.remove_hydrogens:
265
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
266
267
        if self.heaviest_component:
268
            molecule = self._heaviest_component(molecule)
269
270
        crystal.molecule = molecule
271
272
        # by here all atoms to be removed have been (except via ordered_sites).
273
        # If disorder == 'skip' and there were atom(s) with occ < 1 found
274
        # eariler, we check if all such atoms were removed. If not, skip.
275
        if self.disorder == 'skip' and may_have_disorder:
276
            for a in crystal.disordered_molecule.atoms:
277
                occ = a.occupancy
278
                if _atom_has_disorder(a.label, occ):
279
                    if self.show_warnings:
280
                        warnings.warn(
281
                            f'Skipping {self.current_identifier} as structure is disordered')
282
                    return None
283
284
        # if disorder is all_sites, we need to know where disorder is to ignore overlaps
285
        asym_is_disordered = []     # True/False list same length as asym unit
286
        if self.disorder == 'all_sites':
287
            for a in crystal.asymmetric_unit_molecule.atoms:
288
                occ = a.occupancy
289
                if _atom_has_disorder(a.label, occ):
290
                    asym_is_disordered.append(True)
291
                else:
292
                    asym_is_disordered.append(False)
293
294
        # check all atoms have coords. option/default remove unknown sites?
295
        if not molecule.all_atoms_have_sites or \
296
           any(a.fractional_coordinates is None for a in molecule.atoms):
297
            if self.show_warnings:
298
                warnings.warn(
299
                    f'Skipping {self.current_identifier} as some atoms do not have sites')
300
            return None
301
302
        # get cell & asymmetric unit
303
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
304
        asym_frac_motif = np.array([tuple(a.fractional_coordinates)
305
                                    for a in crystal.asymmetric_unit_molecule.atoms])
306
        asym_frac_motif = np.mod(asym_frac_motif, 1)
307
        asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms]
308
309
        # remove overlapping sites, check sites exist
310
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
311
        if keep_sites is not None:
312
            asym_frac_motif = asym_frac_motif[keep_sites]
313
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
314
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
315
        if self._has_no_valid_sites(asym_frac_motif):
316
            return None
317
318
        sitesym = crystal.symmetry_operators
319
        if not sitesym:
320
            sitesym = ['x,y,z', ]
321
322
        entry.crystal.molecule = crystal.disordered_molecule    # for extract_data. remove?
323
324
        return self._construct_periodic_set(entry, asym_frac_motif, asym_symbols, sitesym, cell)
325
326
    def expand(
327
            self,
328
            asym_frac_motif: np.ndarray,
329
            sitesym: Sequence[str]
330
    ) -> Tuple[np.ndarray, ...]:
331
        """
332
        Asymmetric unit's fractional coords + sitesyms (as strings)
333
        -->
334
        frac_motif, asym_unit, multiplicities, inverses
335
        """
336
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
337
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
338
        all_sites = []
339
        asym_unit = [0]
340
        multiplicities = []
341
        inverses = []
342
343
        for inv, site in enumerate(asym_frac_motif):
344
            multiplicity = 0
345
346
            for rot, trans in zip(rotations, translations):
347
                site_ = np.mod(np.dot(rot, site) + trans, 1)
348
349
                if not all_sites:
350
                    all_sites.append(site_)
351
                    inverses.append(inv)
352
                    multiplicity += 1
353
                    continue
354
355
                if not self._is_site_overlapping(site_, all_sites, inverses, inv):
356
                    all_sites.append(site_)
357
                    inverses.append(inv)
358
                    multiplicity += 1
359
360
            if multiplicity > 0:
361
                multiplicities.append(multiplicity)
362
                asym_unit.append(len(all_sites))
363
364
        frac_motif = np.array(all_sites)
365
        asym_unit = np.array(asym_unit[:-1])
366
        multiplicities = np.array(multiplicities)
367
        return frac_motif, asym_unit, multiplicities, inverses
368
369
    def _is_site_overlapping(self, new_site, all_sites, inverses, inv):
370
        """Return True (and warn) if new_site overlaps with a site in all_sites."""
371
        diffs1 = np.abs(new_site - all_sites)
372
        diffs2 = np.abs(diffs1 - 1)
373
        mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol,
374
                                    diffs2 <= _Reader.equiv_site_tol),
375
                        axis=-1)
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 2 spaces).
Loading history...
376
377
        if np.any(mask):
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
378
            where_equal = np.argwhere(mask).flatten()
379
            for ind in where_equal:
380
                if inverses[ind] == inv:
381
                    pass
382
                else:
383
                    if self.show_warnings:
384
                        warnings.warn(
385
                            f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (108/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
386
            return True
387
        else:
388
            return False
389
390
    def _validate_sites(self, asym_frac_motif, asym_is_disordered):
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
391
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
392
        site_diffs2 = np.abs(site_diffs1 - 1)
393
        overlapping = np.triu(np.all(
394
            (site_diffs1 <= _Reader.equiv_site_tol) |
395
            (site_diffs2 <= _Reader.equiv_site_tol),
396
            axis=-1), 1)
397
398
        if self.disorder == 'all_sites':
399
            for i, j in np.argwhere(overlapping):
400
                if asym_is_disordered[i] or asym_is_disordered[j]:
401
                    overlapping[i, j] = False
402
403
        if overlapping.any():
404
            if self.show_warnings:
405
                warnings.warn(
406
                    f'{self.current_identifier} may have overlapping sites; duplicates will be removed')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
407
            keep_sites = ~overlapping.any(0)
408
            return keep_sites
409
410
    def _has_no_valid_sites(self, motif):
411
        if motif.shape[0] == 0:
412
            if self.show_warnings:
413
                warnings.warn(
414
                    f'Skipping {self.current_identifier} as there are no sites with coordinates')
415
            return True
416
        return False
417
418
    def _construct_periodic_set(self, raw_item, asym_frac_motif, asym_symbols, sitesym, cell):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
419
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
420
        full_types =  [asym_symbols[i] for i in inverses]
0 ignored issues
show
Coding Style introduced by
Exactly one space required after assignment
Loading history...
421
        motif = frac_motif @ cell
422
423
        kwargs = {
424
            'name': self.current_identifier,
425
            'asymmetric_unit': asym_unit,
426
            'wyckoff_multiplicities': multiplicities,
427
            'types': full_types,
428
        }
429
430
        if self.current_filename:
431
            kwargs['filename'] = self.current_filename
432
433
        if self.extract_data is not None:
434
            for key in self.extract_data:
435
                kwargs[key] = self.extract_data[key](raw_item)
436
437
        return PeriodicSet(motif, cell, **kwargs)
438
439
    def _heaviest_component(self, molecule):
0 ignored issues
show
Coding Style introduced by
This method could be written as a function/class method.

If a method does not access any attributes of the class, it could also be implemented as a function or static method. This can help improve readability. For example

class Foo:
    def some_method(self, x, y):
        return x + y;

could be written as

class Foo:
    @classmethod
    def some_method(cls, x, y):
        return x + y;
Loading history...
440
        """Heaviest component (removes all but the heaviest component of the asym unit).
441
        Intended for removing solvents. Probably doesn't play well with disorder"""
442
        if len(molecule.components) > 1:
443
            component_weights = []
444
            for component in molecule.components:
445
                weight = 0
446
                for a in component.atoms:
447
                    if isinstance(a.atomic_weight, (float, int)):
448
                        if isinstance(a.occupancy, (float, int)):
449
                            weight += a.occupancy * a.atomic_weight
450
                        else:
451
                            weight += a.atomic_weight
452
                component_weights.append(weight)
453
            largest_component_arg = np.argmax(np.array(component_weights))
454
            molecule = molecule.components[largest_component_arg]
455
456
        return molecule
457
458
class CifReader(_Reader):
459
    """Read all structures in a .CIF with ``ase`` or ``ccdc``
460
    (``csd-python-api`` only), yielding  :class:`.periodicset.PeriodicSet`
461
    objects which can be passed to :func:`.calculate.AMD` or
462
    :func:`.calculate.PDD`.
463
464
    Examples:
465
466
        ::
467
468
            # Put all crystals in a .CIF in a list
469
            structures = list(amd.CifReader('mycif.cif'))
470
471
            # Reads just one if the .CIF has just one crystal
472
            periodic_set = amd.CifReader('mycif.cif').read_one()
473
474
            # If a folder has several .CIFs each with one crystal, use
475
            structures = list(amd.CifReader('path/to/folder', folder=True))
476
477
            # Make list of AMDs (with k=100) of crystals in a .CIF
478
            amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')]
479
    """
480
481
    @_extend_signature(_Reader.__init__)
482
    def __init__(
483
            self,
484
            path,
485
            reader='ase',
486
            folder=False,
487
            **kwargs):
488
489
        super().__init__(**kwargs)
490
491
        if reader not in ('ase', 'ccdc'):
492
            raise ValueError(f'Invalid reader {reader}; must be ase or ccdc.')
493
494
        if reader == 'ase' and self.heaviest_component:
495
            raise NotImplementedError('Parameter heaviest_component not implimented for ase, only ccdc.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
496
497
        if reader == 'ase':
498
            extensions = {'cif'}
499
            file_parser = ase.io.cif.parse_cif
500
            pset_converter = self._CIFBlock_to_PeriodicSet
501
502
        elif reader == 'ccdc':
503
            if not _CCDC_ENABLED:
504
                raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
505
            extensions = ccdc.io.EntryReader.known_suffixes
506
            file_parser = ccdc.io.EntryReader
507
            pset_converter = self._Entry_to_PeriodicSet
508
509
        if folder:
510
            generator = self._folder_generator(path, file_parser, extensions)
0 ignored issues
show
introduced by
The variable extensions does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable file_parser does not seem to be defined for all execution paths.
Loading history...
511
        else:
512
            generator = file_parser(path)
513
514
        self._generator = self._map(pset_converter, generator)
0 ignored issues
show
introduced by
The variable pset_converter does not seem to be defined for all execution paths.
Loading history...
515
516
    def _folder_generator(self, path, file_parser, extensions):
517
        for file in os.listdir(path):
518
            suff = os.path.splitext(file)[1][1:]
519
            if suff.lower() in extensions:
520
                self.current_filename = file
521
                yield from file_parser(os.path.join(path, file))
522
523
524
class CSDReader(_Reader):
525
    """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects.
526
527
    The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed
528
    to :func:`.calculate.AMD` or :func:`.calculate.PDD`.
529
530
    Examples:
531
532
        Get crystals with refcodes in a list::
533
534
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
535
            structures = list(amd.CSDReader(refcodes))
536
537
        Read refcode families (any whose refcode starts with strings in the list)::
538
539
            refcodes = ['ACSALA', 'HXACAN']
540
            structures = list(amd.CSDReader(refcodes, families=True))
541
542
        Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`::
543
544
            reader = amd.CSDReader()
545
            debxit01 = reader.entry('DEBXIT01')
546
547
            # looping over this generic reader will yield all CSD entries
548
            for periodic_set in reader:
549
                ...
550
551
        Make list of AMD (with k=100) for crystals in these families::
552
553
            refcodes = ['ACSALA', 'HXACAN']
554
            amds = []
555
            for periodic_set in amd.CSDReader(refcodes, families=True):
556
                amds.append(amd.AMD(periodic_set, 100))
557
    """
558
559
    @_extend_signature(_Reader.__init__)
560
    def __init__(
561
            self,
562
            refcodes=None,
563
            families=False,
564
            **kwargs):
565
566
        if not _CCDC_ENABLED:
567
            raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (101/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
568
569
        super().__init__(**kwargs)
570
571
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
572
            refcodes = None
573
574
        if refcodes is None:
575
            families = False
576
        else:
577
            refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes)
578
579
        # families parameter reads all crystals with ids starting with passed refcodes
580
        if families:
581
            all_refcodes = []
582
            for refcode in refcodes:
583
                query = ccdc.search.TextNumericSearch()
584
                query.add_identifier(refcode)
585
                all_refcodes.extend((hit.identifier for hit in query.search()))
0 ignored issues
show
introduced by
The variable hit does not seem to be defined in case the for loop on line 582 is not entered. Are you sure this can never be the case?
Loading history...
586
587
            # filter to unique refcodes
588
            seen = set()
589
            seen_add = seen.add
590
            refcodes = [
591
                refcode for refcode in all_refcodes
592
                if not (refcode in seen or seen_add(refcode))]
593
594
        self._entry_reader = ccdc.io.EntryReader('CSD')
595
        self._generator = self._map(
596
            self._Entry_to_PeriodicSet,
597
            self._ccdc_generator(refcodes))
598
599
    def _ccdc_generator(self, refcodes):
600
        """Generates ccdc Entries from CSD refcodes"""
601
602
        if refcodes is None:
603
            for entry in self._entry_reader:
604
                yield entry
605
        else:
606
            for refcode in refcodes:
607
                try:
608
                    entry = self._entry_reader.entry(refcode)
609
                    yield entry
610
                except RuntimeError:
611
                    warnings.warn(
612
                        f'Identifier {refcode} not found in database')
613
614
    def entry(self, refcode: str) -> PeriodicSet:
615
        """Read a PeriodicSet given any CSD refcode."""
616
617
        entry = self._entry_reader.entry(refcode)
618
        periodic_set = self._Entry_to_PeriodicSet(entry)
619
        return periodic_set
620
621
622
class SetWriter:
623
    """Write several :class:`.periodicset.PeriodicSet` objects to a .hdf5 file.
624
    Reading the .hdf5 is much faster than parsing a .CIF file.
625
626
    Examples:
627
628
        Write the crystals in mycif.cif to a .hdf5 file::
629
630
            with amd.SetWriter('crystals.hdf5') as writer:
631
632
                for periodic_set in amd.CifReader('mycif.cif'):
633
                    writer.write(periodic_set)
634
635
                # use iwrite to write straight from an iterator
636
                # below is equivalent to the above loop
637
                writer.iwrite(amd.CifReader('mycif.cif'))
638
639
    Read the crystals back from the file with :class:`SetReader`.
640
    """
641
642
    _str_dtype = h5py.vlen_dtype(str)
643
644
    def __init__(self, filename: str):
645
646
        self.file = h5py.File(filename, 'w', track_order=True)
647
648
    def write(self, periodic_set: PeriodicSet, name: Optional[str] = None):
649
        """Write a PeriodicSet object to file."""
650
651
        if not isinstance(periodic_set, PeriodicSet):
652
            raise ValueError(
653
                f'Object type {periodic_set.__class__.__name__} cannot be written with SetWriter')
654
655
        # need a name to store or you can't access items by key
656
        if name is None:
657
            if periodic_set.name is None:
658
                raise ValueError(
659
                    'Periodic set must have a name to be written. Either set the name '
660
                    'attribute of the PeriodicSet or pass a name to SetWriter.write()')
661
            name = periodic_set.name
662
663
        # this group is the PeriodicSet
664
        group = self.file.create_group(name)
665
666
        # datasets in the group for motif and cell
667
        group.create_dataset('motif', data=periodic_set.motif)
668
        group.create_dataset('cell', data=periodic_set.cell)
669
670
        if periodic_set.tags:
671
            # a subgroup contains tags that are lists or ndarrays
672
            tags_group = group.create_group('tags')
673
674
            for tag in periodic_set.tags:
675
                data = periodic_set.tags[tag]
676
677
                if data is None:               # nonce to handle None
678
                    tags_group.attrs[tag] = '__None'
679
                elif np.isscalar(data):        # scalars (nums and strs) stored as attrs
680
                    tags_group.attrs[tag] = data
681
                elif isinstance(data, np.ndarray):
682
                    tags_group.create_dataset(tag, data=data)
683
                elif isinstance(data, list):
684
                    # lists of strings stored as special type for some reason
685
                    if any(isinstance(d, str) for d in data):
686
                        data = [str(d) for d in data]
687
                        tags_group.create_dataset(tag,
688
                                                  data=data,
689
                                                  dtype=SetWriter._str_dtype)
690
                    else:    # other lists must be castable to ndarray
691
                        data = np.asarray(data)
692
                        tags_group.create_dataset(tag, data=np.array(data))
693
                else:
694
                    raise ValueError(
695
                        f'Cannot store tag of type {type(data)} with SetWriter')
696
697
    def iwrite(self, periodic_sets: Iterable[PeriodicSet]):
698
        """Write :class:`.periodicset.PeriodicSet` objects from an iterable to file."""
699
        for periodic_set in periodic_sets:
700
            self.write(periodic_set)
701
702
    def close(self):
703
        """Close the :class:`SetWriter`."""
704
        self.file.close()
705
706
    def __enter__(self):
707
        return self
708
709
    # handle exceptions?
710
    def __exit__(self, exc_type, exc_value, tb):
711
        self.file.close()
712
713
714
class SetReader:
715
    """Read :class:`.periodicset.PeriodicSet` objects from a .hdf5 file written
716
    with :class:`SetWriter`. Acts like a read-only dict that can be iterated
717
    over (preserves write order).
718
719
    Examples:
720
721
        Get PDDs (k=100) of crystals in crystals.hdf5::
722
723
            pdds = []
724
            with amd.SetReader('crystals.hdf5') as reader:
725
                for periodic_set in reader:
726
                    pdds.append(amd.PDD(periodic_set, 100))
727
728
            # above is equivalent to:
729
            pdds = [amd.PDD(pset, 100) for pset in amd.SetReader('crystals.hdf5')]
730
    """
731
732
    def __init__(self, filename: str):
733
734
        self.file = h5py.File(filename, 'r', track_order=True)
735
736
    def _get_set(self, name: str) -> PeriodicSet:
737
        # take a name in the set and return the PeriodicSet
738
        group = self.file[name]
739
        periodic_set = PeriodicSet(group['motif'][:], group['cell'][:], name=name)
740
741
        if 'tags' in group:
742
            for tag in group['tags']:
743
                data = group['tags'][tag][:]
744
745
                if any(isinstance(d, (bytes, bytearray)) for d in data):
746
                    periodic_set.tags[tag] = [d.decode() for d in data]
747
                else:
748
                    periodic_set.tags[tag] = data
749
750
            for attr in group['tags'].attrs:
751
                data = group['tags'].attrs[attr]
752
                periodic_set.tags[attr] = None if data == '__None' else data
753
754
        return periodic_set
755
756
    def close(self):
757
        """Close the :class:`SetReader`."""
758
        self.file.close()
759
760
    def family(self, refcode: str) -> Iterable[PeriodicSet]:
761
        """Yield any :class:`.periodicset.PeriodicSet` whose name starts with
762
        input refcode."""
763
        for name in self.keys():
764
            if name.startswith(refcode):
765
                yield self._get_set(name)
766
767
    def __getitem__(self, name):
768
        # index by name. Not found exc?
769
        return self._get_set(name)
770
771
    def __len__(self):
772
        return len(self.keys())
773
774
    def __iter__(self):
775
        # interface to loop over the SetReader; does not close the SetReader when done
776
        for name in self.keys():
777
            yield self._get_set(name)
778
779
    def __contains__(self, item):
780
        return bool(item in self.keys())
781
782
    def keys(self):
783
        """Yield names of items in the :class:`SetReader`."""
784
        return self.file['/'].keys()
785
786
    def __enter__(self):
787
        return self
788
789
    # handle exceptions?
790
    def __exit__(self, exc_type, exc_value, tb):
791
        self.file.close()
792
793
794
def crystal_to_periodicset(crystal):
795
    """ccdc.crystal.Crystal --> amd.periodicset.PeriodicSet.
796
    Ignores disorder, missing sites/coords, checks & no options.
797
    Is a stripped-down version of the function used in CifReader."""
798
799
    cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
800
801
    # asymmetric unit fractional coordinates
802
    asym_frac_motif = np.array([tuple(a.fractional_coordinates)
803
                                for a in crystal.asymmetric_unit_molecule.atoms])
804
    asym_frac_motif = np.mod(asym_frac_motif, 1)
805
806
    # if the above removed everything, skip this structure
807
    if asym_frac_motif.shape[0] == 0:
808
        raise ValueError(f'{crystal.identifier} has no coordinates')
809
810
    sitesym = crystal.symmetry_operators
811
    if not sitesym:
812
        sitesym = ('x,y,z', )
813
    r = _Reader()
814
    r.current_identifier = crystal.identifier
815
    frac_motif, asym_unit, multiplicities, _ = r.expand(asym_frac_motif, sitesym)
816
    motif = frac_motif @ cell
817
818
    kwargs = {
819
        'name': crystal.identifier,
820
        'asymmetric_unit': asym_unit,
821
        'wyckoff_multiplicities': multiplicities,
822
    }
823
824
    periodic_set = PeriodicSet(motif, cell, **kwargs)
825
    return periodic_set
826
827
828
def cifblock_to_periodicset(block):
829
    """ase.io.cif.CIFBlock --> amd.periodicset.PeriodicSet.
830
    Ignores disorder, missing sites/coords, checks & no options.
831
    Is a stripped-down version of the function used in CifReader."""
832
833
    cell = block.get_cell().array
834
    asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
835
836
    if None in asym_frac_motif:
837
        asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
838
        if None in asym_motif:
839
            warnings.warn(
840
                f'Skipping {block.name} as coordinates were not found')
841
            return None
842
843
        asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
844
845
    asym_frac_motif = np.mod(np.array(asym_frac_motif).T, 1)
846
847
    if asym_frac_motif.shape[0] == 0:
848
        raise ValueError(f'{block.name} has no coordinates')
849
850
    sitesym = ('x,y,z', )
851
    for tag in _Reader.symop_tags:
852
        if tag in block:
853
            sitesym = block[tag]
854
            break
855
856
    if isinstance(sitesym, str):
857
        sitesym = [sitesym]
858
859
    dummy_reader = _Reader()
860
    dummy_reader.current_identifier = block.name
861
    frac_motif, asym_unit, multiplicities, _ = dummy_reader.expand(asym_frac_motif, sitesym)
862
    motif = frac_motif @ cell
863
864
    kwargs = {
865
        'name': block.name,
866
        'asymmetric_unit': asym_unit,
867
        'wyckoff_multiplicities': multiplicities
868
    }
869
870
    periodic_set = PeriodicSet(motif, cell, **kwargs)
871
    return periodic_set
872