Test Failed
Push — master ( 37d7fb...c02a6e )
by Daniel
07:38
created

amd.io._Reader.__next__()   C

Complexity

Conditions 10

Size

Total Lines 39
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 25
dl 0
loc 39
rs 5.9999
c 0
b 0
f 0
cc 10
nop 1

How to fix   Complexity   

Complexity

Complex classes like amd.io._Reader.__next__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import warnings
9
import os
10
import re
11
import functools
12
import errno
13
import math
14
import json
15
from pathlib import Path
16
from typing import Iterable, Iterator, Optional, Union, Callable, Tuple, List
17
18
import numpy as np
19
import numpy.typing as npt
20
import numba
21
import tqdm
22
23
from .utils import cellpar_to_cell
24
from .periodicset import PeriodicSet
25
26
27
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
28
    return f'{category.__name__}: {message}\n'
29
30
warnings.formatwarning = _custom_warning
31
32
_pkg_path = Path(__file__).absolute().parent
33
34
with open(str(_pkg_path / 'atomic_numbers.json')) as f:
35
    _ATOMIC_NUMBERS = json.load(f)
36
37
with open(str(_pkg_path / 'cif_tags.json')) as f:
38
    _CIF_TAGS = json.load(f)
39
40
_EQ_SITE_TOL = 1e-3
41
42
43
class _Reader:
44
    """Base reader class.
45
    """
46
47
    def __init__(
48
            self,
49
            iterable: Iterable,
50
            converter: Callable[..., PeriodicSet],
51
            show_warnings: bool,
52
            verbose: bool
53
    ):
54
55
        self._iterator = iter(iterable)
56
        self._converter = converter
57
        self.show_warnings = show_warnings
58
        if verbose:
59
            self._progress_bar = tqdm.tqdm(desc='Reading', delay=1)
60
        else:
61
            self._progress_bar = None
62
63
    def __iter__(self):
64
        return self
65
66
    def __next__(self):
67
        """Iterate over self._iterator, passing items through
68
        self._converter and yielding. If
69
        :class:`ParseError <.io.ParseError>` is raised in a call to
70
        self._converter, the item is skipped. Warnings raised in
71
        self._converter are printed if self.show_warnings is True.
72
        """
73
74
        if not self.show_warnings:
75
            warnings.simplefilter('ignore')
76
77
        while True:
78
79
            try:
80
                item = next(self._iterator)
81
            except StopIteration:
82
                if self._progress_bar is not None:
83
                    self._progress_bar.close()
84
                raise StopIteration
85
86
            with warnings.catch_warnings(record=True) as warning_msgs:
87
                msg = None
88
                try:
89
                    periodic_set = self._converter(item)
90
                except ParseError as err:
91
                    msg = str(err)
92
93
            if self._progress_bar is not None:
94
                self._progress_bar.update(1)
95
96
            if msg:
97
                warnings.warn(msg)
98
                continue
99
100
            for warning in warning_msgs:
101
                msg = f'(name={periodic_set.name}) {warning.message}'
102
                warnings.warn(msg, category=warning.category)
103
104
            return periodic_set
105
106
    def read(self) -> Union[PeriodicSet, List[PeriodicSet]]:
107
        """Read the crystal(s), return one
108
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
109
        only one, otherwise return a list. If there the structure cannot
110
        be parsed, return None.
111
        """
112
        items = list(self)
113
        if len(items) == 1:
114
            return items[0]
115
        return items
116
117
118
class CifReader(_Reader):
119
    """Read all structures in a .cif file or all files in a folder
120
    with ase or csd-python-api (if installed), yielding
121
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
122
123
    Parameters
124
    ----------
125
    path : str
126
        Path to a .CIF file or directory. (Other files are accepted when
127
        using ``reader='ccdc'``, if csd-python-api is installed.)
128
    reader : str, optional
129
        The backend package used to parse the CIF. The default is
130
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
131
        accepted, as well as :code:`ccdc` if csd-python-api is
132
        installed. The ccdc reader should be able to read any format
133
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
134
        been tested.
135
    remove_hydrogens : bool, optional
136
        Remove Hydrogens from the crystals.
137
    disorder : str, optional
138
        Controls how disordered structures are handled. Default is
139
        ``skip`` which skips any crystal with disorder, since disorder
140
        conflicts with the periodic set model. To read disordered
141
        structures anyway, choose either :code:`ordered_sites` to remove
142
        atoms with disorder or :code:`all_sites` include all atoms
143
        regardless of disorder.
144
    heaviest_component : bool, optional
145
        csd-python-api only. Removes all but the heaviest molecule in
146
        the asymmeric unit, intended for removing solvents.
147
    molecular_centres : bool, default False
148
        csd-python-api only. Extract the centres of molecules in the
149
        unit cell and store in the attribute molecular_centres.
150
    show_warnings : bool, optional
151
        Controls whether warnings that arise during reading are printed.
152
    verbose : bool, default False
153
        If True, prints a progress bar showing the number of items
154
        processed.
155
156
    Yields
157
    ------
158
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
159
        Represents the crystal as a periodic set, consisting of a finite
160
        set of points (motif) and lattice (unit cell). Contains other
161
        data, e.g. the crystal's name and information about the
162
        asymmetric unit.
163
164
    Examples
165
    --------
166
167
        ::
168
169
            # Put all crystals in a .CIF in a list
170
            structures = list(amd.CifReader('mycif.cif'))
171
172
            # Can also accept path to a directory, reading all files inside
173
            structures = list(amd.CifReader('path/to/folder'))
174
175
            # Reads just one if the .CIF has just one crystal
176
            periodic_set = amd.CifReader('mycif.cif').read()
177
178
            # List of AMDs (k=100) of crystals in a .CIF
179
            amds = [amd.AMD(item, 100) for item in amd.CifReader('mycif.cif')]
180
    """
181
182
    def __init__(
183
            self,
184
            path: Union[str, os.PathLike],
185
            reader: str = 'gemmi',
186
            remove_hydrogens: bool = False,
187
            disorder: str = 'skip',
188
            heaviest_component: bool = False,
189
            molecular_centres: bool = False,
190
            show_warnings: bool = True,
191
            verbose: bool = False
192
    ):
193
194
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
195
            raise ValueError(
196
                f"'disorder'' parameter of {self.__class__.__name__} must be "
197
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
198
                f"'{disorder}')"
199
            )
200
201
        if reader != 'ccdc':
202
            if heaviest_component:
203
                raise NotImplementedError(
204
                    "'heaviest_component' parameter of "
205
                    f"{self.__class__.__name__} only implemented with "
206
                    "csd-python-api, if installed pass reader='ccdc'"
207
                )
208
            if molecular_centres:
209
                raise NotImplementedError(
210
                    "'molecular_centres' parameter of "
211
                    f"{self.__class__.__name__} only implemented with "
212
                    "csd-python-api, if installed pass reader='ccdc'"
213
                )
214
215
        # cannot handle some characters (�) in cifs
216
        if reader == 'gemmi':
217
            import gemmi
218
            extensions = {'cif'}
219
            file_parser = gemmi.cif.read_file
220
            converter = functools.partial(
221
                periodicset_from_gemmi_block,
222
                remove_hydrogens=remove_hydrogens,
223
                disorder=disorder
224
            )
225
226
        elif reader in ('ase', 'pycodcif'):
227
            from ase.io.cif import parse_cif
228
            extensions = {'cif'}
229
            file_parser = functools.partial(parse_cif, reader=reader)
230
            converter = functools.partial(
231
                periodicset_from_ase_cifblock,
232
                remove_hydrogens=remove_hydrogens,
233
                disorder=disorder
234
            )
235
236
        elif reader == 'pymatgen':
237
238
            def _pymatgen_cif_parser(path):
239
                from pymatgen.io.cif import CifFile
240
                return CifFile.from_file(path).data.values()
241
242
            extensions = {'cif'}
243
            file_parser = _pymatgen_cif_parser
244
            converter = functools.partial(
245
                periodicset_from_pymatgen_cifblock,
246
                remove_hydrogens=remove_hydrogens,
247
                disorder=disorder
248
            )
249
250
        elif reader == 'ccdc':
251
            try:
252
                import ccdc.io
253
            except (ImportError, RuntimeError) as e:
254
                raise ImportError('Failed to import csd-python-api') from e
255
256
            extensions = ccdc.io.EntryReader.known_suffixes
257
            file_parser = ccdc.io.EntryReader
258
            converter = functools.partial(
259
                periodicset_from_ccdc_entry,
260
                remove_hydrogens=remove_hydrogens,
261
                disorder=disorder,
262
                molecular_centres=molecular_centres,
263
                heaviest_component=heaviest_component
264
            )
265
266
        else:
267
            raise ValueError(
268
                f"'reader' parameter of {self.__class__.__name__} must be one "
269
                f"of 'gemmi', 'pymatgen', 'ccdc', 'ase', or 'pycodcif' "
270
                f"(passed '{reader}')"
271
            )
272
273
        path = Path(path)
274
        if path.is_file():
275
            iterable = file_parser(str(path))
276
        elif path.is_dir():
277
            iterable = CifReader._dir_generator(path, file_parser, extensions)
278
        else:
279
            raise FileNotFoundError(
280
                errno.ENOENT, os.strerror(errno.ENOENT), path
281
            )
282
283
        super().__init__(iterable, converter, show_warnings, verbose)
284
285
    @staticmethod
286
    def _dir_generator(
287
            path: os.PathLike,
288
            file_parser: Callable,
289
            extensions: Iterable
290
    ) -> Iterator:
291
        for file_path in path.iterdir():
292
            if not file_path.is_file():
293
                continue
294
            if file_path.suffix.lower() not in extensions:
295
                continue
296
            try:
297
                yield from file_parser(str(file_path))
298
            except Exception as e:
299
                warnings.warn(
300
                    f'Error parsing "{str(file_path)}", skipping: {str(e)}'
301
                )
302
303
304
class CSDReader(_Reader):
305
    """Read structures from the CSD with csd-python-api, yielding
306
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
307
308
    Parameters
309
    ----------
310
    refcodes : str or List[str], optional
311
        Single or list of CSD refcodes to read. If None or 'CSD',
312
        iterates over the whole CSD.
313
    families : bool, optional
314
        Read all entries whose refcode starts with the given strings, or
315
        'families' (e.g. giving 'DEBXIT' reads all entries starting with
316
        DEBXIT).
317
    remove_hydrogens : bool, optional
318
        Remove hydrogens from the crystals.
319
    disorder : str, optional
320
        Controls how disordered structures are handled. Default is
321
        ``skip`` which skips any crystal with disorder, since disorder
322
        conflicts with the periodic set model. To read disordered
323
        structures anyway, choose either :code:`ordered_sites` to remove
324
        atoms with disorder or :code:`all_sites` include all atoms
325
        regardless of disorder.
326
    heaviest_component : bool, optional
327
        Removes all but the heaviest molecule in the asymmeric unit,
328
        intended for removing solvents.
329
    molecular_centres : bool, default False
330
        Extract the centres of molecules in the unit cell and store in
331
        attribute molecular_centres.
332
    show_warnings : bool, optional
333
        Controls whether warnings that arise during reading are printed.
334
    verbose : bool, default False
335
        If True, prints a progress bar showing the number of items
336
        processed.
337
338
    Yields
339
    ------
340
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
341
        Represents the crystal as a periodic set, consisting of a finite
342
        set of points (motif) and lattice (unit cell). Contains other
343
        useful data, e.g. the crystal's name and information about the
344
        asymmetric unit for calculation.
345
346
    Examples
347
    --------
348
349
        ::
350
351
            # Put these entries in a list
352
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
353
            structures = list(amd.CSDReader(refcodes))
354
355
            # Read refcode families (any whose refcode starts with strings in the list)
356
            refcode_families = ['ACSALA', 'HXACAN']
357
            structures = list(amd.CSDReader(refcode_families, families=True))
358
359
            # Get AMDs (k=100) for crystals in these families
360
            refcodes = ['ACSALA', 'HXACAN']
361
            amds = []
362
            for periodic_set in amd.CSDReader(refcodes, families=True):
363
                amds.append(amd.AMD(periodic_set, 100))
364
365
            # Giving the reader nothing reads from the whole CSD.
366
            for periodic_set in amd.CSDReader():
367
                ...
368
    """
369
370
    def __init__(
371
            self,
372
            refcodes: Optional[Union[str, Iterable[str]]] = None,
373
            families: bool = False,
374
            remove_hydrogens: bool = False,
375
            disorder: str = 'skip',
376
            heaviest_component: bool = False,
377
            molecular_centres: bool = False,
378
            show_warnings: bool = True,
379
            verbose: bool = False
380
    ):
381
382
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
383
            raise ValueError(
384
                f"'disorder'' parameter of {self.__class__.__name__} must be "
385
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
386
                f"'{disorder}')"
387
            )
388
389
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
390
            refcodes = None
391
392
        if refcodes is None:
393
            families = False
394
        elif isinstance(refcodes, str):
395
            refcodes = [refcodes]
396
        else:
397
            refcodes = [str(refcode) for refcode in refcodes]
398
399
        if families:
400
            refcodes = self._refcodes_from_families_ccdc(refcodes)
401
402
        converter = functools.partial(
403
            periodicset_from_ccdc_entry,
404
            remove_hydrogens=remove_hydrogens,
405
            disorder=disorder,
406
            molecular_centres=molecular_centres,
407
            heaviest_component=heaviest_component
408
        )
409
        iterable = self._ccdc_generator(refcodes)
410
        super().__init__(iterable, converter, show_warnings, verbose)
411
412
    @staticmethod
413
    def _ccdc_generator(refcodes: Optional[Union[str, List[str]]]) -> Iterator:
414
        """Generates ccdc Entries from CSD refcodes.
415
        """
416
417
        try:
418
            import ccdc.io
419
        except (ImportError, RuntimeError) as e:
420
            raise ImportError('Failed to import csd-python-api') from e
421
422
        entry_reader = ccdc.io.EntryReader('CSD')
423
424
        if refcodes is None:
425
            yield from entry_reader
426
        else:
427
            for refcode in refcodes:
428
                entry = entry_reader.entry(refcode)
429
                yield entry
430
431
    @staticmethod
432
    def _refcodes_from_families_ccdc(refcode_families: List[str]) -> List[str]:
433
        """List of strings --> all CSD refcodes starting with any of the
434
        strings. Intended to be passed a list of families and return all
435
        refcodes in them.
436
        """
437
438
        try:
439
            import ccdc.search
440
        except (ImportError, RuntimeError) as e:
441
            raise ImportError('Failed to import csd-python-api') from e
442
443
        all_refcodes = []
444
        for refcode in refcode_families:
445
            query = ccdc.search.TextNumericSearch()
446
            query.add_identifier(refcode)
447
            hits = [hit.identifier for hit in query.search()]
448
            all_refcodes.extend(hits)
449
450
        # filter to unique refcodes while keeping order
451
        refcodes = []
452
        seen = set()
453
        for refcode in all_refcodes:
454
            if refcode not in seen:
455
                refcodes.append(refcode)
456
                seen.add(refcode)
457
458
        return refcodes
459
460
461
class ParseError(ValueError):
462
    """Raised when an item cannot be parsed into a periodic set.
463
    """
464
    pass
465
466
467
def periodicset_from_gemmi_block(
468
        block,
469
        remove_hydrogens: bool = False,
470
        disorder: bool = 'skip'
471
) -> PeriodicSet:
472
    """Convert a :class:`gemmi.cif.Block` object to a
473
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
474
    :class:`gemmi.cif.Block` is the type returned by
475
    :func:`gemmi.cif.read_file`.
476
477
    Parameters
478
    ----------
479
    block : :class:`gemmi.cif.Block`
480
        An ase CIFBlock object representing a crystal.
481
    remove_hydrogens : bool, optional
482
        Remove Hydrogens from the crystal.
483
    disorder : str, optional
484
        Controls how disordered structures are handled. Default is
485
        ``skip`` which skips any crystal with disorder, since disorder
486
        conflicts with the periodic set model. To read disordered
487
        structures anyway, choose either :code:`ordered_sites` to remove
488
        atoms with disorder or :code:`all_sites` include all atoms
489
        regardless of disorder.
490
491
    Returns
492
    -------
493
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
494
        Represents the crystal as a periodic set, consisting of a finite
495
        set of points (motif) and lattice (unit cell). Contains other
496
        useful data, e.g. the crystal's name and information about the
497
        asymmetric unit for calculation.
498
499
    Raises
500
    ------
501
    ParseError
502
        Raised if the structure fails to be parsed for any of the
503
        following: 1. Required data is missing (e.g. cell parameters),
504
        2. :code:``disorder == 'skip'`` and disorder is found on any
505
        atom, 3. The motif is empty after removing H or disordered
506
        sites.
507
    """
508
509
    import gemmi
510
    from gemmi.cif import as_number, as_string, as_int
511
512
    def _gemmi_loop_to_dict(gemmi_loop) -> dict:
513
        """Convert a gemmi Loop object to a dict."""
514
        tablified_loop = [[] for _ in range(len(gemmi_loop.tags))]
515
        n_cols = gemmi_loop.width()
516
        for i, item in enumerate(gemmi_loop.values):
517
            tablified_loop[i % n_cols].append(item)
518
        return {tag: l for tag, l in zip(gemmi_loop.tags, tablified_loop)}
519
520
    # Unit cell
521
    cellpar = [block.find_value(t) for t in _CIF_TAGS['cellpar']]
522
    if not all(isinstance(par, str) for par in cellpar):
523
        raise ParseError(f'{block.name} has missing cell data')
524
    cellpar = np.array([as_number(par) for par in cellpar])
525
    if np.isnan(np.sum(cellpar)):
526
        raise ParseError(f'{block.name} has missing cell data')
527
    cell = cellpar_to_cell(cellpar)
528
529
    # Asymmetric unit coordinates
530
    cartesian = False
531
    xyz_loop = block.find(_CIF_TAGS['atom_site_fract']).loop
532
    if xyz_loop is None:
533
        xyz_loop = block.find(_CIF_TAGS['atom_site_cartn']).loop
534
        if xyz_loop is None:
535
            raise ParseError(f'{block.name} has missing coordinate data')
536
        cartesian = True
537
    loop_dict = _gemmi_loop_to_dict(xyz_loop)
538
    xyz_str = [loop_dict[t] for t in _CIF_TAGS['atom_site_fract']]
539
    asym_unit = np.transpose(np.array(
540
        [[as_number(c) for c in coords] for coords in xyz_str]
541
    ))
542
    if cartesian:
543
        asym_unit = asym_unit @ np.linalg.inv(cell)
544
    asym_unit = np.mod(asym_unit, 1)
545
    # recommended by pymatgen
546
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4) 
547
548
    # Labels
549
    if '_atom_site_label' in loop_dict:
550
        labels = [as_string(label) for label in loop_dict['_atom_site_label']]
551
    else:
552
        labels = [''] * xyz_loop.length()
553
554
    # Atomic types
555
    if '_atom_site_type_symbol' in loop_dict:
556
        asym_syms = [as_string(s) for s in loop_dict['_atom_site_type_symbol']]
557
    else:
558
        asym_syms = []
559
        for l in labels:
560
            sym = re.search(r'([A-Z][a-z]?)', l).group() if l else ''
561
            asym_syms.append(sym)
562
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_syms]
563
564
    # Occupancies
565
    if '_atom_site_occupancy' in loop_dict:
566
        occs = [as_number(occ) for occ in loop_dict['_atom_site_occupancy']]
567
        occupancies = [occ if not math.isnan(occ) else 1 for occ in occs]
568
    else:
569
        occupancies = [1] * xyz_loop.length()
570
571
    # Remove sites with missing coordinates, disorder and Hydrogens if needed
572
    remove_sites = []
573
    remove_sites.extend(np.nonzero(np.isnan(asym_unit.min(axis=-1)))[0])
574
575
    if disorder == 'skip':
576
        if any(
577
            _has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)
578
        ):
579
            raise ParseError(
580
                f"{block.name} has disorder, pass disorder='ordered_sites' or "
581
                "'all_sites' to remove/ignore disorder"
582
            )
583
    elif disorder == 'ordered_sites':
584
        for i, (label, occ) in enumerate(zip(labels, occupancies)):
585
            if _has_disorder(label, occ):
586
                remove_sites.append(i)
587
588
    if remove_hydrogens:
589
        remove_sites.extend(
590
            i for i, num in enumerate(asym_types) if num == 1
591
        )
592
593
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
594
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
595
    if asym_unit.shape[0] == 0:
596
        raise ParseError(f'{block.name} has no valid sites')
597
598
    if disorder != 'all_sites':
599
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
600
        if not np.all(keep_sites):
601
            warnings.warn(
602
                'may have overlapping sites; duplicates will be removed'
603
            )
604
        asym_unit = asym_unit[keep_sites]
605
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
606
607
    # Symmetry operations
608
    sitesym = []
609
    for tag in _CIF_TAGS['symop']:
610
        symop_loop = block.find_loop(tag).get_loop()
611
        if symop_loop is not None:
612
            sitesym = _gemmi_loop_to_dict(symop_loop)[tag]
613
            break
614
615
    if not sitesym:
616
        # TODO: what can gemmi accept here? 
617
        for tag in _CIF_TAGS['spacegroup_name']:
618
            label_or_num = block.find_value(tag)
619
            if label_or_num is not None:
620
                label_or_num = as_string(label_or_num)
621
                break
622
        if label_or_num is None:
0 ignored issues
show
introduced by
The variable label_or_num does not seem to be defined in case the for loop on line 617 is not entered. Are you sure this can never be the case?
Loading history...
623
            for tag in _CIF_TAGS['spacegroup_number']:
624
                label_or_num = block.find_value(tag)
625
                if label_or_num is not None:
626
                    label_or_num = as_int(label_or_num)
627
                    break
628
        if label_or_num is None:
629
            warnings.warn('no symmetry data found, defaulting to P1')
630
            label_or_num = 1
631
        ops = list(gemmi.SpaceGroup(label_or_num).operations())
632
        rot = np.array([np.array(o.rot) / o.DEN for o in ops])
633
        trans = np.array([np.array(o.tran) / o.DEN for o in ops])
634
    else:
635
        rot, trans = _parse_sitesyms(sitesym)
636
637
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
638
    _, wyc_muls = np.unique(invs, return_counts=True)
639
    asym_inds = np.zeros_like(wyc_muls)
640
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
641
    types = np.array([asym_types[i] for i in invs])
642
    motif = frac_motif @ cell
643
644
    return PeriodicSet(
645
        motif=motif,
646
        cell=cell,
647
        name=block.name,
648
        asymmetric_unit=asym_inds,
649
        wyckoff_multiplicities=wyc_muls,
650
        types=types
651
    )
652
653
654
def periodicset_from_ase_cifblock(
655
        block,
656
        remove_hydrogens: bool = False,
657
        disorder: str = 'skip'
658
) -> PeriodicSet:
659
    """Convert a :class:`ase.io.cif.CIFBlock` object to a 
660
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
661
    :class:`ase.io.cif.CIFBlock` is the type returned by
662
    :func:`ase.io.cif.parse_cif`.
663
664
    Parameters
665
    ----------
666
    block : :class:`ase.io.cif.CIFBlock`
667
        An ase :class:`ase.io.cif.CIFBlock` object representing a
668
        crystal.
669
    remove_hydrogens : bool, optional
670
        Remove Hydrogens from the crystal.
671
    disorder : str, optional
672
        Controls how disordered structures are handled. Default is
673
        ``skip`` which skips any crystal with disorder, since disorder
674
        conflicts with the periodic set model. To read disordered
675
        structures anyway, choose either :code:`ordered_sites` to remove
676
        atoms with disorder or :code:`all_sites` include all atoms
677
        regardless of disorder.
678
679
    Returns
680
    -------
681
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
682
        Represents the crystal as a periodic set, consisting of a finite
683
        set of points (motif) and lattice (unit cell). Contains other
684
        useful data, e.g. the crystal's name and information about the
685
        asymmetric unit for calculation.
686
687
    Raises
688
    ------
689
    ParseError
690
        Raised if the structure fails to be parsed for any of the
691
        following: 1. Required data is missing (e.g. cell parameters),
692
        2. The motif is empty after removing H or disordered sites,
693
        3. :code:``disorder == 'skip'`` and disorder is found on any
694
        atom.
695
    """
696
697
    import ase
698
    import ase.spacegroup
699
700
    # Unit cell
701
    cellpar = [block.get(tag) for tag in _CIF_TAGS['cellpar']]
702
    if None in cellpar:
703
        raise ParseError(f'{block.name} has missing cell data')
704
    cell = cellpar_to_cell(np.array(cellpar))
705
706
    # Asymmetric unit coordinates. ase removes uncertainty brackets
707
    cartesian = False
708
    asym_unit = [block.get(name) for name in _CIF_TAGS['atom_site_fract']]
709
    if None in asym_unit:
710
        asym_unit = [block.get(name) for name in _CIF_TAGS['atom_site_cartn']]
711
        if None in asym_unit:
712
            raise ParseError(f'{block.name} has missing coordinates')
713
        cartesian = True
714
    asym_unit = list(zip(*asym_unit))
715
716
    # Atomic types
717
    asym_symbols = block._get_any(_CIF_TAGS['atom_symbol'])
718 View Code Duplication
    if asym_symbols is None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
719
        warnings.warn('missing atomic types will be labelled 0')
720
        asym_types = [0] * len(asym_unit)
721
    else:
722
        asym_types = []
723
        for label in asym_symbols:
724
            if label in ('.', '?'):
725
                warnings.warn('missing atomic types will be labelled 0')
726
                num = 0
727
            else:
728
                sym = re.search(r'([A-Z][a-z]?)', label).group(0)
729
                num = _ATOMIC_NUMBERS[sym]
730
            asym_types.append(num)
731
732
    # Find where sites have disorder if necassary
733
    has_disorder = []
734
    if disorder != 'all_sites':
735
        occupancies = block.get('_atom_site_occupancy')
736
        if occupancies is None:
737
            occupancies = [1] * len(asym_unit)
738
        labels = block.get('_atom_site_label')
739
        if labels is None:
740
            labels = [''] * len(asym_unit)
741
        for lab, occ in zip(labels, occupancies):
742
            has_disorder.append(_has_disorder(lab, occ))
743
744
    # Remove sites with ?, . or other invalid string for coordinates
745
    invalid = []
746
    for i, xyz in enumerate(asym_unit):
747
        if not all(isinstance(coord, (int, float)) for coord in xyz):
748
            invalid.append(i)
749
    if invalid:
750
        warnings.warn('atoms without sites or missing data will be removed')
751
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
752
        asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
753
        if disorder != 'all_sites':
754
            has_disorder = [d for i, d in enumerate(has_disorder)
755
                            if i not in invalid]
756
757
    remove_sites = []
758
759
    if remove_hydrogens:
760
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
761
762
    # Remove atoms with fractional occupancy or raise ParseError
763 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
764
        for i, dis in enumerate(has_disorder):
765
            if i in remove_sites:
766
                continue
767
            if dis:
768
                if disorder == 'skip':
769
                    raise ParseError(
770
                        f'{block.name} has disorder, pass '
771
                        "disorder='ordered_sites' or 'all_sites' to "
772
                        'remove/ignore disorder'
773
                    )
774
                elif disorder == 'ordered_sites':
775
                    remove_sites.append(i)
776
777
    # Asymmetric unit
778
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
779
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
780
    if len(asym_unit) == 0:
781
        raise ParseError(f'{block.name} has no valid sites')
782
    asym_unit = np.array(asym_unit)
783
784
    # If Cartesian coords were given, convert to scaled
785
    if cartesian:
786
        asym_unit = asym_unit @ np.linalg.inv(cell)
787
    asym_unit = np.mod(asym_unit, 1)
788
789
    # recommended by pymatgen
790
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
791
792
    # Remove overlapping sites unless disorder == 'all_sites'
793
    if disorder != 'all_sites':
794
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
795
        if not np.all(keep_sites):
796
            warnings.warn(
797
                'may have overlapping sites, duplicates will be removed'
798
            )
799
            asym_unit = asym_unit[keep_sites]
800
            asym_types = [t for t, keep in zip(asym_types, keep_sites) if keep]
801
802
    # Get symmetry operations
803
    sitesym = block._get_any(_CIF_TAGS['symop'])
804
    if sitesym is None:
805
        label_or_num = block._get_any(_CIF_TAGS['spacegroup_name'])
806
        if label_or_num is None:
807
            label_or_num = block._get_any(_CIF_TAGS['spacegroup_number'])
808
        if label_or_num is None:
809
            warnings.warn('no symmetry data found, defaulting to P1')
810
            label_or_num = 1
811
        spg = ase.spacegroup.Spacegroup(label_or_num)
812
        rot, trans = spg.get_op()
813
    else:
814
        if isinstance(sitesym, str):
815
            sitesym = [sitesym]
816
        rot, trans = _parse_sitesyms(sitesym)
817
818
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
819
    _, wyc_muls = np.unique(invs, return_counts=True)
820
    asym_inds = np.zeros_like(wyc_muls)
821
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
822
    types = np.array([asym_types[i] for i in invs])
823
    motif = frac_motif @ cell
824
825
    return PeriodicSet(
826
        motif=motif,
827
        cell=cell,
828
        name=block.name,
829
        asymmetric_unit=asym_inds,
830
        wyckoff_multiplicities=wyc_muls,
831
        types=types
832
    )
833
834
835
def periodicset_from_ase_atoms(
836
        atoms,
837
        remove_hydrogens: bool = False
838
) -> PeriodicSet:
839
    """Convert an :class:`ase.atoms.Atoms` object to a
840
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
841
    the option to remove disorder.
842
843
    Parameters
844
    ----------
845
    atoms : :class:`ase.atoms.Atoms`
846
        An ase :class:`ase.atoms.Atoms` object representing a crystal.
847
    remove_hydrogens : bool, optional
848
        Remove Hydrogens from the crystal.
849
850
    Returns
851
    -------
852
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
853
        Represents the crystal as a periodic set, consisting of a finite
854
        set of points (motif) and lattice (unit cell). Contains other
855
        useful data, e.g. the crystal's name and information about the
856
        asymmetric unit for calculation.
857
858
    Raises
859
    ------
860
    ParseError
861
        Raised if there are no valid sites in atoms.
862
    """
863
864
    from ase.spacegroup import get_basis
865
866
    cell = atoms.get_cell().array
867
868
    remove_inds = []
869
    if remove_hydrogens:
870
        for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
871
            remove_inds.append(i)
872
    for i in sorted(remove_inds, reverse=True):
873
        atoms.pop(i)
874
875
    if len(atoms) == 0:
876
        raise ParseError('ase Atoms object has no valid sites')
877
878
    # Symmetry operations from spacegroup
879
    spg = None
880
    if 'spacegroup' in atoms.info:
881
        spg = atoms.info['spacegroup']
882
        rot, trans = spg.rotations, spg.translations
883
    else:
884
        warnings.warn('no symmetry data found, defaulting to P1')
885
        rot = np.identity(3)[None, :]
886
        trans = np.zeros((1, 3))
887
888
    # Asymmetric unit. ase default tol is 1e-5
889
    # do differently! get_basis determines a reduced asym unit from the atoms;
890
    # surely this is not needed!
891
    asym_unit = get_basis(atoms, spacegroup=spg, tol=_EQ_SITE_TOL)
892
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
893
    _, wyc_muls = np.unique(invs, return_counts=True)
894
    asym_inds = np.zeros_like(wyc_muls)
895
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
896
    motif = frac_motif @ cell
897
898
    return PeriodicSet(
899
        motif=motif,
900
        cell=cell,
901
        asymmetric_unit=asym_inds,
902
        wyckoff_multiplicities=wyc_muls,
903
        types=atoms.get_atomic_numbers()
904
    )
905
906
907
def periodicset_from_pymatgen_cifblock(
908
        block,
909
        remove_hydrogens: bool = False,
910
        disorder: str = 'skip'
911
) -> PeriodicSet:
912
    """Convert a :class:`pymatgen.io.cif.CifBlock` object to a
913
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
914
    :class:`pymatgen.io.cif.CifBlock` is the type returned by
915
    :class:`pymatgen.io.cif.CifFile`.
916
917
    Parameters
918
    ----------
919
    block : :class:`pymatgen.io.cif.CifBlock`
920
        A pymatgen CifBlock object representing a crystal.
921
    remove_hydrogens : bool, optional
922
        Remove Hydrogens from the crystal.
923
    disorder : str, optional
924
        Controls how disordered structures are handled. Default is
925
        ``skip`` which skips any crystal with disorder, since disorder
926
        conflicts with the periodic set model. To read disordered
927
        structures anyway, choose either :code:`ordered_sites` to remove
928
        atoms with disorder or :code:`all_sites` include all atoms
929
        regardless of disorder.
930
931
    Returns
932
    -------
933
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
934
        Represents the crystal as a periodic set, consisting of a finite
935
        set of points (motif) and lattice (unit cell). Contains other
936
        useful data, e.g. the crystal's name and information about the
937
        asymmetric unit for calculation.
938
939
    Raises
940
    ------
941
    ParseError
942
        Raised if the structure can/should not be parsed for the
943
        following reasons: 1. No sites found or motif is empty after
944
        removing Hydrogens & disorder, 2. A site has missing
945
        coordinates, 3. :code:``disorder == 'skip'`` and disorder is
946
        found on any atom.
947
    """
948
949
    from pymatgen.io.cif import str2float
950
951
    odict = block.data
952
953
    # Unit cell
954
    cellpar = [odict.get(tag) for tag in _CIF_TAGS['cellpar']]
955
    if any(par in (None, '?', '.') for par in cellpar):
956
        raise ParseError(f'{block.header} has missing cell data')
957
    cell = cellpar_to_cell(np.array([str2float(v) for v in cellpar]))
958
959
    # Asymmetric unit coordinates
960
    cartesian = False
961
    asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
962
    # check for . and ?
963
    if None in asym_unit:
964
        asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_cartn']]
965
        if None in asym_unit:
966
            raise ParseError(f'{block.header} has no coordinates')
967
        cartesian = True
968
969
    asym_unit = list(zip(*asym_unit))
970
    asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
971
972
    # Atomic types
973
    for tag in _CIF_TAGS['atom_symbol']:
974
        asym_symbols = odict.get(tag)
975 View Code Duplication
        if asym_symbols is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
976
            asym_types = []
977
            for label in asym_symbols:
978
                if label in ('.', '?'):
979
                    warnings.warn('missing atomic types will be labelled 0')
980
                    num = 0
981
                else:
982
                    sym = re.search(r'([A-Z][a-z]?)', label).group(0)
983
                    num = _ATOMIC_NUMBERS[sym]
984
                asym_types.append(num)
985
            break
986
    else:
987
        warnings.warn('missing atomic types will be labelled 0')
988
        asym_types = [0] * len(asym_unit)
989
990
    # Find where sites have disorder if necassary
991
    has_disorder = []
992
    if disorder != 'all_sites':
993
        occupancies = odict.get('_atom_site_occupancy')
994
        if occupancies is None:
995
            occupancies = np.ones((len(asym_unit), ))
996
        else:
997
            occupancies = np.array([str2float(occ) for occ in occupancies])
998
        labels = odict.get('_atom_site_label')
999
        if labels is None:
1000
            labels = [''] * len(asym_unit)
1001
        for lab, occ in zip(labels, occupancies):
1002
            has_disorder.append(_has_disorder(lab, occ))
1003
1004
    # Remove sites with ?, . or other invalid string for coordinates
1005
    invalid = []
1006
    for i, xyz in enumerate(asym_unit):
1007
        if not all(isinstance(coord, (int, float)) for coord in xyz):
1008
            invalid.append(i)
1009
1010
    if invalid:
1011
        warnings.warn('atoms without sites or missing data will be removed')
1012
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
1013
        asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
1014
        if disorder != 'all_sites':
1015
            has_disorder = [
1016
                d for i, d in enumerate(has_disorder) if i not in invalid
1017
            ]
1018
1019
    remove_sites = []
1020
1021
    if remove_hydrogens:
1022
        remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
0 ignored issues
show
introduced by
The variable i does not seem to be defined in case the for loop on line 1006 is not entered. Are you sure this can never be the case?
Loading history...
1023
1024
    # Remove atoms with fractional occupancy or raise ParseError
1025 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1026
        for i, dis in enumerate(has_disorder):
1027
            if i in remove_sites:
1028
                continue
1029
            if dis:
1030
                if disorder == 'skip':
1031
                    raise ParseError(
1032
                        f'{block.header} has disorder, pass '
1033
                        "disorder='ordered_sites' or 'all_sites' to "
1034
                        'remove/ignore disorder'
1035
                    )
1036
                elif disorder == 'ordered_sites':
1037
                    remove_sites.append(i)
1038
1039
    # Asymmetric unit
1040
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
1041
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
1042
    if len(asym_unit) == 0:
1043
        raise ParseError(f'{block.header} has no valid sites')
1044
    asym_unit = np.array(asym_unit)
1045
1046
    # If Cartesian coords were given, convert to scaled
1047
    if cartesian:
1048
        asym_unit = asym_unit @ np.linalg.inv(cell)
1049
    asym_unit = np.mod(asym_unit, 1)
1050
1051
    # recommended by pymatgen
1052
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1053
1054
    # Remove overlapping sites unless disorder == 'all_sites'
1055
    if disorder != 'all_sites':
1056
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1057
        if not np.all(keep_sites):
1058
            warnings.warn(
1059
                'may have overlapping sites; duplicates will be removed'
1060
            )
1061
        asym_unit = asym_unit[keep_sites]
1062
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1063
1064
    # Apply symmetries to asymmetric unit
1065
    rot, trans = _get_syms_pymatgen(odict)
1066
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1067
    _, wyc_muls = np.unique(invs, return_counts=True)
1068
    asym_inds = np.zeros_like(wyc_muls)
1069
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1070
    types = np.array([asym_types[i] for i in invs])
1071
    motif = frac_motif @ cell
1072
1073
    return PeriodicSet(
1074
        motif=motif,
1075
        cell=cell,
1076
        name=block.header,
1077
        asymmetric_unit=asym_inds,
1078
        wyckoff_multiplicities=wyc_muls,
1079
        types=types
1080
    )
1081
1082
1083
def periodicset_from_pymatgen_structure(
1084
        structure,
1085
        remove_hydrogens: bool = False,
1086
        disorder: str = 'skip'
1087
) -> PeriodicSet:
1088
    """Convert a :class:`pymatgen.core.structure.Structure` object to a
1089
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not set
1090
    the name of the periodic set, as pymatgen Structure objects seem to
1091
    have no name attribute.
1092
1093
    Parameters
1094
    ----------
1095
    structure : :class:`pymatgen.core.structure.Structure`
1096
        A pymatgen Structure object representing a crystal.
1097
    remove_hydrogens : bool, optional
1098
        Remove Hydrogens from the crystal.
1099
    disorder : str, optional
1100
        Controls how disordered structures are handled. Default is
1101
        ``skip`` which skips any crystal with disorder, since disorder
1102
        conflicts with the periodic set model. To read disordered
1103
        structures anyway, choose either :code:`ordered_sites` to remove
1104
        atoms with disorder or :code:`all_sites` include all atoms
1105
        regardless of disorder.
1106
1107
    Returns
1108
    -------
1109
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1110
        Represents the crystal as a periodic set, consisting of a finite
1111
        set of points (motif) and lattice (unit cell). Contains other
1112
        useful data, e.g. the crystal's name and information about the
1113
        asymmetric unit for calculation.
1114
1115
    Raises
1116
    ------
1117
    ParseError
1118
        Raised if the :code:`disorder == 'skip'` and
1119
        :code:`not structure.is_ordered`
1120
    """
1121
1122
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1123
1124
    if remove_hydrogens:
1125
        structure.remove_species(['H', 'D'])
1126
1127
    # Disorder
1128
    if disorder == 'skip':
1129
        if not structure.is_ordered:
1130
            raise ParseError(
1131
                'pymatgen Structure has disorder, pass '
1132
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1133
                'disorder'
1134
            )
1135
    elif disorder == 'ordered_sites':
1136
        remove_inds = []
1137
        for i, comp in enumerate(structure.species_and_occu):
1138
            if comp.num_atoms < 1:
1139
                remove_inds.append(i)
1140
        structure.remove_sites(remove_inds)
1141
1142
    motif = structure.cart_coords
1143
    cell = structure.lattice.matrix
1144
    sym_structure = SpacegroupAnalyzer(structure).get_symmetrized_structure()
1145
    eq_inds = sym_structure.equivalent_indices
1146
    asym_unit = np.array([ind_list[0] for ind_list in eq_inds])
1147
    wyc_muls = np.array([len(ind_list) for ind_list in eq_inds])
1148
    types = np.array(sym_structure.atomic_numbers)
1149
1150
    return PeriodicSet(
1151
        motif=motif,
1152
        cell=cell,
1153
        asymmetric_unit=asym_unit,
1154
        wyckoff_multiplicities=wyc_muls,
1155
        types=types
1156
    )
1157
1158
1159
def periodicset_from_ccdc_entry(
1160
        entry,
1161
        remove_hydrogens: bool = False,
1162
        disorder: str = 'skip',
1163
        heaviest_component: bool = False,
1164
        molecular_centres: bool = False
1165
) -> PeriodicSet:
1166
    """Convert a :class:`ccdc.entry.Entry` object to a
1167
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1168
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
1169
1170
    Parameters
1171
    ----------
1172
    entry : :class:`ccdc.entry.Entry`
1173
        A ccdc Entry object representing a database entry.
1174
    remove_hydrogens : bool, optional
1175
        Remove Hydrogens from the crystal.
1176
    disorder : str, optional
1177
        Controls how disordered structures are handled. Default is
1178
        ``skip`` which skips any crystal with disorder, since disorder
1179
        conflicts with the periodic set model. To read disordered
1180
        structures anyway, choose either :code:`ordered_sites` to remove
1181
        atoms with disorder or :code:`all_sites` include all atoms
1182
        regardless of disorder.
1183
    heaviest_component : bool, optional
1184
        Removes all but the heaviest molecule in the asymmeric unit,
1185
        intended for removing solvents.
1186
    molecular_centres : bool, default False
1187
        Extract the centres of molecules in the unit cell and store in
1188
        the attribute molecular_centres of the returned PeriodicSet.
1189
1190
    Returns
1191
    -------
1192
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1193
        Represents the crystal as a periodic set, consisting of a finite
1194
        set of points (motif) and lattice (unit cell). Contains other
1195
        useful data, e.g. the crystal's name and information about the
1196
        asymmetric unit for calculation.
1197
1198
    Raises
1199
    ------
1200
    ParseError
1201
        Raised if the structure fails parsing for any of the following:
1202
        1. entry.has_3d_structure is False, 2.
1203
        :code:``disorder == 'skip'`` and disorder is found on any atom,
1204
        3. entry.crystal.molecule.all_atoms_have_sites is False,
1205
        4. a.fractional_coordinates is None for any a in
1206
        entry.crystal.disordered_molecule, 5. The motif is empty after
1207
        removing Hydrogens and disordered sites.
1208
    """
1209
1210
    # Entry specific flag
1211
    if not entry.has_3d_structure:
1212
        raise ParseError(f'{entry.identifier} has no 3D structure')
1213
1214
    # Disorder
1215
    if disorder == 'skip' and entry.has_disorder:
1216
        raise ParseError(
1217
            f"{entry.identifier} has disorder, pass disorder='ordered_sites' "
1218
            "or 'all_sites' to remove/ignore disorder"
1219
        )
1220
1221
    return periodicset_from_ccdc_crystal(
1222
        entry.crystal,
1223
        remove_hydrogens=remove_hydrogens,
1224
        disorder=disorder,
1225
        heaviest_component=heaviest_component,
1226
        molecular_centres=molecular_centres
1227
    )
1228
1229
1230
def periodicset_from_ccdc_crystal(
1231
        crystal,
1232
        remove_hydrogens: bool = False,
1233
        disorder: str = 'skip',
1234
        heaviest_component: bool = False,
1235
        molecular_centres: bool = False
1236
) -> PeriodicSet:
1237
    """Convert a :class:`ccdc.crystal.Crystal` object to a
1238
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1239
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
1240
1241
    Parameters
1242
    ----------
1243
    crystal : :class:`ccdc.crystal.Crystal`
1244
        A ccdc Crystal object representing a crystal structure.
1245
    remove_hydrogens : bool, optional
1246
        Remove Hydrogens from the crystal.
1247
    disorder : str, optional
1248
        Controls how disordered structures are handled. Default is
1249
        ``skip`` which skips any crystal with disorder, since disorder
1250
        conflicts with the periodic set model. To read disordered
1251
        structures anyway, choose either :code:`ordered_sites` to remove
1252
        atoms with disorder or :code:`all_sites` include all atoms
1253
        regardless of disorder.
1254
    heaviest_component : bool, optional
1255
        Removes all but the heaviest molecule in the asymmeric unit,
1256
        intended for removing solvents.
1257
    molecular_centres : bool, default False
1258
        Extract the centres of molecules in the unit cell and store in
1259
        the attribute molecular_centres of the returned PeriodicSet.
1260
1261
    Returns
1262
    -------
1263
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1264
        Represents the crystal as a periodic set, consisting of a finite
1265
        set of points (motif) and lattice (unit cell). Contains other
1266
        useful data, e.g. the crystal's name and information about the
1267
        asymmetric unit for calculation.
1268
1269
    Raises
1270
    ------
1271
    ParseError
1272
        Raised if the structure fails parsing for any of the following:
1273
        1. :code:``disorder == 'skip'`` and disorder is found on any
1274
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
1275
        3. a.fractional_coordinates is None for any a in
1276
        crystal.disordered_molecule, 4. The motif is empty after
1277
        removing H, disordered sites or solvents.
1278
    """
1279
1280
    molecule = crystal.disordered_molecule
1281
1282
    # Disorder
1283
    if disorder == 'skip':
1284
        if crystal.has_disorder or \
1285
         any(_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
1286
            raise ParseError(
1287
                f"{crystal.identifier} has disorder, pass "
1288
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1289
                "disorder"
1290
            )
1291
    elif disorder == 'ordered_sites':
1292
        molecule.remove_atoms(
1293
            a for a in molecule.atoms if _has_disorder(a.label, a.occupancy)
1294
        )
1295
1296
    if remove_hydrogens:
1297
        molecule.remove_atoms(
1298
            a for a in molecule.atoms if a.atomic_symbol in 'HD'
1299
        )
1300
1301
    if heaviest_component and len(molecule.components) > 1:
1302
        molecule = _heaviest_component_ccdc(molecule)
1303
1304
    # Remove atoms with missing coordinates and warn
1305
    if any(a.fractional_coordinates is None for a in molecule.atoms):
1306
        warnings.warn('atoms without sites or missing data will be removed')
1307
        molecule.remove_atoms(
1308
            a for a in molecule.atoms if a.fractional_coordinates is None
1309
        )
1310
1311
    crystal.molecule = molecule
1312
    cellpar = crystal.cell_lengths + crystal.cell_angles
1313
    if None in cellpar:
1314
        raise ParseError(f'{crystal.identifier} has missing cell data')
1315
    cell = cellpar_to_cell(np.array(cellpar))
1316
1317
    if molecular_centres:
1318
        frac_centres = _frac_molecular_centres_ccdc(crystal, _EQ_SITE_TOL)
1319
        mol_centres = frac_centres @ cell
1320
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
1321
1322
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
1323
    # check for None?
1324
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
1325
1326
    if asym_unit.shape[0] == 0:
1327
        raise ParseError(f'{crystal.identifier} has no valid sites')
1328
1329
    asym_unit = np.mod(asym_unit, 1)
1330
1331
    # recommended by pymatgen
1332
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1333
1334
    asym_types = [a.atomic_number for a in asym_atoms]
1335
1336
    # Remove overlapping sites unless disorder == 'all_sites'
1337
    if disorder != 'all_sites':
1338
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1339
        if not np.all(keep_sites):
1340
            warnings.warn(
1341
                'may have overlapping sites; duplicates will be removed'
1342
            )
1343
        asym_unit = asym_unit[keep_sites]
1344
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1345
1346
    # Symmetry operations
1347
    sitesym = crystal.symmetry_operators
1348
    # try spacegroup numbers?
1349
    if not sitesym:
1350
        warnings.warn('no symmetry data found, defaulting to P1')
1351
        sitesym = ['x,y,z']
1352
1353
    # Apply symmetries to asymmetric unit
1354
    rot, trans = _parse_sitesyms(sitesym)
1355
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1356
    _, wyc_muls = np.unique(invs, return_counts=True)
1357
    asym_inds = np.zeros_like(wyc_muls)
1358
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1359
    motif = frac_motif @ cell
1360
    types = np.array([asym_types[i] for i in invs])
1361
1362
    return PeriodicSet(
1363
        motif=motif,
1364
        cell=cell,
1365
        name=crystal.identifier,
1366
        asymmetric_unit=asym_inds,
1367
        wyckoff_multiplicities=wyc_muls,
1368
        types=types
1369
    )
1370
1371
1372
def _parse_sitesyms(
1373
        symmetries: List[str]
1374
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
1375
    """Parse a sequence of symmetries in xyz form and return rotation
1376
    and translation arrays.
1377
    """
1378
1379
    n_syms = len(symmetries)
1380
    rotations = np.zeros((n_syms, 3, 3), dtype=np.float64)
1381
    translations = np.zeros((n_syms, 3), dtype=np.float64)
1382
1383
    for i, sym in enumerate(symmetries):
1384
        for ind, element in enumerate(sym.split(',')):
1385
1386
            is_positive = True
1387
            is_fraction = False
1388
            sng_trans = None
1389
            fst_trans = []
1390
            snd_trans = []
1391
1392
            for char in element.lower():
1393
                if char == '+':
1394
                    is_positive = True
1395
                elif char == '-':
1396
                    is_positive = False
1397
                elif char == '/':
1398
                    is_fraction = True
1399
                elif char in 'xyz':
1400
                    rot_sgn = 1.0 if is_positive else -1.0
1401
                    rotations[i][ind][ord(char) - ord('x')] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1402
                elif char.isdigit() or char == '.':
1403
                    if sng_trans is None:
1404
                        sng_trans = 1.0 if is_positive else -1.0
1405
                    if is_fraction:
1406
                        snd_trans.append(char)
1407
                    else:
1408
                        fst_trans.append(char)
1409
1410
            if not fst_trans:
1411
                e_trans = 0.0
1412
            else:
1413
                e_trans = sng_trans * float(''.join(fst_trans))
1414
1415
            if is_fraction:
1416
                e_trans /= float(''.join(snd_trans))
1417
1418
            translations[i][ind] = e_trans
1419
1420
    return rotations, translations
1421
1422
1423
def _expand_asym_unit(
1424
        asym_unit: npt.NDArray,
1425
        rotations: npt.NDArray,
1426
        translations: npt.NDArray,
1427
        tol: float
1428
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int32]]:
1429
    """Expand the asymmetric unit by applying symmetries given by
1430
    ``rotations`` and ``translations``.
1431
    """
1432
1433
    asym_unit = asym_unit.astype(np.float64, copy=False)
1434
    rotations = rotations.astype(np.float64, copy=False)
1435
    translations = translations.astype(np.float64, copy=False)
1436
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1437
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1438
1439
    if not all(_unique_sites(frac_motif, tol)):
1440
        frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1441
1442
    return frac_motif, invs
1443
1444
1445
@numba.njit(cache=True)
1446
def _expand_sites(
1447
        asym_unit: npt.NDArray[np.float64],
1448
        rotations: npt.NDArray[np.float64],
1449
        translations: npt.NDArray[np.float64]
1450
) -> npt.NDArray[np.float64]:
1451
    """Expand the asymmetric unit by applying ``rotations`` and
1452
    ``translations``, without yet removing points duplicated because
1453
    they are invariant under a symmetry. Returns a 3D array shape
1454
    (#points, #syms, dims).
1455
    """
1456
1457
    m, dims = asym_unit.shape
1458
    n_syms = len(rotations)
1459
    expanded_sites = np.empty((m, n_syms, dims), dtype=np.float64)
1460
    for i in range(m):
1461
        p = asym_unit[i]
1462
        for j in range(n_syms):
1463
            expanded_sites[i, j] = np.dot(rotations[j], p) + translations[j]
1464
    expanded_sites = np.mod(expanded_sites, 1)
1465
    return expanded_sites
1466
1467
1468
@numba.njit(cache=True)
1469
def _reduce_expanded_sites(
1470
        expanded_sites: npt.NDArray[np.float64],
1471
        tol: float
1472
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int32]]:
1473
    """Reduce the asymmetric unit after being expended by symmetries by
1474
    removing invariant points. This is the fast version which works in
1475
    the case that no two sites in the asymmetric unit are equivalent.
1476
    If they are, the reduction is re-ran with
1477
    _reduce_expanded_equiv_sites() to account for it.
1478
    """
1479
1480
    all_unqiue_inds = []
1481
    n_sites, _, dims = expanded_sites.shape
1482
    multiplicities = np.zeros(shape=(n_sites, ))
1483
1484
    for i, sites in enumerate(expanded_sites):
1485
        unique_inds = _unique_sites(sites, tol)
1486
        all_unqiue_inds.append(unique_inds)
1487
        multiplicities[i] = np.sum(unique_inds)
1488
1489
    m = int(np.sum(multiplicities))
1490
    frac_motif = np.zeros(shape=(m, dims))
1491
    inverses = np.zeros(shape=(m, ), dtype=np.int32)
1492
1493
    s = 0
1494
    for i in range(n_sites):
1495
        t = s + multiplicities[i]
1496
        frac_motif[s:t, :] = expanded_sites[i][all_unqiue_inds[i]]
1497
        inverses[s:t] = i
1498
        s = t
1499
1500
    return frac_motif, inverses
1501
1502
1503
def _reduce_expanded_equiv_sites(
1504
        expanded_sites: npt.NDArray[np.float64],
1505
        tol: float
1506
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int32]]:
1507
    """Reduce the asymmetric unit after being expended by symmetries by
1508
    removing invariant points. This is the slower version, called after
1509
    the fast version if we find equivalent motif points which need to be
1510
    removed.
1511
    """
1512
1513
    sites = expanded_sites[0]
1514
    unique_inds = _unique_sites(sites, tol)
1515
    frac_motif = sites[unique_inds]
1516
    inverses = [0] * len(frac_motif)
1517
1518
    for i in range(1, len(expanded_sites)):
1519
        sites = expanded_sites[i]
1520
        unique_inds = _unique_sites(sites, tol)
1521
1522
        points = []
1523
        for site in sites[unique_inds]:
1524
            diffs1 = np.abs(site - frac_motif)
1525
            diffs2 = np.abs(diffs1 - 1)
1526
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1527
1528
            if not np.any(mask):
1529
                points.append(site)
1530
            else:
1531
                warnings.warn(
1532
                    'has equivalent sites at positions '
1533
                    f'{inverses[np.argmax(mask)]}, {i}'
1534
                )
1535
1536
        if points:
1537
            inverses.extend(i for _ in range(len(points)))
1538
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1539
1540
    return frac_motif, np.array(inverses, dtype=np.int32)
1541
1542
1543
@numba.njit(cache=True)
1544
def _unique_sites(
1545
        asym_unit: npt.NDArray[np.float64], tol: float
1546
) -> npt.NDArray[np.bool_]:
1547
    """Uniquify (within tol) a list of fractional coordinates,
1548
    considering all points modulo 1. Return an array of bools such that
1549
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1550
    """
1551
1552
    m, _ = asym_unit.shape
1553
    where_unique = np.full(shape=(m, ), fill_value=True)
1554
1555
    for i in range(1, m):
1556
        asym_unit[i]
1557
        site_diffs1 = np.abs(asym_unit[:i, :] - asym_unit[i])
1558
        site_diffs2 = np.abs(site_diffs1 - 1)
1559
        sites_neq_mask = (site_diffs1 > tol) & (site_diffs2 > tol)
1560
        if not np.all(np.sum(sites_neq_mask, axis=-1)):
1561
            where_unique[i] = False
1562
1563
    return where_unique
1564
1565
1566
def _has_disorder(label: str, occupancy) -> bool:
1567
    """Return True if label ends with ? or occupancy is a number < 1."""
1568
    try:
1569
        occupancy = float(occupancy)
1570
    except Exception:
1571
        occupancy = 1
1572
    return (occupancy < 1) or label.endswith('?')
1573
1574
1575
def _get_syms_pymatgen(
1576
        data: dict
1577
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
1578
    """Parse symmetry operations given by data = block.data where block
1579
    is a pymatgen CifBlock object. If the symops are not present the
1580
    space group symbol/international number is parsed and symops are
1581
    generated.
1582
    """
1583
1584
    from pymatgen.symmetry.groups import SpaceGroup
1585
    import pymatgen.io.cif
1586
1587
    # Try xyz symmetry operations
1588
    for symmetry_label in _CIF_TAGS['symop']:
1589
        xyz = data.get(symmetry_label)
1590
        if not xyz:
1591
            continue
1592
        if isinstance(xyz, str):
1593
            xyz = [xyz]
1594
        return _parse_sitesyms(xyz)
1595
1596
    symops = []
1597
    # Try spacegroup symbol
1598
    for symmetry_label in _CIF_TAGS['spacegroup_name']:
1599
        sg = data.get(symmetry_label)
1600
        if not sg:
1601
            continue
1602
        sg = re.sub(r'[\s_]', '', sg)
1603
        try:
1604
            spg = pymatgen.io.cif.space_groups.get(sg)
1605
            if not spg:
1606
                continue
1607
            symops = SpaceGroup(spg).symmetry_ops
1608
            break
1609
        except ValueError:
1610
            pass
1611
        try:
1612
            for d in pymatgen.io.cif._get_cod_data():
1613
                if sg == re.sub(r'\s+', '', d['hermann_mauguin']):
1614
                    return _parse_sitesyms(d['symops'])
1615
        except Exception:
1616
            continue
1617
        if symops:
1618
            break
1619
1620
    # Try international number
1621
    if not symops:
1622
        for symmetry_label in _CIF_TAGS['spacegroup_number']:
1623
            num = data.get(symmetry_label)
1624
            if not num:
1625
                continue
1626
            try:
1627
                i = int(pymatgen.io.cif.str2float(num))
1628
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1629
                break
1630
            except ValueError:
1631
                continue
1632
1633
    if not symops:
1634
        warnings.warn('no symmetry data found, defaulting to P1')
1635
        return _parse_sitesyms(['x,y,z'])
1636
1637
    rotations = [op.rotation_matrix for op in symops]
1638
    translations = [op.translation_vector for op in symops]
1639
    rotations = np.array(rotations, dtype=np.float64)
1640
    translations = np.array(translations, dtype=np.float64)
1641
1642
    return rotations, translations
1643
1644
1645
def _frac_molecular_centres_ccdc(
1646
        crystal, tol: float
1647
) -> npt.NDArray[np.float64]:
1648
    """Return the geometric centres of molecules in the unit cell.
1649
    Expects a ccdc Crystal object and returns fractional coordiantes.
1650
    """
1651
1652
    frac_centres = []
1653
    for comp in crystal.packing(inclusion='CentroidIncluded').components:
1654
        coords = [a.fractional_coordinates for a in comp.atoms]
1655
        frac_centres.append((sum(ax) / len(coords) for ax in zip(*coords)))
0 ignored issues
show
introduced by
The variable ax does not seem to be defined in case the for loop on line 1653 is not entered. Are you sure this can never be the case?
Loading history...
1656
    frac_centres = np.mod(np.array(frac_centres, dtype=np.float64), 1)
1657
    return frac_centres[_unique_sites(frac_centres, tol)]
1658
1659
1660
def _heaviest_component_ccdc(molecule):
1661
    """Remove all but the heaviest component of the asymmetric unit.
1662
    Intended for removing solvents. Expects and returns a ccdc Molecule
1663
    object.
1664
    """
1665
1666
    component_weights = []
1667
    for component in molecule.components:
1668
        weight = 0
1669
        for a in component.atoms:
1670
            try:
1671
                occ = float(a.occupancy)
1672
            except ValueError:
1673
                occ = 1
1674
            try:
1675
                weight += float(a.atomic_weight) * occ
1676
            except ValueError:
1677
                pass
1678
        component_weights.append(weight)
1679
    largest_component_ind = np.argmax(np.array(component_weights))
1680
    molecule = molecule.components[largest_component_ind]
1681
    return molecule
1682
1683
1684
def _snap_small_prec_coords(
1685
        frac_coords: npt.NDArray[np.float64], tol: float
1686
) -> npt.NDArray[np.float64]:
1687
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1688
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1689
    """
1690
1691
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.
1692
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.
1693
    return frac_coords
1694