Passed
Push — master ( dfd883...6a3dea )
by Daniel
01:56
created

amd.io._Reader.read_one()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Contains I/O tools, including a .CIF reader and CSD reader
2
(``csd-python-api`` only) to extract periodic set representations
3
of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`.
4
5
These intermediate :class:`.periodicset.PeriodicSet` representations can be written
6
to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`.
7
This is much faster than rereading a .CIF and recomputing invariants.
8
"""
9
10
import os
11
import warnings
12
from typing import Callable, Iterable, Sequence, Tuple, Optional
13
14
import numpy as np
15
import ase.spacegroup.spacegroup    # parse_sitesym
16
import ase.io.cif
17
import h5py
18
19
from .periodicset import PeriodicSet
20
from .utils import _extend_signature, cellpar_to_cell
21
22
try:
23
    import ccdc.io       # EntryReader
24
    import ccdc.search   # TextNumericSearch
25
    _CCDC_ENABLED = True
26
except (ImportError, RuntimeError) as exception:
27
    _CCDC_ENABLED = False
28
29
30
def _warning(message, category, filename, lineno, *args, **kwargs):
0 ignored issues
show
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
31
    return f'{filename}:{lineno}: {category.__name__}: {message}\n'
32
33
warnings.formatwarning = _warning
34
35
36
def _atom_has_disorder(label, occupancy):
37
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
38
39
40
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (9/7)
Loading history...
41
    """Base Reader class. Contains parsers for converting ase CifBlock
42
    and ccdc Entry objects to PeriodicSets.
43
44
    Intended use:
45
46
    First make a new method for _Reader converting object to PeriodicSet
47
    (e.g. named _X_to_PSet). Then make this class outline:
48
49
    class XReader(_Reader):
50
        def __init__(self, ..., **kwargs):
51
52
        super().__init__(**kwargs)
53
54
        # setup and checks
55
56
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
57
58
        # set self._generator like this
59
        self._generator = self._read(iterable, self._X_to_PSet)
60
    """
61
62
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
63
    reserved_tags = {
64
        'motif',
65
        'cell',
66
        'name',
67
        'asymmetric_unit',
68
        'wyckoff_multiplicities',
69
        'types',}
70
    atom_site_fract_tags = [
71
        '_atom_site_fract_x',
72
        '_atom_site_fract_y',
73
        '_atom_site_fract_z',]
74
    atom_site_cartn_tags = [
75
        '_atom_site_cartn_x',
76
        '_atom_site_cartn_y',
77
        '_atom_site_cartn_z',]
78
    symop_tags = [
79
        '_space_group_symop_operation_xyz',
80
        '_space_group_symop.operation_xyz',
81
        '_symmetry_equiv_pos_as_xyz',]
82
83
    equiv_site_tol = 1e-3
84
85
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
86
            self,
87
            remove_hydrogens=False,
88
            disorder='skip',
89
            heaviest_component=False,
90
            show_warnings=True,
91
            extract_data=None,
92
            include_if=None):
93
94
        # settings
95
        if disorder not in _Reader.disorder_options:
96
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
97
98
        if extract_data:
99
            if not isinstance(extract_data, dict):
100
                raise ValueError('extract_data must be a dict with callable values')
101
            for key in extract_data:
102
                if not callable(extract_data[key]):
103
                    raise ValueError('extract_data must be a dict with callable values')
104
                if key in _Reader.reserved_tags:
105
                    raise ValueError(f'extract_data includes reserved key {key}')
106
107
        if include_if:
108
            for func in include_if:
109
                if not callable(func):
110
                    raise ValueError('include_if must be a list of callables')
111
112
        self.remove_hydrogens = remove_hydrogens
113
        self.disorder = disorder
114
        self.heaviest_component = heaviest_component
115
        self.extract_data = extract_data
116
        self.include_if = include_if
117
        self.show_warnings = show_warnings
118
        self.current_identifier = None
119
        self.current_filename = None
120
        self._generator = []
121
122
    def __iter__(self):
123
        yield from self._generator
124
125
    def read_one(self):
126
        """Read the next (or first) item."""
127
        return next(iter(self._generator))
128
129
    # basically the builtin map, but skips items if the function returned None.
130
    # The object returned by this function (Iterable of PeriodicSets) is set to
131
    # self._generator; then iterating over the Reader iterates over
132
    # self._generator.
133
    @staticmethod
134
    def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
135
        """Iterates over iterable, passing items through parser and
136
        yielding the result if it is not None.
137
        """
138
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
139
        for item in iterable:
140
            res = func(item)
141
            if res is not None:
142
                yield res
143
144
    def expand(
145
            self,
146
            asym_frac_motif: np.ndarray,
147
            sitesym: Sequence[str]
148
    ) -> Tuple[np.ndarray, ...]:
149
        """
150
        Asymmetric unit's fractional coords + sitesyms (as strings)
151
        -->
152
        frac_motif, asym_unit, multiplicities, inverses
153
        """
154
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
155
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
156
        all_sites = []
157
        asym_unit = [0]
158
        multiplicities = []
159
        inverses = []
160
161
        for inv, site in enumerate(asym_frac_motif):
162
            multiplicity = 0
163
164
            for rot, trans in zip(rotations, translations):
165
                site_ = np.mod(np.dot(rot, site) + trans, 1)
166
167
                if not all_sites:
168
                    all_sites.append(site_)
169
                    inverses.append(inv)
170
                    multiplicity += 1
171
                    continue
172
173
                if not self._is_site_overlapping(site_, all_sites, inverses, inv):
174
                    all_sites.append(site_)
175
                    inverses.append(inv)
176
                    multiplicity += 1
177
178
            if multiplicity > 0:
179
                multiplicities.append(multiplicity)
180
                asym_unit.append(len(all_sites))
181
182
        frac_motif = np.array(all_sites)
183
        asym_unit = np.array(asym_unit[:-1])
184
        multiplicities = np.array(multiplicities)
185
        return frac_motif, asym_unit, multiplicities, inverses
186
187
    def _is_site_overlapping(self, new_site, all_sites, inverses, inv):
188
        """Return True (and warn) if new_site overlaps with a site in all_sites."""
189
        diffs1 = np.abs(new_site - all_sites)
190
        diffs2 = np.abs(diffs1 - 1)
191
        mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol,
192
                                    diffs2 <= _Reader.equiv_site_tol),
193
                        axis=-1)
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 2 spaces).
Loading history...
194
195
        if np.any(mask):
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
196
            where_equal = np.argwhere(mask).flatten()
197
            for ind in where_equal:
198
                if inverses[ind] == inv:
199
                    pass
200
                else:
201
                    if self.show_warnings:
202
                        warnings.warn(
203
                            f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (108/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
204
            return True
205
        else:
206
            return False
207
208
    def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet:
209
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
210
211
        # skip if structure does not pass checks in include_if
212
        if self.include_if:
213
            if not all(check(block) for check in self.include_if):
214
                return None
215
216
        # read name, cell, asym motif and atomic symbols
217
        self.current_identifier = block.name
218
        cell = block.get_cell().array
219
        asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
220
        if None in asym_frac_motif:
221
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
222
            if None in asym_motif:
223
                if self.show_warnings:
224
                    warnings.warn(
225
                        f'Skipping {self.current_identifier} as coordinates were not found')
226
                return None
227
            asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
228
        asym_frac_motif = np.array(asym_frac_motif).T
229
230
        try:
231
            asym_symbols = block.get_symbols()
232
        except ase.io.cif.NoStructureData as _:
233
            asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))]
234
235
        # indices of sites to remove
236
        remove = []
237
        if self.remove_hydrogens:
238
            remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
239
240
        # find disordered sites
241
        asym_is_disordered = []
242
        occupancies = block.get('_atom_site_occupancy')
243
        labels = block.get('_atom_site_label')
244
        if occupancies is not None:
245
            disordered = []     # indices where there is disorder
246
            for i, (occ, label) in enumerate(zip(occupancies, labels)):
247
                if _atom_has_disorder(label, occ):
248
                    if i not in remove:
249
                        disordered.append(i)
250
                        asym_is_disordered.append(True)
251
                else:
252
                    asym_is_disordered.append(False)
253
254
            if self.disorder == 'skip' and len(disordered) > 0:
255
                if self.show_warnings:
256
                    warnings.warn(
257
                        f'Skipping {self.current_identifier} as structure is disordered')
258
                return None
259
260
            if self.disorder == 'ordered_sites':
261
                remove.extend(disordered)
262
263
        # remove sites
264
        asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1)
265
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove]
266
        asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove]
267
268
        # if there are overlapping sites in asym unit, warn and keep only one
269
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
270
        site_diffs2 = np.abs(site_diffs1 - 1)
271
        overlapping = np.triu(np.all(
272
            (site_diffs1 <= _Reader.equiv_site_tol) |
273
            (site_diffs2 <= _Reader.equiv_site_tol), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
274
            axis=-1), 1)
275
276
        # don't remove overlapping sites if one is disordered and disorder='all_sites'
277
        if self.disorder == 'all_sites':
278
            for i, j in np.argwhere(overlapping):
279
                if asym_is_disordered[i] or asym_is_disordered[j]:
280
                    overlapping[i, j] = False
281
282
        if overlapping.any():
283
            if self.show_warnings:
284
                warnings.warn(
285
                    f'{self.current_identifier} may have overlapping sites; duplicates will be removed')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
286
            keep_sites = ~overlapping.any(0)
287
            asym_frac_motif = asym_frac_motif[keep_sites]
288
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
289
290
        # if no points left in motif, skip structure
291
        if asym_frac_motif.shape[0] == 0:
292
            if self.show_warnings:
293
                warnings.warn(
294
                    f'Skipping {self.current_identifier} as there are no sites with coordinates')
295
            return None
296
297
        # get symmetries
298
        sitesym = ['x,y,z', ]
299
        for tag in _Reader.symop_tags:
300
            if tag in block:
301
                sitesym = block[tag]
302
                break
303
304
        if isinstance(sitesym, str):
305
            sitesym = [sitesym]
306
307
        # expand the asymmetric unit to full motif + multiplicities
308
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
309
        motif = frac_motif @ cell
310
311
        # construct PeriodicSet
312
        kwargs = {
313
            'name': self.current_identifier, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
314
            'asymmetric_unit': asym_unit,
315
            'wyckoff_multiplicities': multiplicities,
316
            'types': [asym_symbols[i] for i in inverses],
317
        }
318
319
        if self.current_filename:
320
            kwargs['filename'] = self.current_filename
321
322
        if self.extract_data is not None:
323
            for key in self.extract_data:
324
                kwargs[key] = self.extract_data[key](block)
325
326
        periodic_set = PeriodicSet(motif, cell, **kwargs)
327
        return periodic_set
328
329
    def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet:
0 ignored issues
show
best-practice introduced by
Too many return statements (8/6)
Loading history...
330
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
331
332
        # skip if structure does not pass checks in include_if
333
        if self.include_if:
334
            if not all(check(entry) for check in self.include_if):
335
                return None
336
337
        self.current_identifier = entry.identifier
338
        # structure must pass this test
339
        if not entry.has_3d_structure:
340
            if self.show_warnings:
341
                warnings.warn(
342
                    f'Skipping {self.current_identifier} as entry has no 3D structure')
343
            return None
344
345
        try:
346
            crystal = entry.crystal
347
        except RuntimeError as e:
348
            if self.show_warnings:
349
                warnings.warn(f'Skipping {self.current_identifier}: {e}')
350
            return None
351
352
        # first disorder check, if skipping. If occ == 1 for all atoms but the entry
353
        # or crystal is listed as having disorder, skip (can't know where disorder is).
354
        # If occ != 1 for any atoms, we wait to see if we remove them before skipping.
355
        molecule = crystal.disordered_molecule
356
        if self.disorder == 'ordered_sites':
357
            molecule.remove_atoms(
358
                a for a in molecule.atoms if a.label.endswith('?'))
359
360
        may_have_disorder = False
361
        if self.disorder == 'skip':
362
            for a in molecule.atoms:
363
                occ = a.occupancy
364
                if _atom_has_disorder(a.label, occ):
365
                    may_have_disorder = True
366
                    break
367
368
            if not may_have_disorder:
369
                if crystal.has_disorder or entry.has_disorder:
370
                    if self.show_warnings:
371
                        warnings.warn(
372
                            f'Skipping {self.current_identifier} as structure is disordered')
373
                    return None
374
375
        if self.remove_hydrogens:
376
            molecule.remove_atoms(
377
                a for a in molecule.atoms if a.atomic_symbol in 'HD')
378
379
        # heaviest component (removes all but the heaviest component of the asym unit)
380
        # intended for removing solvents. probably doesn't play well with disorder
381
        if self.heaviest_component:
0 ignored issues
show
unused-code introduced by
Too many nested blocks (6/5)
Loading history...
382
            if len(molecule.components) > 1:
383
                component_weights = []
384
                for component in molecule.components:
385
                    weight = 0
386
                    for a in component.atoms:
387
                        if isinstance(a.atomic_weight, (float, int)):
388
                            if isinstance(a.occupancy, (float, int)):
389
                                weight += a.occupancy * a.atomic_weight
390
                            else:
391
                                weight += a.atomic_weight
392
                    component_weights.append(weight)
393
                largest_component_arg = np.argmax(np.array(component_weights))
394
                molecule = molecule.components[largest_component_arg]
395
396
        crystal.molecule = molecule
397
398
        # by here all atoms to be removed have been (except via ordered_sites).
399
        # If disorder == 'skip' and there were atom(s) with occ < 1 found
400
        # eariler, we check if all such atoms were removed. If not, skip.
401
        if self.disorder == 'skip' and may_have_disorder:
402
            for a in crystal.disordered_molecule.atoms:
403
                occ = a.occupancy
404
                if _atom_has_disorder(a.label, occ):
405
                    if self.show_warnings:
406
                        warnings.warn(
407
                            f'Skipping {self.current_identifier} as structure is disordered')
408
                    return None
409
410
        # if disorder is all_sites, we need to know where disorder is to ignore overlaps
411
        asym_is_disordered = []     # True/False list same length as asym unit
412
        if self.disorder == 'all_sites':
413
            for a in crystal.asymmetric_unit_molecule.atoms:
414
                occ = a.occupancy
415
                if _atom_has_disorder(a.label, occ):
416
                    asym_is_disordered.append(True)
417
                else:
418
                    asym_is_disordered.append(False)
419
420
        # check all atoms have coords. option/default remove unknown sites?
421
        if not molecule.all_atoms_have_sites or \
422
           any(a.fractional_coordinates is None for a in molecule.atoms):
423
            if self.show_warnings:
424
                warnings.warn(
425
                    f'Skipping {self.current_identifier} as some atoms do not have sites')
426
            return None
427
428
        # get cell & asymmetric unit
429
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
430
        asym_frac_motif = np.array([tuple(a.fractional_coordinates)
431
                                    for a in crystal.asymmetric_unit_molecule.atoms])
432
        asym_frac_motif = np.mod(asym_frac_motif, 1)
433
        asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms]
434
435
        # if there are overlapping sites in asym unit, warn and keep only one
436
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
437
        site_diffs2 = np.abs(site_diffs1 - 1)
438
        overlapping = np.triu(np.all(
439
            (site_diffs1 <= _Reader.equiv_site_tol) |
440
            (site_diffs2 <= _Reader.equiv_site_tol),
441
            axis=-1), 1)
442
443
        # don't remove overlapping sites if one is disordered and disorder='all_sites'
444
        if self.disorder == 'all_sites':
445
            for i, j in np.argwhere(overlapping):
446
                if asym_is_disordered[i] or asym_is_disordered[j]:
447
                    overlapping[i, j] = False
448
449
        if overlapping.any():
450
            if self.show_warnings:
451
                warnings.warn(
452
                    f'{self.current_identifier} may have overlapping sites; '
453
                    'duplicates will be removed')
454
            keep_sites = ~overlapping.any(0)
455
            asym_frac_motif = asym_frac_motif[keep_sites]
456
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
457
458
        # if no points left in motif, skip structure
459
        if asym_frac_motif.shape[0] == 0:
460
            if self.show_warnings:
461
                warnings.warn(
462
                    f'Skipping {self.current_identifier} as there are no sites with coordinates')
463
            return None
464
465
        # get symmetries, expand the asymmetric unit to full motif + multiplicities
466
        sitesym = crystal.symmetry_operators
467
        if not sitesym:
468
            sitesym = ('x,y,z', )
469
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
470
        motif = frac_motif @ cell
471
472
        # construct PeriodicSet
473
        kwargs = {
474
            'name': self.current_identifier,
475
            'asymmetric_unit': asym_unit,
476
            'wyckoff_multiplicities': multiplicities,
477
            'types': [asym_symbols[i] for i in inverses],
478
        }
479
480
        if self.current_filename:
481
            kwargs['filename'] = self.current_filename
482
483
        if self.extract_data is not None:
484
            entry.crystal.molecule = crystal.disordered_molecule
485
            for key in self.extract_data:
486
                kwargs[key] = self.extract_data[key](entry)
487
488
        periodic_set = PeriodicSet(motif, cell, **kwargs)
489
        return periodic_set
490
491
492
class CifReader(_Reader):
493
    """Read all structures in a .CIF with ``ase`` or ``ccdc``
494
    (``csd-python-api`` only), yielding  :class:`.periodicset.PeriodicSet`
495
    objects which can be passed to :func:`.calculate.AMD` or
496
    :func:`.calculate.PDD`.
497
498
    Examples:
499
500
        ::
501
502
            # Put all crystals in a .CIF in a list
503
            structures = list(amd.CifReader('mycif.cif'))
504
505
            # Reads just one if the .CIF has just one crystal
506
            periodic_set = amd.CifReader('mycif.cif').read_one()
507
508
            # If a folder has several .CIFs each with one crystal, use
509
            structures = list(amd.CifReader('path/to/folder', folder=True))
510
511
            # Make list of AMDs (with k=100) of crystals in a .CIF
512
            amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')]
513
    """
514
515
    @_extend_signature(_Reader.__init__)
516
    def __init__(
517
            self,
518
            path,
519
            reader='ase',
520
            folder=False,
521
            **kwargs):
522
523
        super().__init__(**kwargs)
524
525
        if reader not in ('ase', 'ccdc'):
526
            raise ValueError(f'Invalid reader {reader}; must be ase or ccdc.')
527
528
        if reader == 'ase' and self.heaviest_component:
529
            raise NotImplementedError('Parameter heaviest_component not implimented for ase, only ccdc.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
530
531
        if reader == 'ase':
532
            extensions = {'cif'}
533
            file_parser = ase.io.cif.parse_cif
534
            pset_converter = self._CIFBlock_to_PeriodicSet
535
536
        elif reader == 'ccdc':
537
            if not _CCDC_ENABLED:
538
                raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
539
            extensions = ccdc.io.EntryReader.known_suffixes
540
            file_parser = ccdc.io.EntryReader
541
            pset_converter = self._Entry_to_PeriodicSet
542
543
        if folder:
544
            generator = self._folder_generator(path, file_parser, extensions)
0 ignored issues
show
introduced by
The variable file_parser does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable extensions does not seem to be defined for all execution paths.
Loading history...
545
        else:
546
            generator = file_parser(path)
547
548
        self._generator = self._map(pset_converter, generator)
0 ignored issues
show
introduced by
The variable pset_converter does not seem to be defined for all execution paths.
Loading history...
549
550
    def _folder_generator(self, path, file_parser, extensions):
551
        for file in os.listdir(path):
552
            suff = os.path.splitext(file)[1][1:]
553
            if suff.lower() in extensions:
554
                self.current_filename = file
555
                yield from file_parser(os.path.join(path, file))
556
557
558
class CSDReader(_Reader):
559
    """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects.
560
561
    The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed
562
    to :func:`.calculate.AMD` or :func:`.calculate.PDD`.
563
564
    Examples:
565
566
        Get crystals with refcodes in a list::
567
568
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
569
            structures = list(amd.CSDReader(refcodes))
570
571
        Read refcode families (any whose refcode starts with strings in the list)::
572
573
            refcodes = ['ACSALA', 'HXACAN']
574
            structures = list(amd.CSDReader(refcodes, families=True))
575
576
        Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`::
577
578
            reader = amd.CSDReader()
579
            debxit01 = reader.entry('DEBXIT01')
580
581
            # looping over this generic reader will yield all CSD entries
582
            for periodic_set in reader:
583
                ...
584
585
        Make list of AMD (with k=100) for crystals in these families::
586
587
            refcodes = ['ACSALA', 'HXACAN']
588
            amds = []
589
            for periodic_set in amd.CSDReader(refcodes, families=True):
590
                amds.append(amd.AMD(periodic_set, 100))
591
    """
592
593
    @_extend_signature(_Reader.__init__)
594
    def __init__(
595
            self,
596
            refcodes=None,
597
            families=False,
598
            **kwargs):
599
600
        if not _CCDC_ENABLED:
601
            raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (101/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
602
603
        super().__init__(**kwargs)
604
605
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
606
            refcodes = None
607
608
        if refcodes is None:
609
            families = False
610
        else:
611
            refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes)
612
613
        # families parameter reads all crystals with ids starting with passed refcodes
614
        if families:
615
            all_refcodes = []
616
            for refcode in refcodes:
617
                query = ccdc.search.TextNumericSearch()
618
                query.add_identifier(refcode)
619
                all_refcodes.extend((hit.identifier for hit in query.search()))
0 ignored issues
show
introduced by
The variable hit does not seem to be defined in case the for loop on line 616 is not entered. Are you sure this can never be the case?
Loading history...
620
621
            # filter to unique refcodes
622
            seen = set()
623
            seen_add = seen.add
624
            refcodes = [
625
                refcode for refcode in all_refcodes
626
                if not (refcode in seen or seen_add(refcode))]
627
628
        self._entry_reader = ccdc.io.EntryReader('CSD')
629
        self._generator = self._map(
630
            self._Entry_to_PeriodicSet,
631
            self._ccdc_generator(refcodes))
632
633
    def _ccdc_generator(self, refcodes):
634
        """Generates ccdc Entries from CSD refcodes"""
635
636
        if refcodes is None:
637
            for entry in self._entry_reader:
638
                yield entry
639
        else:
640
            for refcode in refcodes:
641
                try:
642
                    entry = self._entry_reader.entry(refcode)
643
                    yield entry
644
                except RuntimeError:
645
                    warnings.warn(
646
                        f'Identifier {refcode} not found in database')
647
648
    def entry(self, refcode: str) -> PeriodicSet:
649
        """Read a PeriodicSet given any CSD refcode."""
650
651
        entry = self._entry_reader.entry(refcode)
652
        periodic_set = self._Entry_to_PeriodicSet(entry)
653
        return periodic_set
654
655
656
class SetWriter:
657
    """Write several :class:`.periodicset.PeriodicSet` objects to a .hdf5 file.
658
    Reading the .hdf5 is much faster than parsing a .CIF file.
659
660
    Examples:
661
662
        Write the crystals in mycif.cif to a .hdf5 file::
663
664
            with amd.SetWriter('crystals.hdf5') as writer:
665
666
                for periodic_set in amd.CifReader('mycif.cif'):
667
                    writer.write(periodic_set)
668
669
                # use iwrite to write straight from an iterator
670
                # below is equivalent to the above loop
671
                writer.iwrite(amd.CifReader('mycif.cif'))
672
673
    Read the crystals back from the file with :class:`SetReader`.
674
    """
675
676
    _str_dtype = h5py.vlen_dtype(str)
677
678
    def __init__(self, filename: str):
679
680
        self.file = h5py.File(filename, 'w', track_order=True)
681
682
    def write(self, periodic_set: PeriodicSet, name: Optional[str] = None):
683
        """Write a PeriodicSet object to file."""
684
685
        if not isinstance(periodic_set, PeriodicSet):
686
            raise ValueError(
687
                f'Object type {periodic_set.__class__.__name__} cannot be written with SetWriter')
688
689
        # need a name to store or you can't access items by key
690
        if name is None:
691
            if periodic_set.name is None:
692
                raise ValueError(
693
                    'Periodic set must have a name to be written. Either set the name '
694
                    'attribute of the PeriodicSet or pass a name to SetWriter.write()')
695
            name = periodic_set.name
696
697
        # this group is the PeriodicSet
698
        group = self.file.create_group(name)
699
700
        # datasets in the group for motif and cell
701
        group.create_dataset('motif', data=periodic_set.motif)
702
        group.create_dataset('cell', data=periodic_set.cell)
703
704
        if periodic_set.tags:
705
            # a subgroup contains tags that are lists or ndarrays
706
            tags_group = group.create_group('tags')
707
708
            for tag in periodic_set.tags:
709
                data = periodic_set.tags[tag]
710
711
                if data is None:               # nonce to handle None
712
                    tags_group.attrs[tag] = '__None'
713
                elif np.isscalar(data):        # scalars (nums and strs) stored as attrs
714
                    tags_group.attrs[tag] = data
715
                elif isinstance(data, np.ndarray):
716
                    tags_group.create_dataset(tag, data=data)
717
                elif isinstance(data, list):
718
                    # lists of strings stored as special type for some reason
719
                    if any(isinstance(d, str) for d in data):
720
                        data = [str(d) for d in data]
721
                        tags_group.create_dataset(tag,
722
                                                  data=data,
723
                                                  dtype=SetWriter._str_dtype)
724
                    else:    # other lists must be castable to ndarray
725
                        data = np.asarray(data)
726
                        tags_group.create_dataset(tag, data=np.array(data))
727
                else:
728
                    raise ValueError(
729
                        f'Cannot store tag of type {type(data)} with SetWriter')
730
731
    def iwrite(self, periodic_sets: Iterable[PeriodicSet]):
732
        """Write :class:`.periodicset.PeriodicSet` objects from an iterable to file."""
733
        for periodic_set in periodic_sets:
734
            self.write(periodic_set)
735
736
    def close(self):
737
        """Close the :class:`SetWriter`."""
738
        self.file.close()
739
740
    def __enter__(self):
741
        return self
742
743
    # handle exceptions?
744
    def __exit__(self, exc_type, exc_value, tb):
745
        self.file.close()
746
747
748
class SetReader:
749
    """Read :class:`.periodicset.PeriodicSet` objects from a .hdf5 file written
750
    with :class:`SetWriter`. Acts like a read-only dict that can be iterated
751
    over (preserves write order).
752
753
    Examples:
754
755
        Get PDDs (k=100) of crystals in crystals.hdf5::
756
757
            pdds = []
758
            with amd.SetReader('crystals.hdf5') as reader:
759
                for periodic_set in reader:
760
                    pdds.append(amd.PDD(periodic_set, 100))
761
762
            # above is equivalent to:
763
            pdds = [amd.PDD(pset, 100) for pset in amd.SetReader('crystals.hdf5')]
764
    """
765
766
    def __init__(self, filename: str):
767
768
        self.file = h5py.File(filename, 'r', track_order=True)
769
770
    def _get_set(self, name: str) -> PeriodicSet:
771
        # take a name in the set and return the PeriodicSet
772
        group = self.file[name]
773
        periodic_set = PeriodicSet(group['motif'][:], group['cell'][:], name=name)
774
775
        if 'tags' in group:
776
            for tag in group['tags']:
777
                data = group['tags'][tag][:]
778
779
                if any(isinstance(d, (bytes, bytearray)) for d in data):
780
                    periodic_set.tags[tag] = [d.decode() for d in data]
781
                else:
782
                    periodic_set.tags[tag] = data
783
784
            for attr in group['tags'].attrs:
785
                data = group['tags'].attrs[attr]
786
                periodic_set.tags[attr] = None if data == '__None' else data
787
788
        return periodic_set
789
790
    def close(self):
791
        """Close the :class:`SetReader`."""
792
        self.file.close()
793
794
    def family(self, refcode: str) -> Iterable[PeriodicSet]:
795
        """Yield any :class:`.periodicset.PeriodicSet` whose name starts with
796
        input refcode."""
797
        for name in self.keys():
798
            if name.startswith(refcode):
799
                yield self._get_set(name)
800
801
    def __getitem__(self, name):
802
        # index by name. Not found exc?
803
        return self._get_set(name)
804
805
    def __len__(self):
806
        return len(self.keys())
807
808
    def __iter__(self):
809
        # interface to loop over the SetReader; does not close the SetReader when done
810
        for name in self.keys():
811
            yield self._get_set(name)
812
813
    def __contains__(self, item):
814
        return bool(item in self.keys())
815
816
    def keys(self):
817
        """Yield names of items in the :class:`SetReader`."""
818
        return self.file['/'].keys()
819
820
    def __enter__(self):
821
        return self
822
823
    # handle exceptions?
824
    def __exit__(self, exc_type, exc_value, tb):
825
        self.file.close()
826
827
828
def crystal_to_periodicset(crystal):
829
    """ccdc.crystal.Crystal --> amd.periodicset.PeriodicSet.
830
    Ignores disorder, missing sites/coords, checks & no options.
831
    Is a stripped-down version of the function used in CifReader."""
832
833
    cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
834
835
    # asymmetric unit fractional coordinates
836
    asym_frac_motif = np.array([tuple(a.fractional_coordinates)
837
                                for a in crystal.asymmetric_unit_molecule.atoms])
838
    asym_frac_motif = np.mod(asym_frac_motif, 1)
839
840
    # if the above removed everything, skip this structure
841
    if asym_frac_motif.shape[0] == 0:
842
        raise ValueError(f'{crystal.identifier} has no coordinates')
843
844
    sitesym = crystal.symmetry_operators
845
    if not sitesym:
846
        sitesym = ('x,y,z', )
847
    r = _Reader()
848
    r.current_identifier = crystal.identifier
849
    frac_motif, asym_unit, multiplicities, _ = r.expand(asym_frac_motif, sitesym)
850
    motif = frac_motif @ cell
851
852
    kwargs = {
853
        'name': crystal.identifier,
854
        'asymmetric_unit': asym_unit,
855
        'wyckoff_multiplicities': multiplicities,
856
    }
857
858
    periodic_set = PeriodicSet(motif, cell, **kwargs)
859
    return periodic_set
860
861
862
def cifblock_to_periodicset(block):
863
    """ase.io.cif.CIFBlock --> amd.periodicset.PeriodicSet.
864
    Ignores disorder, missing sites/coords, checks & no options.
865
    Is a stripped-down version of the function used in CifReader."""
866
867
    cell = block.get_cell().array
868
    asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
869
870
    if None in asym_frac_motif:
871
        asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
872
        if None in asym_motif:
873
            warnings.warn(
874
                f'Skipping {block.name} as coordinates were not found')
875
            return None
876
877
        asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
878
879
    asym_frac_motif = np.mod(np.array(asym_frac_motif).T, 1)
880
881
    if asym_frac_motif.shape[0] == 0:
882
        raise ValueError(f'{block.name} has no coordinates')
883
884
    sitesym = ('x,y,z', )
885
    for tag in _Reader.symop_tags:
886
        if tag in block:
887
            sitesym = block[tag]
888
            break
889
890
    if isinstance(sitesym, str):
891
        sitesym = [sitesym]
892
893
    dummy_reader = _Reader()
894
    dummy_reader.current_identifier = block.name
895
    frac_motif, asym_unit, multiplicities, _ = dummy_reader.expand(asym_frac_motif, sitesym)
896
    motif = frac_motif @ cell
897
898
    kwargs = {
899
        'name': block.name,
900
        'asymmetric_unit': asym_unit,
901
        'wyckoff_multiplicities': multiplicities
902
    }
903
904
    periodic_set = PeriodicSet(motif, cell, **kwargs)
905
    return periodic_set
906