Passed
Push — master ( 1d1d87...c4cb0d )
by Daniel
02:01
created

amd.io.SetReader.family()   A

Complexity

Conditions 3

Size

Total Lines 6
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 3
nop 2
1
"""Contains I/O tools, including a .CIF reader and CSD reader
2
(``csd-python-api`` only) to extract periodic set representations
3
of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`.
4
5
These intermediate :class:`.periodicset.PeriodicSet` representations can be written
6
to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`.
7
This is much faster than rereading a .CIF and recomputing invariants.
8
"""
9
10
import os
11
import warnings
12
from typing import Callable, Iterable, Sequence, Tuple, Optional
0 ignored issues
show
Unused Code introduced by
Unused Optional imported from typing
Loading history...
13
14
import numpy as np
15
import ase.spacegroup.spacegroup    # parse_sitesym
16
import ase.io.cif
17
import h5py
0 ignored issues
show
Unused Code introduced by
The import h5py seems to be unused.
Loading history...
18
19
from .periodicset import PeriodicSet
20
from .utils import _extend_signature, cellpar_to_cell
21
22
try:
23
    import ccdc.io       # EntryReader
24
    import ccdc.search   # TextNumericSearch
25
    _CCDC_ENABLED = True
26
except (ImportError, RuntimeError) as exception:
27
    _CCDC_ENABLED = False
28
29
30
def _warning(message, category, filename, lineno, *args, **kwargs):
0 ignored issues
show
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
31
    return f'{filename}:{lineno}: {category.__name__}: {message}\n'
32
33
warnings.formatwarning = _warning
34
35
36
def _atom_has_disorder(label, occupancy):
37
    return label.endswith('?') or (np.isscalar(occupancy) and occupancy < 1)
38
39
40
class _Reader:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (9/7)
Loading history...
41
    """Base Reader class. Contains parsers for converting ase CifBlock
42
    and ccdc Entry objects to PeriodicSets.
43
44
    Intended use:
45
46
    First make a new method for _Reader converting object to PeriodicSet
47
    (e.g. named _X_to_PSet). Then make this class outline:
48
49
    class XReader(_Reader):
50
        def __init__(self, ..., **kwargs):
51
52
        super().__init__(**kwargs)
53
54
        # setup and checks
55
56
        # make 'iterable' which yields objects to be converted (e.g. CIFBlock, Entry)
57
58
        # set self._generator like this
59
        self._generator = self._read(iterable, self._X_to_PSet)
60
    """
61
62
    disorder_options = {'skip', 'ordered_sites', 'all_sites'}
63
    reserved_tags = {
64
        'motif',
65
        'cell',
66
        'name',
67
        'asymmetric_unit',
68
        'wyckoff_multiplicities',
69
        'types',
70
        'filename',}
71
    atom_site_fract_tags = [
72
        '_atom_site_fract_x',
73
        '_atom_site_fract_y',
74
        '_atom_site_fract_z',]
75
    atom_site_cartn_tags = [
76
        '_atom_site_cartn_x',
77
        '_atom_site_cartn_y',
78
        '_atom_site_cartn_z',]
79
    symop_tags = [
80
        '_space_group_symop_operation_xyz',
81
        '_space_group_symop.operation_xyz',
82
        '_symmetry_equiv_pos_as_xyz',]
83
84
    equiv_site_tol = 1e-3
85
86
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
87
            self,
88
            remove_hydrogens=False,
89
            disorder='skip',
90
            heaviest_component=False,
91
            show_warnings=True,
92
            extract_data=None,
93
            include_if=None):
94
95
        # settings
96
        if disorder not in _Reader.disorder_options:
97
            raise ValueError(f'disorder parameter {disorder} must be one of {_Reader.disorder_options}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
98
99
        if extract_data:
100
            if not isinstance(extract_data, dict):
101
                raise ValueError('extract_data must be a dict with callable values')
102
            for key in extract_data:
103
                if not callable(extract_data[key]):
104
                    raise ValueError('extract_data must be a dict with callable values')
105
                if key in _Reader.reserved_tags:
106
                    raise ValueError(f'extract_data includes reserved key {key}')
107
108
        if include_if:
109
            for func in include_if:
110
                if not callable(func):
111
                    raise ValueError('include_if must be a list of callables')
112
113
        self.remove_hydrogens = remove_hydrogens
114
        self.disorder = disorder
115
        self.heaviest_component = heaviest_component
116
        self.extract_data = extract_data
117
        self.include_if = include_if
118
        self.show_warnings = show_warnings
119
        self.current_identifier = None
120
        self.current_filename = None
121
        self._generator = []
122
123
    def __iter__(self):
124
        yield from self._generator
125
126
    def read_one(self):
127
        """Read the next (or first) item."""
128
        return next(iter(self._generator))
129
130
    # basically the builtin map, but skips items if the function returned None.
131
    # The object returned by this function (Iterable of PeriodicSets) is set to
132
    # self._generator; then iterating over the Reader iterates over
133
    # self._generator.
134
    @staticmethod
135
    def _map(func: Callable, iterable: Iterable) -> Iterable[PeriodicSet]:
136
        """Iterates over iterable, passing items through parser and
137
        yielding the result if it is not None.
138
        """
139
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
140
        for item in iterable:
141
            res = func(item)
142
            if res is not None:
143
                yield res
144
145
    def _CIFBlock_to_PeriodicSet(self, block) -> PeriodicSet:
146
        """ase.io.cif.CIFBlock --> PeriodicSet. Returns None for a "bad" set."""
147
148
        # skip if structure does not pass checks in include_if
149
        if self.include_if:
150
            if not all(check(block) for check in self.include_if):
151
                return None
152
153
        # read name, cell, asym motif and atomic symbols
154
        self.current_identifier = block.name
155
        cell = block.get_cell().array
156
        asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
157
        if None in asym_frac_motif:
158
            asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
159
            if None in asym_motif:
160
                if self.show_warnings:
161
                    warnings.warn(
162
                        f'Skipping {self.current_identifier} as coordinates were not found')
163
                return None
164
            asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
165
        asym_frac_motif = np.array(asym_frac_motif).T
166
167
        try:
168
            asym_symbols = block.get_symbols()
169
        except ase.io.cif.NoStructureData as _:
170
            asym_symbols = ['Unknown' for _ in range(len(asym_frac_motif))]
171
172
        # indices of sites to remove
173
        remove = []
174
        if self.remove_hydrogens:
175
            remove.extend((i for i, sym in enumerate(asym_symbols) if sym in 'HD'))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
176
177
        # find disordered sites
178
        asym_is_disordered = []
179
        occupancies = block.get('_atom_site_occupancy')
180
        labels = block.get('_atom_site_label')
181
        if occupancies is not None:
182
            disordered = []     # indices where there is disorder
183
            for i, (occ, label) in enumerate(zip(occupancies, labels)):
184
                if _atom_has_disorder(label, occ):
185
                    if i not in remove:
186
                        disordered.append(i)
187
                        asym_is_disordered.append(True)
188
                else:
189
                    asym_is_disordered.append(False)
190
191
            if self.disorder == 'skip' and len(disordered) > 0:
192
                if self.show_warnings:
193
                    warnings.warn(
194
                        f'Skipping {self.current_identifier} as structure is disordered')
195
                return None
196
197
            if self.disorder == 'ordered_sites':
198
                remove.extend(disordered)
199
200
        # remove sites
201
        asym_frac_motif = np.mod(np.delete(asym_frac_motif, remove, axis=0), 1)
202
        asym_symbols = [s for i, s in enumerate(asym_symbols) if i not in remove]
203
        asym_is_disordered = [v for i, v in enumerate(asym_is_disordered) if i not in remove]
204
205
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
206
        if keep_sites is not None:
207
            asym_frac_motif = asym_frac_motif[keep_sites]
208
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
209
210
        if self._has_no_valid_sites(asym_frac_motif):
211
            return None
212
213
        sitesym = ['x,y,z', ]
214
        for tag in _Reader.symop_tags:
215
            if tag in block:
216
                sitesym = block[tag]
217
                break
218
219
        if isinstance(sitesym, str):
220
            sitesym = [sitesym]
221
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
222
        return self._construct_periodic_set(block, asym_frac_motif, asym_symbols, sitesym, cell)
223
224
225
    def _Entry_to_PeriodicSet(self, entry) -> PeriodicSet:
0 ignored issues
show
best-practice introduced by
Too many return statements (7/6)
Loading history...
226
        """ccdc.entry.Entry --> PeriodicSet. Returns None for a "bad" set."""
227
228
        # skip if structure does not pass checks in include_if
229
        if self.include_if:
230
            if not all(check(entry) for check in self.include_if):
231
                return None
232
233
        self.current_identifier = entry.identifier
234
        # structure must pass this test
235
        if not entry.has_3d_structure:
236
            if self.show_warnings:
237
                warnings.warn(
238
                    f'Skipping {self.current_identifier} as entry has no 3D structure')
239
            return None
240
241
        crystal = entry.crystal
242
243
        # first disorder check, if skipping. If occ == 1 for all atoms but the entry
244
        # or crystal is listed as having disorder, skip (can't know where disorder is).
245
        # If occ != 1 for any atoms, we wait to see if we remove them before skipping.
246
        molecule = crystal.disordered_molecule
247
        if self.disorder == 'ordered_sites':
248
            molecule.remove_atoms(a for a in molecule.atoms if a.label.endswith('?'))
249
250
        may_have_disorder = False
251
        if self.disorder == 'skip':
252
            for a in molecule.atoms:
253
                occ = a.occupancy
254
                if _atom_has_disorder(a.label, occ):
255
                    may_have_disorder = True
256
                    break
257
258
            if not may_have_disorder:
259
                if crystal.has_disorder or entry.has_disorder:
260
                    if self.show_warnings:
261
                        warnings.warn(f'Skipping {self.current_identifier} as structure is disordered')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (103/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
262
                    return None
263
264
        if self.remove_hydrogens:
265
            molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in 'HD')
266
267
        if self.heaviest_component:
268
            molecule = _Reader._heaviest_component(molecule)
269
270
        crystal.molecule = molecule
271
272
        # by here all atoms to be removed have been (except via ordered_sites).
273
        # If disorder == 'skip' and there were atom(s) with occ < 1 found
274
        # eariler, we check if all such atoms were removed. If not, skip.
275
        if self.disorder == 'skip' and may_have_disorder:
276
            for a in crystal.disordered_molecule.atoms:
277
                occ = a.occupancy
278
                if _atom_has_disorder(a.label, occ):
279
                    if self.show_warnings:
280
                        warnings.warn(
281
                            f'Skipping {self.current_identifier} as structure is disordered')
282
                    return None
283
284
        # if disorder is all_sites, we need to know where disorder is to ignore overlaps
285
        asym_is_disordered = []     # True/False list same length as asym unit
286
        if self.disorder == 'all_sites':
287
            for a in crystal.asymmetric_unit_molecule.atoms:
288
                occ = a.occupancy
289
                if _atom_has_disorder(a.label, occ):
290
                    asym_is_disordered.append(True)
291
                else:
292
                    asym_is_disordered.append(False)
293
294
        # check all atoms have coords. option/default remove unknown sites?
295
        if not molecule.all_atoms_have_sites or \
296
           any(a.fractional_coordinates is None for a in molecule.atoms):
297
            if self.show_warnings:
298
                warnings.warn(
299
                    f'Skipping {self.current_identifier} as some atoms do not have sites')
300
            return None
301
302
        # get cell & asymmetric unit
303
        cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
304
        asym_frac_motif = np.array([tuple(a.fractional_coordinates)
305
                                    for a in crystal.asymmetric_unit_molecule.atoms])
306
        asym_frac_motif = np.mod(asym_frac_motif, 1)
307
        asym_symbols = [a.atomic_symbol for a in crystal.asymmetric_unit_molecule.atoms]
308
309
        # remove overlapping sites, check sites exist
310
        keep_sites = self._validate_sites(asym_frac_motif, asym_is_disordered)
311
        if keep_sites is not None:
312
            asym_frac_motif = asym_frac_motif[keep_sites]
313
            asym_symbols = [sym for sym, keep in zip(asym_symbols, keep_sites) if keep]
314
315
        if self._has_no_valid_sites(asym_frac_motif):
316
            return None
317
318
        sitesym = crystal.symmetry_operators
319
        if not sitesym:
320
            sitesym = ['x,y,z', ]
321
322
        entry.crystal.molecule = crystal.disordered_molecule    # for extract_data. remove?
323
324
        return self._construct_periodic_set(entry, asym_frac_motif, asym_symbols, sitesym, cell)
325
326
    def expand(
327
            self,
328
            asym_frac_motif: np.ndarray,
329
            sitesym: Sequence[str]
330
    ) -> Tuple[np.ndarray, ...]:
331
        """
332
        Asymmetric unit's fractional coords + sitesyms (as strings)
333
        -->
334
        frac_motif, asym_unit, multiplicities, inverses
335
        """
336
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
337
        rotations, translations = ase.spacegroup.spacegroup.parse_sitesym(sitesym)
338
        all_sites = []
339
        asym_unit = [0]
340
        multiplicities = []
341
        inverses = []
342
343
        for inv, site in enumerate(asym_frac_motif):
344
            multiplicity = 0
345
346
            for rot, trans in zip(rotations, translations):
347
                site_ = np.mod(np.dot(rot, site) + trans, 1)
348
349
                if not all_sites:
350
                    all_sites.append(site_)
351
                    inverses.append(inv)
352
                    multiplicity += 1
353
                    continue
354
355
                if not self._is_site_overlapping(site_, all_sites, inverses, inv):
356
                    all_sites.append(site_)
357
                    inverses.append(inv)
358
                    multiplicity += 1
359
360
            if multiplicity > 0:
361
                multiplicities.append(multiplicity)
362
                asym_unit.append(len(all_sites))
363
364
        frac_motif = np.array(all_sites)
365
        asym_unit = np.array(asym_unit[:-1])
366
        multiplicities = np.array(multiplicities)
367
        return frac_motif, asym_unit, multiplicities, inverses
368
369
    def _is_site_overlapping(self, new_site, all_sites, inverses, inv):
370
        """Return True (and warn) if new_site overlaps with a site in all_sites."""
371
        diffs1 = np.abs(new_site - all_sites)
372
        diffs2 = np.abs(diffs1 - 1)
373
        mask = np.all(np.logical_or(diffs1 <= _Reader.equiv_site_tol,
374
                                    diffs2 <= _Reader.equiv_site_tol),
375
                        axis=-1)
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 2 spaces).
Loading history...
376
377
        if np.any(mask):
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
378
            where_equal = np.argwhere(mask).flatten()
379
            for ind in where_equal:
380
                if inverses[ind] == inv:
381
                    pass
382
                else:
383
                    if self.show_warnings:
384
                        warnings.warn(
385
                            f'{self.current_identifier} has equivalent positions {inverses[ind]} and {inv}')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (108/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
386
            return True
387
        else:
388
            return False
389
390
    def _validate_sites(self, asym_frac_motif, asym_is_disordered):
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
391
        site_diffs1 = np.abs(asym_frac_motif[:, None] - asym_frac_motif)
392
        site_diffs2 = np.abs(site_diffs1 - 1)
393
        overlapping = np.triu(np.all(
394
            (site_diffs1 <= _Reader.equiv_site_tol) |
395
            (site_diffs2 <= _Reader.equiv_site_tol),
396
            axis=-1), 1)
397
398
        if self.disorder == 'all_sites':
399
            for i, j in np.argwhere(overlapping):
400
                if asym_is_disordered[i] or asym_is_disordered[j]:
401
                    overlapping[i, j] = False
402
403
        if overlapping.any():
404
            if self.show_warnings:
405
                warnings.warn(
406
                    f'{self.current_identifier} may have overlapping sites; duplicates will be removed')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
407
            keep_sites = ~overlapping.any(0)
408
            return keep_sites
409
410
    def _has_no_valid_sites(self, motif):
411
        if motif.shape[0] == 0:
412
            if self.show_warnings:
413
                warnings.warn(
414
                    f'Skipping {self.current_identifier} as there are no sites with coordinates')
415
            return True
416
        return False
417
418
    def _construct_periodic_set(self, raw_item, asym_frac_motif, asym_symbols, sitesym, cell):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
419
        frac_motif, asym_unit, multiplicities, inverses = self.expand(asym_frac_motif, sitesym)
420
        full_types = [asym_symbols[i] for i in inverses]
421
        motif = frac_motif @ cell
422
423
        kwargs = {
424
            'name': self.current_identifier,
425
            'asymmetric_unit': asym_unit,
426
            'wyckoff_multiplicities': multiplicities,
427
            'types': full_types,
428
        }
429
430
        if self.current_filename:
431
            kwargs['filename'] = self.current_filename
432
433
        if self.extract_data is not None:
434
            for key in self.extract_data:
435
                kwargs[key] = self.extract_data[key](raw_item)
436
437
        return PeriodicSet(motif, cell, **kwargs)
438
439
    def _heaviest_component(molecule):
0 ignored issues
show
Coding Style Best Practice introduced by
Methods should have self as first argument.

It is a widespread convention and generally a good practice to name the first argument of methods self.

class SomeClass:
    def some_method(self):
        # ... do something
Loading history...
440
        """Heaviest component (removes all but the heaviest component of the asym unit).
441
        Intended for removing solvents. Probably doesn't play well with disorder"""
442
        if len(molecule.components) > 1:
0 ignored issues
show
Bug introduced by
The Instance of _Reader does not seem to have a member named components.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
443
            component_weights = []
444
            for component in molecule.components:
0 ignored issues
show
Bug introduced by
The Instance of _Reader does not seem to have a member named components.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
445
                weight = 0
446
                for a in component.atoms:
447
                    if isinstance(a.atomic_weight, (float, int)):
448
                        if isinstance(a.occupancy, (float, int)):
449
                            weight += a.occupancy * a.atomic_weight
450
                        else:
451
                            weight += a.atomic_weight
452
                component_weights.append(weight)
453
            largest_component_arg = np.argmax(np.array(component_weights))
454
            molecule = molecule.components[largest_component_arg]
0 ignored issues
show
introduced by
Invalid assignment to molecule in method
Loading history...
Bug introduced by
The Instance of _Reader does not seem to have a member named components.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
455
456
        return molecule
457
458
class CifReader(_Reader):
459
    """Read all structures in a .CIF with ``ase`` or ``ccdc``
460
    (``csd-python-api`` only), yielding  :class:`.periodicset.PeriodicSet`
461
    objects which can be passed to :func:`.calculate.AMD` or
462
    :func:`.calculate.PDD`.
463
464
    Examples:
465
466
        ::
467
468
            # Put all crystals in a .CIF in a list
469
            structures = list(amd.CifReader('mycif.cif'))
470
471
            # Reads just one if the .CIF has just one crystal
472
            periodic_set = amd.CifReader('mycif.cif').read_one()
473
474
            # If a folder has several .CIFs each with one crystal, use
475
            structures = list(amd.CifReader('path/to/folder', folder=True))
476
477
            # Make list of AMDs (with k=100) of crystals in a .CIF
478
            amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')]
479
    """
480
481
    @_extend_signature(_Reader.__init__)
482
    def __init__(
483
            self,
484
            path,
485
            reader='ase',
486
            folder=False,
487
            **kwargs):
488
489
        super().__init__(**kwargs)
490
491
        if reader not in ('ase', 'ccdc'):
492
            raise ValueError(f'Invalid reader {reader}; must be ase or ccdc.')
493
494
        if reader == 'ase' and self.heaviest_component:
495
            raise NotImplementedError('Parameter heaviest_component not implimented for ase, only ccdc.')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
496
497
        if reader == 'ase':
498
            extensions = {'cif'}
499
            file_parser = ase.io.cif.parse_cif
500
            pset_converter = self._CIFBlock_to_PeriodicSet
501
502
        elif reader == 'ccdc':
503
            if not _CCDC_ENABLED:
504
                raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
505
            extensions = ccdc.io.EntryReader.known_suffixes
506
            file_parser = ccdc.io.EntryReader
507
            pset_converter = self._Entry_to_PeriodicSet
508
509
        if folder:
510
            generator = self._folder_generator(path, file_parser, extensions)
0 ignored issues
show
introduced by
The variable extensions does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable file_parser does not seem to be defined for all execution paths.
Loading history...
511
        else:
512
            generator = file_parser(path)
513
514
        self._generator = self._map(pset_converter, generator)
0 ignored issues
show
introduced by
The variable pset_converter does not seem to be defined for all execution paths.
Loading history...
515
516
    def _folder_generator(self, path, file_parser, extensions):
517
        for file in os.listdir(path):
518
            suff = os.path.splitext(file)[1][1:]
519
            if suff.lower() in extensions:
520
                self.current_filename = file
521
                yield from file_parser(os.path.join(path, file))
522
523
524
class CSDReader(_Reader):
525
    """Read Entries from the CSD, yielding :class:`.periodicset.PeriodicSet` objects.
526
527
    The CSDReader returns :class:`.periodicset.PeriodicSet` objects which can be passed
528
    to :func:`.calculate.AMD` or :func:`.calculate.PDD`.
529
530
    Examples:
531
532
        Get crystals with refcodes in a list::
533
534
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
535
            structures = list(amd.CSDReader(refcodes))
536
537
        Read refcode families (any whose refcode starts with strings in the list)::
538
539
            refcodes = ['ACSALA', 'HXACAN']
540
            structures = list(amd.CSDReader(refcodes, families=True))
541
542
        Create a generic reader, read crystals by name with :meth:`CSDReader.entry()`::
543
544
            reader = amd.CSDReader()
545
            debxit01 = reader.entry('DEBXIT01')
546
547
            # looping over this generic reader will yield all CSD entries
548
            for periodic_set in reader:
549
                ...
550
551
        Make list of AMD (with k=100) for crystals in these families::
552
553
            refcodes = ['ACSALA', 'HXACAN']
554
            amds = []
555
            for periodic_set in amd.CSDReader(refcodes, families=True):
556
                amds.append(amd.AMD(periodic_set, 100))
557
    """
558
559
    @_extend_signature(_Reader.__init__)
560
    def __init__(
561
            self,
562
            refcodes=None,
563
            families=False,
564
            **kwargs):
565
566
        if not _CCDC_ENABLED:
567
            raise ImportError("Failed to import csd-python-api; check it is installed and licensed.")
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (101/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
568
569
        super().__init__(**kwargs)
570
571
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
572
            refcodes = None
573
574
        if refcodes is None:
575
            families = False
576
        else:
577
            refcodes = [refcodes] if isinstance(refcodes, str) else list(refcodes)
578
579
        # families parameter reads all crystals with ids starting with passed refcodes
580
        if families:
581
            all_refcodes = []
582
            for refcode in refcodes:
583
                query = ccdc.search.TextNumericSearch()
584
                query.add_identifier(refcode)
585
                all_refcodes.extend((hit.identifier for hit in query.search()))
0 ignored issues
show
introduced by
The variable hit does not seem to be defined in case the for loop on line 582 is not entered. Are you sure this can never be the case?
Loading history...
586
587
            # filter to unique refcodes
588
            seen = set()
589
            seen_add = seen.add
590
            refcodes = [
591
                refcode for refcode in all_refcodes
592
                if not (refcode in seen or seen_add(refcode))]
593
594
        self._entry_reader = ccdc.io.EntryReader('CSD')
595
        self._generator = self._map(
596
            self._Entry_to_PeriodicSet,
597
            self._ccdc_generator(refcodes))
598
599
    def _ccdc_generator(self, refcodes):
600
        """Generates ccdc Entries from CSD refcodes"""
601
602
        if refcodes is None:
603
            for entry in self._entry_reader:
604
                yield entry
605
        else:
606
            for refcode in refcodes:
607
                try:
608
                    entry = self._entry_reader.entry(refcode)
609
                    yield entry
610
                except RuntimeError:
611
                    warnings.warn(
612
                        f'Identifier {refcode} not found in database')
613
614
    def entry(self, refcode: str) -> PeriodicSet:
615
        """Read a PeriodicSet given any CSD refcode."""
616
617
        entry = self._entry_reader.entry(refcode)
618
        periodic_set = self._Entry_to_PeriodicSet(entry)
619
        return periodic_set
620
621
622
def crystal_to_periodicset(crystal):
623
    """ccdc.crystal.Crystal --> amd.periodicset.PeriodicSet.
624
    Ignores disorder, missing sites/coords, checks & no options.
625
    Is a stripped-down version of the function used in CifReader."""
626
627
    cell = cellpar_to_cell(*crystal.cell_lengths, *crystal.cell_angles)
628
629
    # asymmetric unit fractional coordinates
630
    asym_frac_motif = np.array([tuple(a.fractional_coordinates)
631
                                for a in crystal.asymmetric_unit_molecule.atoms])
632
    asym_frac_motif = np.mod(asym_frac_motif, 1)
633
634
    # if the above removed everything, skip this structure
635
    if asym_frac_motif.shape[0] == 0:
636
        raise ValueError(f'{crystal.identifier} has no coordinates')
637
638
    sitesym = crystal.symmetry_operators
639
    if not sitesym:
640
        sitesym = ('x,y,z', )
641
    r = _Reader()
642
    r.current_identifier = crystal.identifier
643
    frac_motif, asym_unit, multiplicities, _ = r.expand(asym_frac_motif, sitesym)
644
    motif = frac_motif @ cell
645
646
    kwargs = {
647
        'name': crystal.identifier,
648
        'asymmetric_unit': asym_unit,
649
        'wyckoff_multiplicities': multiplicities,
650
    }
651
652
    periodic_set = PeriodicSet(motif, cell, **kwargs)
653
    return periodic_set
654
655
656
def cifblock_to_periodicset(block):
657
    """ase.io.cif.CIFBlock --> amd.periodicset.PeriodicSet.
658
    Ignores disorder, missing sites/coords, checks & no options.
659
    Is a stripped-down version of the function used in CifReader."""
660
661
    cell = block.get_cell().array
662
    asym_frac_motif = [block.get(name) for name in _Reader.atom_site_fract_tags]
663
664
    if None in asym_frac_motif:
665
        asym_motif = [block.get(name) for name in _Reader.atom_site_cartn_tags]
666
        if None in asym_motif:
667
            warnings.warn(
668
                f'Skipping {block.name} as coordinates were not found')
669
            return None
670
671
        asym_frac_motif = np.array(asym_motif) @ np.linalg.inv(cell)
672
673
    asym_frac_motif = np.mod(np.array(asym_frac_motif).T, 1)
674
675
    if asym_frac_motif.shape[0] == 0:
676
        raise ValueError(f'{block.name} has no coordinates')
677
678
    sitesym = ('x,y,z', )
679
    for tag in _Reader.symop_tags:
680
        if tag in block:
681
            sitesym = block[tag]
682
            break
683
684
    if isinstance(sitesym, str):
685
        sitesym = [sitesym]
686
687
    dummy_reader = _Reader()
688
    dummy_reader.current_identifier = block.name
689
    frac_motif, asym_unit, multiplicities, _ = dummy_reader.expand(asym_frac_motif, sitesym)
690
    motif = frac_motif @ cell
691
692
    kwargs = {
693
        'name': block.name,
694
        'asymmetric_unit': asym_unit,
695
        'wyckoff_multiplicities': multiplicities
696
    }
697
698
    periodic_set = PeriodicSet(motif, cell, **kwargs)
699
    return periodic_set
700