Passed
Push — master ( 68f603...f19e60 )
by Daniel
01:56
created

amd.io._Reader.expand()   C

Complexity

Conditions 9

Size

Total Lines 57
Code Lines 41

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 41
dl 0
loc 57
rs 6.5626
c 0
b 0
f 0
cc 9
nop 3

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Contains I/O tools, including a .CIF reader and CSD reader
2
(``csd-python-api`` only) to extract periodic set representations
3
of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`.
4
5
These intermediate :class:`.periodicset.PeriodicSet` representations can be written
6
to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`.
7
This is much faster than rereading a .CIF and recomputing invariants.
8
"""
9
10
import os
11
import warnings
12
from typing import Callable, Iterable, Sequence, Tuple, Optional
13
14
import numpy as np
15
import ase.spacegroup.spacegroup    # parse_sitesym
16
import ase.io.cif
17
import h5py
18
19
from .periodicset import PeriodicSet
20
from .utils import _extend_signature, cellpar_to_cell
21
22
try:
23
    import ccdc.io       # EntryReader
24
    import ccdc.search   # TextNumericSearch
25
    _CCDC_ENABLED = True
26
except (ImportError, RuntimeError) as exception:
27
    _CCDC_ENABLED = False
28
29
 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
30
def _warning(message, category, filename, lineno, *args, **kwargs):
0 ignored issues
show
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
31
    return f'{filename}:{lineno}: {category.__name__}: {message}\n'
32
33
warnings.formatwarning = _warning
34
35
36
def _atom_has_disorder(label, occupancy):
37
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
38
39
40
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (9/7)
Loading history...
41
    """Base Reader class. Contains parsers for converting ase CifBlock
42
    and ccdc Entry objects to PeriodicSets.
43
44
    Intended use:
45
46
    First make a new method for _Reader converting object to PeriodicSet
47
    (e.g. named _X_to_PSet). Then make this class outline:
48
49
    class XReader(_Reader):
50
        def __init__(self, ..., **kwargs):
51
52
        super().__init__(**kwargs)
53
54
        # setup and checks
55
56
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
57
58
        # set self._generator like this
59
        self._generator = self._read(iterable, self._X_to_PSet)
60
    """
61
62
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
63
    reserved_tags = {
64
        'motif',
65
        'cell',
66
        'name',
67
        'asymmetric_unit',
68
        'wyckoff_multiplicities',
69
        'types',}
70
    atom_site_fract_tags = [
71
        '_atom_site_fract_x',
72
        '_atom_site_fract_y',
73
        '_atom_site_fract_z',]
74
    atom_site_cartn_tags = [
75
        '_atom_site_cartn_x',
76
        '_atom_site_cartn_y',
77
        '_atom_site_cartn_z',]
78
    symop_tags = [
79
        '_space_group_symop_operation_xyz',
80
        '_space_group_symop.operation_xyz',
81
        '_symmetry_equiv_pos_as_xyz',]
82
83
    equiv_site_tol = 1e-3
84
85
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
86
            self,
87
            remove_hydrogens=False,
88
            disorder='skip',
89
            heaviest_component=False,
90
            show_warnings=True,
91
            extract_data=None,
92
            include_if=None):
93
94
        # settings
95
        if disorder not in _Reader.disorder_options:
96
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
97
98
        if extract_data:
99
            if not isinstance(extract_data, dict):
100
                raise ValueError('extract_data must be a dict with callable values')
101
            for key in extract_data:
102
                if not callable(extract_data[key]):
103
                    raise ValueError('extract_data must be a dict with callable values')
104
                if key in _Reader.reserved_tags:
105
                    raise ValueError(f'extract_data includes reserved key {key}')
106
107
        if include_if:
108
            for func in include_if:
109
                if not callable(func):
110
                    raise ValueError('include_if must be a list of callables')
111
112
        self.remove_hydrogens = remove_hydrogens
113
        self.disorder = disorder
114
        self.heaviest_component = heaviest_component
115
        self.extract_data = extract_data
116
        self.include_if = include_if
117
        self.show_warnings = show_warnings
118
        self.current_identifier = None
119
        self.current_filename = None
120
        self._generator = None
121
122
    def __iter__(self):
123
        yield from self._generator
0 ignored issues
show
introduced by
Non-iterable value self._generator is used in an iterating context
Loading history...
124
125
    def read_one(self):
126
        """Read the next (or first) item."""
127
        return next(iter(self._generator))
128
129
    # basically the builtin map, but skips items if the function returned None.
130
    # The object returned by this function (Iterable of PeriodicSets) is set to
131
    # self._generator; then iterating over the Reader iterates over
132
    # self._generator.
133
    @staticmethod
134
    def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
135
        """Iterates over iterable, passing items through parser and
136
        yielding the result if it is not None.
137
        """
138
139
        for item in iterable:
140
            res = func(item)
141
            if res is not None:
142
                yield res
143
144
    def expand(
145
            self,
146
            asym_frac_motif: np.ndarray,
147
            sitesym: Sequence[str]
148
    ) -> Tuple[np.ndarray, ...]:
149
        """
150
        Asymmetric unit's fractional coords + sitesyms (as strings)
151
        -->
152
        frac_motif, asym_unit, multiplicities, inverses
153
        """
154
155
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
156
        all_sites = []
157
        asym_unit = [0]
158
        multiplicities = []
159
        inverses = []
160
161
        for inv, site in enumerate(asym_frac_motif):
0 ignored issues
show
unused-code introduced by
Too many nested blocks (6/5)
Loading history...
162
            multiplicity = 0
163
164
            for rot, trans in zip(rotations, translations):
165
                site_ = np.mod(np.dot(rot, site) + trans, 1)
166
167
                if not all_sites:
168
                    all_sites.append(site_)
169
                    inverses.append(inv)
170
                    multiplicity += 1
171
                    continue
172
173
                diffs1 = np.abs(site_ - all_sites)
174
                diffs2 = np.abs(diffs1 - 1)
175
                mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol,
176
                                            diffs2 <= _Reader.equiv_site_tol),
177
                              axis=-1)
178
179
                if np.any(mask):
180
                    where_equal = np.argwhere(mask).flatten()
181
                    for ind in where_equal:
182
                        if inverses[ind] == inv:
183
                            pass
184
                        else:
185
                            if self.show_warnings:
186
                                warnings.warn(
187
                                    f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (116/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
188
                else:
189
                    all_sites.append(site_)
190
                    inverses.append(inv)
191
                    multiplicity += 1
192
193
            if multiplicity > 0:
194
                multiplicities.append(multiplicity)
195
                asym_unit.append(len(all_sites))
196
197
        frac_motif = np.array(all_sites)
198
        asym_unit = np.array(asym_unit[:-1])
199
        multiplicities = np.array(multiplicities)
200
        return frac_motif, asym_unit, multiplicities, inverses
201
202
    def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet:
203
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
204
205
        # skip if structure does not pass checks in include_if
206
        if self.include_if:
207
            if not all(check(block) for check in self.include_if):
208
                return None
209
210
        # read name, cell, asym motif and atomic symbols
211
        self.current_identifier = block.name
212
        cell = block.get_cell().array
213
        asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
214
        if None in asym_frac_motif:
215
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
216
            if None in asym_motif:
217
                if self.show_warnings:
218
                    warnings.warn(
219
                        f'Skipping {self.current_identifier} as coordinates were not found')
220
                return None
221
            asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
222
        asym_frac_motif = np.array(asym_frac_motif).T
223
224
        try:
225
            asym_symbols = block.get_symbols()
226
        except ase.io.cif.NoStructureData as _:
227
            asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))]
228
229
        # indices of sites to remove
230
        remove = []
231
        if self.remove_hydrogens:
232
            remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
233
234
        # find disordered sites
235
        asym_is_disordered = []
236
        occupancies = block.get('_atom_site_occupancy')
237
        labels = block.get('_atom_site_label')
238
        if occupancies is not None:
239
            disordered = []     # indices where there is disorder
240
            for i, (occ, label) in enumerate(zip(occupancies, labels)):
241
                if _atom_has_disorder(label, occ):
242
                    if i not in remove:
243
                        disordered.append(i)
244
                        asym_is_disordered.append(True)
245
                else:
246
                    asym_is_disordered.append(False)
247
248
            if self.disorder == 'skip' and len(disordered) > 0:
249
                if self.show_warnings:
250
                    warnings.warn(
251
                        f'Skipping {self.current_identifier} as structure is disordered')
252
                return None
253
254
            if self.disorder == 'ordered_sites':
255
                remove.extend(disordered)
256
257
        # remove sites
258
        asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1)
259
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove]
260
        asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove]
261
262
        # if there are overlapping sites in asym unit, warn and keep only one
263
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
264
        site_diffs2 = np.abs(site_diffs1 - 1)
265
        overlapping = np.triu(np.all(
266
            (site_diffs1 <= _Reader.equiv_site_tol) |
267
            (site_diffs2 <= _Reader.equiv_site_tol), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
268
            axis=-1), 1)
269
270
        # don't remove overlapping sites if one is disordered and disorder='all_sites'
271
        if self.disorder == 'all_sites':
272
            for i, j in np.argwhere(overlapping):
273
                if asym_is_disordered[i] or asym_is_disordered[j]:
274
                    overlapping[i, j] = False
275
276
        if overlapping.any():
277
            if self.show_warnings:
278
                warnings.warn(
279
                    f'{self.current_identifier} may have overlapping sites; duplicates will be removed')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
280
            keep_sites = ~overlapping.any(0)
281
            asym_frac_motif = asym_frac_motif[keep_sites]
282
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
283
284
        # if no points left in motif, skip structure
285
        if asym_frac_motif.shape[0] == 0:
286
            if self.show_warnings:
287
                warnings.warn(
288
                    f'Skipping {self.current_identifier} as there are no sites with coordinates')
289
            return None
290
291
        # get symmetries
292
        sitesym = ['x,y,z', ]
293
        for tag in _Reader.symop_tags:
294
            if tag in block:
295
                sitesym = block[tag]
296
                break
297
298
        if isinstance(sitesym, str):
299
            sitesym = [sitesym]
300
301
        # expand the asymmetric unit to full motif + multiplicities
302
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
303
        motif = frac_motif @ cell
304
305
        # construct PeriodicSet
306
        kwargs = {
307
            'name': self.current_identifier, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
308
            'asymmetric_unit': asym_unit,
309
            'wyckoff_multiplicities': multiplicities,
310
            'types': [asym_symbols[i] for i in inverses],
311
        }
312
313
        if self.current_filename:
314
            kwargs['filename'] = self.current_filename
315
316
        if self.extract_data is not None:
317
            for key in self.extract_data:
318
                kwargs[key] = self.extract_data[key](block)
319
320
        periodic_set = PeriodicSet(motif, cell, **kwargs)
321
        return periodic_set
322
323
    def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet:
0 ignored issues
show
best-practice introduced by
Too many return statements (8/6)
Loading history...
324
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
325
326
        # skip if structure does not pass checks in include_if
327
        if self.include_if:
328
            if not all(check(entry) for check in self.include_if):
329
                return None
330
331
        self.current_identifier = entry.identifier
332
        # structure must pass this test
333
        if not entry.has_3d_structure:
334
            if self.show_warnings:
335
                warnings.warn(
336
                    f'Skipping {self.current_identifier} as entry has no 3D structure')
337
            return None
338
339
        try:
340
            crystal = entry.crystal
341
        except RuntimeError as e:
342
            if self.show_warnings:
343
                warnings.warn(f'Skipping {self.current_identifier}: {e}')
344
            return None
345
346
        # first disorder check, if skipping. If occ == 1 for all atoms but the entry
347
        # or crystal is listed as having disorder, skip (can't know where disorder is).
348
        # If occ != 1 for any atoms, we wait to see if we remove them before skipping.
349
        molecule = crystal.disordered_molecule
350
        if self.disorder == 'ordered_sites':
351
            molecule.remove_atoms(
352
                a for a in molecule.atoms if a.label.endswith('?'))
353
354
        may_have_disorder = False
355
        if self.disorder == 'skip':
356
            for a in molecule.atoms:
357
                occ = a.occupancy
358
                if _atom_has_disorder(a.label, occ):
359
                    may_have_disorder = True
360
                    break
361
362
            if not may_have_disorder:
363
                if crystal.has_disorder or entry.has_disorder:
364
                    if self.show_warnings:
365
                        warnings.warn(
366
                            f'Skipping {self.current_identifier} as structure is disordered')
367
                    return None
368
369
        if self.remove_hydrogens:
370
            molecule.remove_atoms(
371
                a for a in molecule.atoms if a.atomic_symbol in 'HD')
372
373
        # heaviest component (removes all but the heaviest component of the asym unit)
374
        # intended for removing solvents. probably doesn't play well with disorder
375
        if self.heaviest_component:
0 ignored issues
show
unused-code introduced by
Too many nested blocks (6/5)
Loading history...
376
            if len(molecule.components) > 1:
377
                component_weights = []
378
                for component in molecule.components:
379
                    weight = 0
380
                    for a in component.atoms:
381
                        if isinstance(a.atomic_weight, (float, int)):
382
                            if isinstance(a.occupancy, (float, int)):
383
                                weight += a.occupancy * a.atomic_weight
384
                            else:
385
                                weight += a.atomic_weight
386
                    component_weights.append(weight)
387
                largest_component_arg = np.argmax(np.array(component_weights))
388
                molecule = molecule.components[largest_component_arg]
389
390
        crystal.molecule = molecule
391
392
        # by here all atoms to be removed have been (except via ordered_sites).
393
        # If disorder == 'skip' and there were atom(s) with occ < 1 found
394
        # eariler, we check if all such atoms were removed. If not, skip.
395
        if self.disorder == 'skip' and may_have_disorder:
396
            for a in crystal.disordered_molecule.atoms:
397
                occ = a.occupancy
398
                if _atom_has_disorder(a.label, occ):
399
                    if self.show_warnings:
400
                        warnings.warn(
401
                            f'Skipping {self.current_identifier} as structure is disordered')
402
                    return None
403
404
        # if disorder is all_sites, we need to know where disorder is to ignore overlaps
405
        asym_is_disordered = []     # True/False list same length as asym unit
406
        if self.disorder == 'all_sites':
407
            for a in crystal.asymmetric_unit_molecule.atoms:
408
                occ = a.occupancy
409
                if _atom_has_disorder(a.label, occ):
410
                    asym_is_disordered.append(True)
411
                else:
412
                    asym_is_disordered.append(False)
413
414
        # check all atoms have coords. option/default remove unknown sites?
415
        if not molecule.all_atoms_have_sites or \
416
           any(a.fractional_coordinates is None for a in molecule.atoms):
417
            if self.show_warnings:
418
                warnings.warn(
419
                    f'Skipping {self.current_identifier} as some atoms do not have sites')
420
            return None
421
422
        # get cell & asymmetric unit
423
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
424
        asym_frac_motif = np.array([tuple(a.fractional_coordinates)
425
                                    for a in crystal.asymmetric_unit_molecule.atoms])
426
        asym_frac_motif = np.mod(asym_frac_motif, 1)
427
        asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms]
428
429
        # if there are overlapping sites in asym unit, warn and keep only one
430
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
431
        site_diffs2 = np.abs(site_diffs1 - 1)
432
        overlapping = np.triu(np.all(
433
            (site_diffs1 <= _Reader.equiv_site_tol) |
434
            (site_diffs2 <= _Reader.equiv_site_tol),
435
            axis=-1), 1)
436
437
        # don't remove overlapping sites if one is disordered and disorder='all_sites'
438
        if self.disorder == 'all_sites':
439
            for i, j in np.argwhere(overlapping):
440
                if asym_is_disordered[i] or asym_is_disordered[j]:
441
                    overlapping[i, j] = False
442
443
        if overlapping.any():
444
            if self.show_warnings:
445
                warnings.warn(
446
                    f'{self.current_identifier} may have overlapping sites; '
447
                    'duplicates will be removed')
448
            keep_sites = ~overlapping.any(0)
449
            asym_frac_motif = asym_frac_motif[keep_sites]
450
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
451
452
        # if no points left in motif, skip structure
453
        if asym_frac_motif.shape[0] == 0:
454
            if self.show_warnings:
455
                warnings.warn(
456
                    f'Skipping {self.current_identifier} as there are no sites with coordinates')
457
            return None
458
459
        # get symmetries, expand the asymmetric unit to full motif + multiplicities
460
        sitesym = crystal.symmetry_operators
461
        if not sitesym:
462
            sitesym = ('x,y,z', )
463
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
464
        motif = frac_motif @ cell
465
466
        # construct PeriodicSet
467
        kwargs = {
468
            'name': self.current_identifier,
469
            'asymmetric_unit': asym_unit,
470
            'wyckoff_multiplicities': multiplicities,
471
            'types': [asym_symbols[i] for i in inverses],
472
        }
473
474
        if self.current_filename:
475
            kwargs['filename'] = self.current_filename
476
477
        if self.extract_data is not None:
478
            entry.crystal.molecule = crystal.disordered_molecule
479
            for key in self.extract_data:
480
                kwargs[key] = self.extract_data[key](entry)
481
482
        periodic_set = PeriodicSet(motif, cell, **kwargs)
483
        return periodic_set
484
485
486
class CifReader(_Reader):
487
    """Read all structures in a .CIF with ``ase`` or ``ccdc``
488
    (``csd-python-api`` only), yielding  :class:`.periodicset.PeriodicSet`
489
    objects which can be passed to :func:`.calculate.AMD` or
490
    :func:`.calculate.PDD`.
491
492
    Examples:
493
494
        ::
495
496
            # Put all crystals in a .CIF in a list
497
            structures = list(amd.CifReader('mycif.cif'))
498
499
            # Reads just one if the .CIF has just one crystal
500
            periodic_set = amd.CifReader('mycif.cif').read_one()
501
502
            # If a folder has several .CIFs each with one crystal, use
503
            structures = list(amd.CifReader('path/to/folder', folder=True))
504
505
            # Make list of AMDs (with k=100) of crystals in a .CIF
506
            amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')]
507
    """
508
509
    @_extend_signature(_Reader.__init__)
510
    def __init__(
511
            self,
512
            path,
513
            reader='ase',
514
            folder=False,
515
            **kwargs):
516
517
        super().__init__(**kwargs)
518
519
        if reader not in ('ase', 'ccdc'):
520
            raise ValueError(f'Invalid reader {reader}; must be ase or ccdc.')
521
522
        if reader == 'ase' and self.heaviest_component:
523
            raise NotImplementedError('Parameter heaviest_component not implimented for ase, only ccdc.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
524
525
        if reader == 'ase':
526
            extensions = {'cif'}
527
            file_parser = ase.io.cif.parse_cif
528
            pset_converter = self._CIFBlock_to_PeriodicSet
529
530
        elif reader == 'ccdc':
531
            if not _CCDC_ENABLED:
532
                raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
533
            extensions = ccdc.io.EntryReader.known_suffixes
534
            file_parser = ccdc.io.EntryReader
535
            pset_converter = self._Entry_to_PeriodicSet
536
537
        if folder:
538
            generator = self._folder_generator(path, file_parser, extensions)
0 ignored issues
show
introduced by
The variable extensions does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable file_parser does not seem to be defined for all execution paths.
Loading history...
539
        else:
540
            generator = file_parser(path)
541
542
        self._generator = self._map(pset_converter, generator)
0 ignored issues
show
introduced by
The variable pset_converter does not seem to be defined for all execution paths.
Loading history...
543
544
    def _folder_generator(self, path, file_parser, extensions):
545
        for file in os.listdir(path):
546
            suff = os.path.splitext(file)[1][1:]
547
            if suff.lower() in extensions:
548
                self.current_filename = file
549
                yield from file_parser(os.path.join(path, file))
550
551
552
class CSDReader(_Reader):
553
    """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects.
554
555
    The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed
556
    to :func:`.calculate.AMD` or :func:`.calculate.PDD`.
557
558
    Examples:
559
560
        Get crystals with refcodes in a list::
561
562
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
563
            structures = list(amd.CSDReader(refcodes))
564
565
        Read refcode families (any whose refcode starts with strings in the list)::
566
567
            refcodes = ['ACSALA', 'HXACAN']
568
            structures = list(amd.CSDReader(refcodes, families=True))
569
570
        Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`::
571
572
            reader = amd.CSDReader()
573
            debxit01 = reader.entry('DEBXIT01')
574
575
            # looping over this generic reader will yield all CSD entries
576
            for periodic_set in reader:
577
                ...
578
579
        Make list of AMD (with k=100) for crystals in these families::
580
581
            refcodes = ['ACSALA', 'HXACAN']
582
            amds = []
583
            for periodic_set in amd.CSDReader(refcodes, families=True):
584
                amds.append(amd.AMD(periodic_set, 100))
585
    """
586
587
    @_extend_signature(_Reader.__init__)
588
    def __init__(
589
            self,
590
            refcodes=None,
591
            families=False,
592
            **kwargs):
593
594
        if not _CCDC_ENABLED:
595
            raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (101/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
596
597
        super().__init__(**kwargs)
598
599
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
600
            refcodes = None
601
602
        if refcodes is None:
603
            families = False
604
        else:
605
            refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes)
606
607
        # families parameter reads all crystals with ids starting with passed refcodes
608
        if families:
609
            all_refcodes = []
610
            for refcode in refcodes:
611
                query = ccdc.search.TextNumericSearch()
612
                query.add_identifier(refcode)
613
                all_refcodes.extend((hit.identifier for hit in query.search()))
0 ignored issues
show
introduced by
The variable hit does not seem to be defined in case the for loop on line 610 is not entered. Are you sure this can never be the case?
Loading history...
614
615
            # filter to unique refcodes
616
            seen = set()
617
            seen_add = seen.add
618
            refcodes = [
619
                refcode for refcode in all_refcodes
620
                if not (refcode in seen or seen_add(refcode))]
621
622
        self._entry_reader = ccdc.io.EntryReader('CSD')
623
        self._generator = self._map(
624
            self._Entry_to_PeriodicSet,
625
            self._ccdc_generator(refcodes))
626
627
    def _ccdc_generator(self, refcodes):
628
        """Generates ccdc Entries from CSD refcodes"""
629
630
        if refcodes is None:
631
            for entry in self._entry_reader:
632
                yield entry
633
        else:
634
            for refcode in refcodes:
635
                try:
636
                    entry = self._entry_reader.entry(refcode)
637
                    yield entry
638
                except RuntimeError:
639
                    warnings.warn(
640
                        f'Identifier {refcode} not found in database')
641
642
    def entry(self, refcode: str) -> PeriodicSet:
643
        """Read a PeriodicSet given any CSD refcode."""
644
645
        entry = self._entry_reader.entry(refcode)
646
        periodic_set = self._Entry_to_PeriodicSet(entry)
647
        return periodic_set
648
649
650
class SetWriter:
651
    """Write several :class:`.periodicset.PeriodicSet` objects to a .hdf5 file.
652
    Reading the .hdf5 is much faster than parsing a .CIF file.
653
654
    Examples:
655
656
        Write the crystals in mycif.cif to a .hdf5 file::
657
658
            with amd.SetWriter('crystals.hdf5') as writer:
659
660
                for periodic_set in amd.CifReader('mycif.cif'):
661
                    writer.write(periodic_set)
662
663
                # use iwrite to write straight from an iterator
664
                # below is equivalent to the above loop
665
                writer.iwrite(amd.CifReader('mycif.cif'))
666
667
    Read the crystals back from the file with :class:`SetReader`.
668
    """
669
670
    _str_dtype = h5py.vlen_dtype(str)
671
672
    def __init__(self, filename: str):
673
674
        self.file = h5py.File(filename, 'w', track_order=True)
675
676
    def write(self, periodic_set: PeriodicSet, name: Optional[str] = None):
677
        """Write a PeriodicSet object to file."""
678
679
        if not isinstance(periodic_set, PeriodicSet):
680
            raise ValueError(
681
                f'Object type {periodic_set.__class__.__name__} cannot be written with SetWriter')
682
683
        # need a name to store or you can't access items by key
684
        if name is None:
685
            if periodic_set.name is None:
686
                raise ValueError(
687
                    'Periodic set must have a name to be written. Either set the name '
688
                    'attribute of the PeriodicSet or pass a name to SetWriter.write()')
689
            name = periodic_set.name
690
691
        # this group is the PeriodicSet
692
        group = self.file.create_group(name)
693
694
        # datasets in the group for motif and cell
695
        group.create_dataset('motif', data=periodic_set.motif)
696
        group.create_dataset('cell', data=periodic_set.cell)
697
698
        if periodic_set.tags:
699
            # a subgroup contains tags that are lists or ndarrays
700
            tags_group = group.create_group('tags')
701
702
            for tag in periodic_set.tags:
703
                data = periodic_set.tags[tag]
704
705
                if data is None:               # nonce to handle None
706
                    tags_group.attrs[tag] = '__None'
707
                elif np.isscalar(data):        # scalars (nums and strs) stored as attrs
708
                    tags_group.attrs[tag] = data
709
                elif isinstance(data, np.ndarray):
710
                    tags_group.create_dataset(tag, data=data)
711
                elif isinstance(data, list):
712
                    # lists of strings stored as special type for some reason
713
                    if any(isinstance(d, str) for d in data):
714
                        data = [str(d) for d in data]
715
                        tags_group.create_dataset(tag,
716
                                                  data=data,
717
                                                  dtype=SetWriter._str_dtype)
718
                    else:    # other lists must be castable to ndarray
719
                        data = np.asarray(data)
720
                        tags_group.create_dataset(tag,
721
                                                  data=np.array(data))      
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
722
                else:
723
                    raise ValueError(
724
                        f'Cannot store tag of type {type(data)} with SetWriter')
725
726
    def iwrite(self, periodic_sets: Iterable[PeriodicSet]):
727
        """Write :class:`.periodicset.PeriodicSet` objects from an iterable to file."""
728
        for periodic_set in periodic_sets:
729
            self.write(periodic_set)
730
731
    def close(self):
732
        """Close the :class:`SetWriter`."""
733
        self.file.close()
734
735
    def __enter__(self):
736
        return self
737
738
    # handle exceptions?
739
    def __exit__(self, exc_type, exc_value, tb):
740
        self.file.close()
741
742
743
class SetReader:
744
    """Read :class:`.periodicset.PeriodicSet` objects from a .hdf5 file written
745
    with :class:`SetWriter`. Acts like a read-only dict that can be iterated
746
    over (preserves write order).
747
748
    Examples:
749
750
        Get PDDs (k=100) of crystals in crystals.hdf5::
751
752
            pdds = []
753
            with amd.SetReader('crystals.hdf5') as reader:
754
                for periodic_set in reader:
755
                    pdds.append(amd.PDD(periodic_set, 100))
756
757
            # above is equivalent to:
758
            pdds = [amd.PDD(pset, 100) for pset in amd.SetReader('crystals.hdf5')]
759
    """
760
761
    def __init__(self, filename: str):
762
763
        self.file = h5py.File(filename, 'r', track_order=True)
764
765
    def _get_set(self, name: str) -> PeriodicSet:
766
        # take a name in the set and return the PeriodicSet
767
        group = self.file[name]
768
        periodic_set = PeriodicSet(group['motif'][:], group['cell'][:], name=name)
769
770
        if 'tags' in group:
771
            for tag in group['tags']:
772
                data = group['tags'][tag][:]
773
774
                if any(isinstance(d, (bytes, bytearray)) for d in data):
775
                    periodic_set.tags[tag] = [d.decode() for d in data]
776
                else:
777
                    periodic_set.tags[tag] = data
778
779
            for attr in group['tags'].attrs:
780
                data = group['tags'].attrs[attr]
781
                periodic_set.tags[attr] = None if data == '__None' else data
782
783
        return periodic_set
784
785
    def close(self):
786
        """Close the :class:`SetReader`."""
787
        self.file.close()
788
789
    def family(self, refcode: str) -> Iterable[PeriodicSet]:
790
        """Yield any :class:`.periodicset.PeriodicSet` whose name starts with
791
        input refcode."""
792
        for name in self.keys():
793
            if name.startswith(refcode):
794
                yield self._get_set(name)
795
796
    def __getitem__(self, name):
797
        # index by name. Not found exc?
798
        return self._get_set(name)
799
800
    def __len__(self):
801
        return len(self.keys())
802
803
    def __iter__(self):
804
        # interface to loop over the SetReader; does not close the SetReader when done
805
        for name in self.keys():
806
            yield self._get_set(name)
807
808
    def __contains__(self, item):
809
        return bool(item in self.keys())
810
811
    def keys(self):
812
        """Yield names of items in the :class:`SetReader`."""
813
        return self.file['/'].keys()
814
815
    def __enter__(self):
816
        return self
817
818
    # handle exceptions?
819
    def __exit__(self, exc_type, exc_value, tb):
820
        self.file.close()
821
822
823
def crystal_to_periodicset(crystal):
824
    """ccdc.crystal.Crystal --> amd.periodicset.PeriodicSet.
825
    Ignores disorder, missing sites/coords, checks & no options.
826
    Is a stripped-down version of the function used in CifReader."""
827
828
    cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
829
830
    # asymmetric unit fractional coordinates
831
    asym_frac_motif = np.array([tuple(a.fractional_coordinates)
832
                                for a in crystal.asymmetric_unit_molecule.atoms])
833
    asym_frac_motif = np.mod(asym_frac_motif, 1)
834
835
    # if the above removed everything, skip this structure
836
    if asym_frac_motif.shape[0] == 0:
837
        raise ValueError(f'{crystal.identifier} has no coordinates')
838
839
    sitesym = crystal.symmetry_operators
840
    if not sitesym: 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
841
        sitesym = ('x,y,z', )
842
    r = _Reader()
843
    r.current_identifier = crystal.identifier
844
    frac_motif, asym_unit, multiplicities, _ = r.expand(asym_frac_motif, sitesym)
845
    motif = frac_motif @ cell
846
847
    kwargs = {
848
        'name': crystal.identifier,
849
        'asymmetric_unit': asym_unit,
850
        'wyckoff_multiplicities': multiplicities,
851
    }
852
853
    periodic_set = PeriodicSet(motif, cell, **kwargs)
854
    return periodic_set
855
856
857
def cifblock_to_periodicset(block):
858
    """ase.io.cif.CIFBlock --> amd.periodicset.PeriodicSet.
859
    Ignores disorder, missing sites/coords, checks & no options.
860
    Is a stripped-down version of the function used in CifReader."""
861
862
    cell = block.get_cell().array
863
    asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
864
865
    if None in asym_frac_motif:
866
        asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
867
        if None in asym_motif:
868
            warnings.warn(
869
                f'Skipping {block.name} as coordinates were not found')
870
            return None
871
872
        asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
873
874
    asym_frac_motif = np.mod(np.array(asym_frac_motif).T, 1)
875
876
    if asym_frac_motif.shape[0] == 0:
877
        raise ValueError(f'{block.name} has no coordinates')
878
879
    sitesym = ('x,y,z', )
880
    for tag in _Reader.symop_tags:
881
        if tag in block:
882
            sitesym = block[tag]
883
            break
884
885
    if isinstance(sitesym, str):
886
        sitesym = [sitesym]
887
888
    dummy_reader = _Reader()
889
    dummy_reader.current_identifier = block.name
890
    frac_motif, asym_unit, multiplicities, _ = dummy_reader.expand(asym_frac_motif, sitesym)
891
    motif = frac_motif @ cell
892
893
    kwargs = {
894
        'name': block.name,
895
        'asymmetric_unit': asym_unit,
896
        'wyckoff_multiplicities': multiplicities
897
    }
898
899
    periodic_set = PeriodicSet(motif, cell, **kwargs)
900
    return periodic_set
901