amd.io.CifReader.__init__()   C
last analyzed

Complexity

Conditions 9

Size

Total Lines 98
Code Lines 51

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 51
dl 0
loc 98
rs 6.2703
c 0
b 0
f 0
cc 9
nop 11

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import warnings
9
import collections
10
import os
11
import re
12
import functools
13
import errno
14
import math
15
import itertools
16
from pathlib import Path
17
from typing import Iterable, Iterator, Optional, Union, Callable, Tuple, List
18
19
import numpy as np
20
import numpy.typing as npt
21
import numba
22
import tqdm
23
from scipy.spatial.distance import pdist
24
25
from .utils import cellpar_to_cell
26
from .periodicset import DisorderGroup, DisorderAssembly, PeriodicSet
27
from .calculate import _collapse_into_groups
28
from .globals_ import SUBS_DISORDER_TOL, ATOMIC_NUMBERS
29
from ._types import FloatArray, UIntArray
30
31
32
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
33
    return f"{category.__name__}: {message}\n"
34
35
36
warnings.formatwarning = _custom_warning
37
38
39
_CIF_TAGS: dict = {
40
    "cellpar": [
41
        "_cell_length_a",
42
        "_cell_length_b",
43
        "_cell_length_c",
44
        "_cell_angle_alpha",
45
        "_cell_angle_beta",
46
        "_cell_angle_gamma",
47
    ],
48
    "atom_site_fract": [
49
        "_atom_site_fract_x",
50
        "_atom_site_fract_y",
51
        "_atom_site_fract_z",
52
    ],
53
    "atom_site_cartn": [
54
        "_atom_site_Cartn_x",
55
        "_atom_site_Cartn_y",
56
        "_atom_site_Cartn_z",
57
    ],
58
    "symop": [
59
        "_space_group_symop_operation_xyz",
60
        "_space_group_symop.operation_xyz",
61
        "_symmetry_equiv_pos_as_xyz",
62
    ],
63
    "spacegroup_name": ["_space_group_name_H-M_alt", "_symmetry_space_group_name_H-M"],
64
    "spacegroup_number": ["_space_group_IT_number", "_symmetry_Int_Tables_number"],
65
}
66
67
__all__ = [
68
    "CifReader",
69
    "CSDReader",
70
    "ParseError",
71
    "periodicset_from_gemmi_block",
72
    "periodicset_from_ccdc_entry",
73
    "periodicset_from_ccdc_crystal",
74
    # "periodicset_from_ase_cifblock",
75
    # "periodicset_from_pymatgen_cifblock",
76
    # "periodicset_from_ase_atoms",
77
    "periodicset_from_pymatgen_structure",
78
]
79
80
81
class _Reader(collections.abc.Iterator):
82
    """Base reader class."""
83
84
    def __init__(
85
        self,
86
        iterable: Iterable,
87
        converter: Callable[..., PeriodicSet],
88
        show_warnings: bool,
89
        verbose: bool,
90
    ):
91
        self._iterator = iter(iterable)
92
        self._converter = converter
93
        self.show_warnings = show_warnings
94
        if verbose:
95
            self._progress_bar = tqdm.tqdm(desc="Reading", delay=1)
96
        else:
97
            self._progress_bar = None
98
99
    def __next__(self):
100
        """Iterate over self._iterator, passing items through
101
        self._converter and yielding. If
102
        :class:`ParseError <.io.ParseError>` is raised in a call to
103
        self._converter, the item is skipped. Warnings raised in
104
        self._converter are printed if self.show_warnings is True.
105
        """
106
107
        if not self.show_warnings:
108
            warnings.simplefilter("ignore")
109
110
        while True:
111
            try:
112
                item = next(self._iterator)
113
            except StopIteration:
114
                if self._progress_bar is not None:
115
                    self._progress_bar.close()
116
                raise StopIteration
117
118
            with warnings.catch_warnings(record=True) as warning_msgs:
119
                try:
120
                    periodic_set = self._converter(item)
121
                except ParseError as err:
122
                    warnings.warn(str(err))
123
                    continue
124
                finally:
125
                    if self._progress_bar is not None:
126
                        self._progress_bar.update(1)
127
128
            for warning in warning_msgs:
129
                msg = f"(name={periodic_set.name}) {warning.message}"
130
                warnings.warn(msg, category=warning.category)
131
132
            return periodic_set
133
134
    def read(self) -> Union[PeriodicSet, List[PeriodicSet]]:
135
        """Read the crystal(s), return one
136
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
137
        only one, otherwise return a list.
138
        """
139
        items = list(self)
140
        if len(items) == 1:
141
            return items[0]
142
        return items
143
144
145
class CifReader(_Reader):
146
    """Read all structures in a .cif file or all files in a folder
147
    with ase or csd-python-api (if installed), yielding
148
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
149
150
    Parameters
151
    ----------
152
    path : str
153
        Path to a .CIF file or directory. (Other files are accepted when
154
        using ``reader='ccdc'``, if csd-python-api is installed.)
155
    reader : str, optional
156
        The backend package used to parse the CIF. The default is
157
        :code:`gemmi`, :code:`ccdc` is accepted if csd-python-api is
158
        installed. The ccdc backend should be able to read any format
159
        accepted by :class:`ccdc.io.EntryReader`.
160
    remove_hydrogens : bool
161
        Remove Hydrogens from the crystals.
162
    skip_disorder : bool
163
        Do not read disordered structures.
164
    heaviest_component : bool
165
        csd-python-api only. Removes all but the heaviest molecule in
166
        the asymmeric unit, intended for removing solvents.
167
    molecular_centres : bool, default False
168
        csd-python-api only. Extract the centres of molecules in the
169
        unit cell and store in the attribute molecular_centres.
170
    show_warnings : bool
171
        Controls whether warnings that arise during reading are printed.
172
    verbose : bool, default False
173
        If True, prints a progress bar showing the number of items
174
        processed.
175
176
    Yields
177
    ------
178
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
179
        Represents the crystal as a periodic set, consisting of a finite
180
        set of points (motif) and lattice (unit cell). Contains other
181
        data, e.g. the crystal's name and information about the
182
        asymmetric unit.
183
184
    Examples
185
    --------
186
187
        ::
188
189
            # Put all crystals in a .CIF in a list
190
            structures = list(amd.CifReader('mycif.cif'))
191
192
            # Can also accept path to a directory, reading all files inside
193
            structures = list(amd.CifReader('path/to/folder'))
194
195
            # Reads just one if the .CIF has just one crystal
196
            periodic_set = amd.CifReader('mycif.cif').read()
197
198
            # List of AMDs (k=100) of crystals in a .CIF
199
            amds = [amd.AMD(item, 100) for item in amd.CifReader('mycif.cif')]
200
    """
201
202
    def __init__(
203
        self,
204
        path: Union[str, os.PathLike],
205
        reader: str = "gemmi",
206
        remove_hydrogens: bool = False,
207
        skip_disorder: bool = False,
208
        missing_coords: str = "warn",
209
        eq_site_tol: float = 1e-3,
210
        show_warnings: bool = True,
211
        verbose: bool = False,
212
        heaviest_component: bool = False,
213
        molecular_centres: bool = False,
214
    ):
215
216
        if reader != "ccdc":
217
            if heaviest_component:
218
                raise NotImplementedError(
219
                    "'heaviest_component' parameter of "
220
                    f"{self.__class__.__name__} only implemented with "
221
                    "csd-python-api, if installed pass reader='ccdc'"
222
                )
223
            if molecular_centres:
224
                raise NotImplementedError(
225
                    "'molecular_centres' parameter of "
226
                    f"{self.__class__.__name__} only implemented with "
227
                    "csd-python-api, if installed pass reader='ccdc'"
228
                )
229
230
        # cannot handle some characters (�) in cifs
231
        if reader == "gemmi":
232
            import gemmi
233
234
            extensions = {"cif"}
235
            file_parser = gemmi.cif.read_file
236
            converter = functools.partial(
237
                periodicset_from_gemmi_block,
238
                remove_hydrogens=remove_hydrogens,
239
                skip_disorder=skip_disorder,
240
                missing_coords=missing_coords,
241
                eq_site_tol=eq_site_tol,
242
            )
243
244
        # elif reader in ("ase", "pycodcif"):
245
        #     from ase.io.cif import parse_cif
246
247
        #     extensions = {"cif"}
248
        #     file_parser = functools.partial(parse_cif, reader=reader)
249
        #     converter = functools.partial(
250
        #         periodicset_from_ase_cifblock,
251
        #         remove_hydrogens=remove_hydrogens,
252
        #         skip_disorder=skip_disorder,
253
        #     )
254
255
        # elif reader == "pymatgen":
256
257
        #     def _pymatgen_cif_parser(path):
258
        #         from pymatgen.io.cif import CifFile
259
        #         return CifFile.from_file(path).data.values()
260
261
        #     extensions = {"cif"}
262
        #     file_parser = _pymatgen_cif_parser
263
        #     converter = functools.partial(
264
        #         periodicset_from_pymatgen_cifblock,
265
        #         remove_hydrogens=remove_hydrogens,
266
        #         skip_disorder=skip_disorder,
267
        #     )
268
269
        elif reader == "ccdc":
270
            try:
271
                import ccdc.io
272
            except (ImportError, RuntimeError) as e:
273
                raise ImportError("Failed to import csd-python-api") from e
274
275
            extensions = set(ccdc.io.EntryReader.known_suffixes.keys())
276
            file_parser = ccdc.io.EntryReader
277
            converter = functools.partial(
278
                periodicset_from_ccdc_entry,
279
                remove_hydrogens=remove_hydrogens,
280
                skip_disorder=skip_disorder,
281
                molecular_centres=molecular_centres,
282
                heaviest_component=heaviest_component,
283
            )
284
285
        else:
286
            raise ValueError(
287
                f"'reader' parameter of {self.__class__.__name__} must be one "
288
                f"of 'gemmi', 'ccdc' (passed '{reader}')"
289
            )
290
291
        path = Path(path)
292
        if path.is_file():
293
            iterable = file_parser(str(path))
294
        elif path.is_dir():
295
            iterable = CifReader._dir_generator(path, file_parser, extensions)
296
        else:
297
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), str(path))
298
299
        super().__init__(iterable, converter, show_warnings, verbose)
300
301
    @staticmethod
302
    def _dir_generator(
303
        path: Path, file_parser: Callable, extensions: Iterable
304
    ) -> Iterator:
305
        """Generate items from all files with extensions in
306
        ``extensions`` from a directory ``path``."""
307
        for file_path in path.iterdir():
308
            if not file_path.is_file():
309
                continue
310
            if file_path.suffix[1:].lower() not in extensions:
311
                continue
312
            try:
313
                yield from file_parser(str(file_path))
314
            except Exception as e:
315
                warnings.warn(
316
                    f'Error parsing "{str(file_path)}", skipping file. '
317
                    f"Exception: {repr(e)}"
318
                )
319
320
321
class CSDReader(_Reader):
322
    """Read structures from the CSD with csd-python-api, yielding
323
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
324
325
    Parameters
326
    ----------
327
    refcodes : str or List[str], optional
328
        Single or list of CSD refcodes to read. If None or 'CSD',
329
        iterates over the whole CSD.
330
    refcode_families : bool, optional
331
        Interpret ``refcodes`` as one or more refcode families, reading
332
        all entries whose refcode starts with the given strings.
333
    remove_hydrogens : bool, optional
334
        Remove hydrogens from the crystals.
335
    skip_disorder : bool
336
        Do not read disordered structures.
337
    heaviest_component : bool, optional
338
        Removes all but the heaviest molecule in the asymmeric unit,
339
        intended for removing solvents.
340
    molecular_centres : bool, default False
341
        Extract the centres of molecules in the unit cell and store in
342
        attribute molecular_centres.
343
    show_warnings : bool, optional
344
        Controls whether warnings that arise during reading are printed.
345
    verbose : bool, default False
346
        If True, prints a progress bar showing the number of items
347
        processed.
348
349
    Yields
350
    ------
351
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
352
        Represents the crystal as a periodic set, consisting of a finite
353
        set of points (motif) and lattice (unit cell). Contains other
354
        useful data, e.g. the crystal's name and information about the
355
        asymmetric unit for calculation.
356
357
    Examples
358
    --------
359
360
        ::
361
362
            # Put these entries in a list
363
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
364
            structures = list(amd.CSDReader(refcodes))
365
366
            # Read refcode families (any whose refcode starts with strings in the list)
367
            families = ['ACSALA', 'HXACAN']
368
            structures = list(amd.CSDReader(families, refcode_families=True))
369
370
            # Get AMDs (k=100) for crystals in these families
371
            refcodes = ['ACSALA', 'HXACAN']
372
            amds = []
373
            for periodic_set in amd.CSDReader(refcodes, refcode_families=True):
374
                amds.append(amd.AMD(periodic_set, 100))
375
376
            # Giving the reader nothing reads from the whole CSD.
377
            for periodic_set in amd.CSDReader():
378
                ...
379
    """
380
381
    def __init__(
382
        self,
383
        refcodes: Optional[Union[str, List[str]]] = None,
384
        refcode_families: bool = False,
385
        remove_hydrogens: bool = False,
386
        skip_disorder: bool = False,
387
        missing_coords: str = "warn",
388
        eq_site_tol: float = 1e-3,
389
        show_warnings: bool = True,
390
        verbose: bool = False,
391
        heaviest_component: bool = False,
392
        molecular_centres: bool = False,
393
    ):
394
395
        try:
396
            import ccdc.search
397
            import ccdc.io
398
        except (ImportError, RuntimeError) as e:
399
            raise ImportError("Failed to import csd-python-api") from e
400
401
        if isinstance(refcodes, str) and refcodes.lower() == "csd":
402
            refcodes = None
403
        if refcodes is None:
404
            refcode_families = False
405
        elif isinstance(refcodes, str):
406
            refcodes = [refcodes]
407
        elif isinstance(refcodes, list):
408
            if not all(isinstance(refcode, str) for refcode in refcodes):
409
                raise ValueError(
410
                    f"List passed to {self.__class__.__name__} contains " "non-strings."
411
                )
412
        else:
413
            raise ValueError(
414
                f"{self.__class__.__name__} expects None, a string or list of "
415
                f"strings, got {refcodes.__class__.__name__}"
416
            )
417
418
        if refcode_families:
419
            all_refcodes = []
420
            for refcode in refcodes:
421
                query = ccdc.search.TextNumericSearch()
422
                query.add_identifier(refcode)
423
                hits = [hit.identifier for hit in query.search()]
424
                all_refcodes.extend(hits)
425
            # filter to unique refcodes while keeping order
426
            refcodes = []
427
            seen = set()
428
            for refcode in all_refcodes:
429
                if refcode not in seen:
430
                    refcodes.append(refcode)
431
                    seen.add(refcode)
432
433
        converter = functools.partial(
434
            periodicset_from_ccdc_entry,
435
            remove_hydrogens=remove_hydrogens,
436
            skip_disorder=skip_disorder,
437
            missing_coords=missing_coords,
438
            eq_site_tol=eq_site_tol,
439
            heaviest_component=heaviest_component,
440
            molecular_centres=molecular_centres,
441
        )
442
443
        entry_reader = ccdc.io.EntryReader("CSD")
444
        if refcodes is None:
445
            iterable = entry_reader
446
        else:
447
            iterable = map(entry_reader.entry, refcodes)
448
449
        super().__init__(iterable, converter, show_warnings, verbose)
450
451
452
# class MPReader(_Reader):
453
#     """Read structures from the Materials Project API, yielding
454
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
455
456
#     Parameters
457
#     ----------
458
#     mp_api_key : str
459
#         ...
460
#     ids : str or list of str, optional
461
#         ...
462
#     remove_hydrogens : bool, optional
463
#         Remove hydrogens from the crystals.
464
#     show_warnings : bool, optional
465
#         Controls whether warnings that arise during reading are printed.
466
#     verbose : bool, default False
467
#         If True, prints a progress bar showing the number of items
468
#         processed.
469
470
#     Yields
471
#     ------
472
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
473
#         Represents the crystal as a periodic set, consisting of a finite
474
#         set of points (motif) and lattice (unit cell). Contains other
475
#         useful data, e.g. the crystal's name and information about the
476
#         asymmetric unit for calculation.
477
#     """
478
479
#     def __init__(
480
#         self,
481
#         mp_api_key: str,
482
#         ids: Optional[Union[str, List[str]]] = None,
483
#         remove_hydrogens: bool = False,
484
#         show_warnings: bool = True,
485
#         verbose: bool = False,
486
#     ):
487
#         from mp_api.client import MPRester
488
489
#         if isinstance(ids, str):
490
#             ids = [ids]
491
#         elif isinstance(ids, list):
492
#             if not all(isinstance(i, str) for i in ids):
493
#                 raise ValueError(
494
#                     f"{self.__class__.__name__} expects None, a string or "
495
#                     "list of strings."
496
#                 )
497
#         else:
498
#             raise ValueError(
499
#                 f"{self.__class__.__name__} expects None, a string or list of "
500
#                 f"strings, got {ids.__class__.__name__}"
501
#             )
502
503
#         converter = functools.partial(
504
#             self._periodicset_from_mp_api_doc, remove_hydrogens=remove_hydrogens
505
#         )
506
507
#         with MPRester(mp_api_key) as mpr:
508
#             docs = mpr.materials.summary.search(
509
#                 fields=["material_id", "structure"], material_ids=ids
510
#             )
511
512
#         super().__init__(docs, converter, show_warnings, verbose)
513
514
#     @staticmethod
515
#     def _periodicset_from_mp_api_doc(doc, remove_hydrogens: bool = False):
516
#         periodic_set = periodicset_from_pymatgen_structure(
517
#             doc.structure, remove_hydrogens=remove_hydrogens
518
#         )
519
#         periodic_set.name = doc.material_id.title().lower()
520
#         return periodic_set
521
522
523
class ParseError(ValueError):
524
    """Raised when an item cannot be parsed into a periodic set."""
525
526
    pass
527
528
529
def periodicset_from_gemmi_block(
530
    block,
531
    remove_hydrogens: bool = False,
532
    skip_disorder: bool = False,
533
    missing_coords: str = "warn",
534
    eq_site_tol: float = 1e-3,
535
) -> PeriodicSet:
536
    """Convert a :class:`gemmi.cif.Block` object to a
537
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
538
    :class:`gemmi.cif.Block` is the type returned by
539
    :func:`gemmi.cif.read_file`.
540
541
    Parameters
542
    ----------
543
    block : :class:`gemmi.cif.Block`
544
        An ase CIFBlock object representing a crystal.
545
    remove_hydrogens : bool, optional
546
        Remove Hydrogens from the crystal.
547
    skip_disorder : bool
548
        Do not read disordered structures.
549
550
    Returns
551
    -------
552
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
553
        Represents the crystal as a periodic set, consisting of a finite
554
        set of points (motif) and lattice (unit cell). Contains other
555
        useful data, e.g. the crystal's name and information about the
556
        asymmetric unit for calculation.
557
558
    Raises
559
    ------
560
    ParseError
561
        Raised if the structure fails to be parsed for any of the
562
        following: 1. Required data is missing (e.g. cell parameters),
563
        2. :code:``skip_disorder is True`` and disorder is found, 3. The
564
        motif is empty after removing H or disordered sites.
565
    """
566
567
    import gemmi
568
569
    # Unit cell
570
    cellpar = [block.find_value(t) for t in _CIF_TAGS["cellpar"]]
571
    if not all(isinstance(par, str) for par in cellpar):
572
        raise ParseError(f"{block.name} has no unit cell")
573
    cellpar = np.array([str2float(par) for par in cellpar])
574
    if np.isnan(np.sum(cellpar)):
575
        raise ParseError(f"{block.name} has no unit cell")
576
    cell = cellpar_to_cell(cellpar)
577
578
    # Asymmetric unit coordinates
579
    xyz_loop = block.find(_CIF_TAGS["atom_site_fract"]).loop
580
    if xyz_loop is None:
581
        raise ParseError(f"{block.name} has no coordinates")
582
583
    tablified_loop = [[] for _ in range(len(xyz_loop.tags))]
584
    for i, item in enumerate(xyz_loop.values):
585
        tablified_loop[i % xyz_loop.width()].append(item)
586
    loop_dict = {tag: l for tag, l in zip(xyz_loop.tags, tablified_loop)}
587
    xyz_str = [loop_dict[t] for t in _CIF_TAGS["atom_site_fract"]]
588
    asym_unit = np.transpose(np.array([[str2float(c) for c in xyz] for xyz in xyz_str]))
589
    asym_unit = np.mod(asym_unit, 1)
590
591
    # recommended by pymatgen
592
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
593
594
    # Atom labels
595
    if "_atom_site_label" in loop_dict:
596
        asym_labels = [
597
            gemmi.cif.as_string(lab) for lab in loop_dict["_atom_site_label"]
598
        ]
599
    else:
600
        asym_labels = [""] * xyz_loop.length()
601
602
    # Atomic types
603
    if "_atom_site_type_symbol" in loop_dict:
604
        symbols = []
605
        for s in loop_dict["_atom_site_type_symbol"]:
606
            sym = gemmi.cif.as_string(s)
607
            match = re.search(r"([A-Za-z][A-Za-z]?)", sym)
608
            if match is not None:
609
                sym = match.group()
610
            else:
611
                sym = ""
612
            sym = list(sym)
613
            if len(sym) > 0:
614
                sym[0] = sym[0].upper()
615
            if len(sym) > 1:
616
                sym[1] = sym[1].lower()
617
            symbols.append("".join(sym))
618
    else:  # Get atomic types from label
619
        symbols = _atomic_symbols_from_labels(asym_labels)
620
621
    asym_types = []
622
    for s in symbols:
623
        if s in ATOMIC_NUMBERS:
624
            asym_types.append(ATOMIC_NUMBERS[s])
625
        else:
626
            asym_types.append(0)
627
628
    # Fractional occupancies
629
    if "_atom_site_occupancy" in loop_dict:
630
        asym_occs = []
631
        for occ in loop_dict["_atom_site_occupancy"]:
632
            try:
633
                occ = str2float(occ)
634
            except ValueError:
635
                occ = 1
636
            if math.isnan(occ):
637
                occ = 1
638
            if skip_disorder and occ < 1:
639
                raise ParseError(f"{block.name} has disorder")
640
            asym_occs.append(occ)
641
    else:
642
        asym_occs = [1] * xyz_loop.length()
643
644
    # if all(a == 1 for a in asym_occs):
645
    #     raise ParseError('no disorder')
646
647
    if "_atom_site_disorder_group" in loop_dict:
648
        groups = loop_dict["_atom_site_disorder_group"]
649
    else:
650
        groups = [None] * len(asym_occs)
651
652
    if "_atom_site_disorder_assembly" in loop_dict:
653
        assemblies = loop_dict["_atom_site_disorder_assembly"]
654
    else:
655
        assemblies = [None] * len(asym_occs)
656
657
    assemblies = [None if x in (".", "?") else x for x in assemblies]
658
    groups = [None if x in (".", "?") else x for x in groups]
659
660
    # Missing coordinates
661
    remove_sites = []
662
    where_missing_atoms = np.isnan(asym_unit.min(axis=-1))
663
    if np.any(where_missing_atoms):
664
        if missing_coords == "skip":
665
            raise ParseError(f"{block.name} has missing coordinates")
666
        elif missing_coords == "warn":
667
            warnings.warn(f"{block.name} has missing coordinates")
668
        remove_sites.extend(np.nonzero(where_missing_atoms)[0])
669
670
    # Remove dummy sites
671
    # remove_sites.extend(i for i, occ in enumerate(asym_occs) if occ > 1)
672
    if "_atom_site_calc_flag" in loop_dict:
673
        calc_flags = [gemmi.cif.as_string(f) for f in loop_dict["_atom_site_calc_flag"]]
674
        remove_sites.extend(i for i, f in enumerate(calc_flags) if f == "dum")
675
676
    # # Remove sites with disorder if needed or skip
677
    # if skip_disorder:
678
    #     if any(_has_disorder(l, o) for l, o in zip(asym_labels, asym_occs)):
679
    #         raise ParseError(f"{block.name} has disorder")
680
681
    if remove_hydrogens:
682
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
683
684
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
685
    if asym_unit.shape[0] == 0:
686
        raise ParseError(f"{block.name} has no coordinates")
687
688
    # Symmetry operations, try xyz strings first
689
    for tag in _CIF_TAGS["symop"]:
690
        sitesym = [v.str(0) for v in block.find([tag])]
691
        if sitesym:
692
            rot, trans = _parse_sitesyms(sitesym)
693
            break
694
    else:
695
        # Try spacegroup name; can be a pair or in a loop
696
        spg = None
697
        for tag in _CIF_TAGS["spacegroup_name"]:
698
            for value in block.find([tag]):
699
                try:
700
                    # Some names cannot be parsed by gemmi.SpaceGroup
701
                    spg = gemmi.SpaceGroup(value.str(0))
702
                    break
703
                except ValueError:
704
                    continue
705
            if spg is not None:
706
                break
707
708
        if spg is None:
709
            # Try international number
710
            for tag in _CIF_TAGS["spacegroup_number"]:
711
                spg_num = block.find_value(tag)
712
                if spg_num is not None:
713
                    spg_num = gemmi.cif.as_int(spg_num)
714
                    break
715
            else:
716
                warnings.warn(f"{block.name} has no symmetry data, defaulting to P1")
717
                spg_num = 1
718
            spg = gemmi.SpaceGroup(spg_num)
719
720
        rot = np.array([np.array(o.rot) / o.DEN for o in spg.operations()])
721
        trans = np.array([np.array(o.tran) / o.DEN for o in spg.operations()])
722
723
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, eq_site_tol)
724
    _, wyc_muls = np.unique(invs, return_counts=True)
725
    wyc_muls = np.array(wyc_muls, dtype=np.uint64)
726
    asym_inds = np.zeros_like(wyc_muls, dtype=np.uint64)
727
    asym_inds[1:] = np.cumsum(wyc_muls, dtype=np.uint64)[:-1]
728
    motif = np.matmul(frac_motif, cell)
729
    
730
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
731
    asym_occs = [s for i, s in enumerate(asym_occs) if i not in remove_sites]
732
    asym_labels = [s for i, s in enumerate(asym_labels) if i not in remove_sites]
733
    assemblies = [s for i, s in enumerate(assemblies) if i not in remove_sites]
734
    groups = [s for i, s in enumerate(groups) if i not in remove_sites]
735
    disorder = _disorder_assemblies(
736
        asym_unit, wyc_muls, cell, assemblies, groups, asym_occs, eq_site_tol
737
    )
738
    asym_types = np.array(asym_types, dtype=np.uint64)
739
740
    return PeriodicSet(
741
        motif=motif,
742
        cell=cell,
743
        name=block.name,
744
        asym_unit=asym_inds,
745
        multiplicities=wyc_muls,
746
        types=asym_types,
747
        labels=asym_labels,
748
        disorder=disorder,
749
    )
750
751
752
def periodicset_from_ccdc_entry(
753
    entry,
754
    remove_hydrogens: bool = False,
755
    skip_disorder: bool = False,
756
    missing_coords: str = "warn",
757
    eq_site_tol: float = 1e-3,
758
    heaviest_component: bool = False,
759
    molecular_centres: bool = False,
760
) -> PeriodicSet:
761
    """Convert a :class:`ccdc.entry.Entry` object to a
762
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
763
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
764
765
    Parameters
766
    ----------
767
    entry : :class:`ccdc.entry.Entry`
768
        A ccdc Entry object representing a database entry.
769
    remove_hydrogens : bool, optional
770
        Remove Hydrogens from the crystal.
771
    skip_disorder : bool
772
        Do not read disordered structures.
773
    heaviest_component : bool, optional
774
        Removes all but the heaviest molecule in the asymmeric unit,
775
        intended for removing solvents.
776
    molecular_centres : bool, default False
777
        Use molecular centres of mass as the motif instead of centres of
778
        atoms.
779
780
    Returns
781
    -------
782
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
783
        Represents the crystal as a periodic set, consisting of a finite
784
        set of points (motif) and lattice (unit cell). Contains other
785
        useful data, e.g. the crystal's name and information about the
786
        asymmetric unit for calculation.
787
788
    Raises
789
    ------
790
    ParseError
791
        Raised if the structure fails parsing for any of the following:
792
        1. entry.has_3d_structure is False, 2.
793
        :code:``disorder == 'skip'`` and disorder is found on any atom,
794
        3. entry.crystal.molecule.all_atoms_have_sites is False,
795
        4. a.fractional_coordinates is None for any a in
796
        entry.crystal.disordered_molecule, 5. The motif is empty after
797
        removing Hydrogens and disordered sites.
798
    """
799
800
    # Entry specific flag
801
    if not entry.has_3d_structure:
802
        raise ParseError(f"{entry.identifier} has no 3D structure")
803
804
    # Disorder
805
    if skip_disorder and entry.has_disorder:
806
        raise ParseError(f"{entry.identifier} has disorder")
807
808
    return periodicset_from_ccdc_crystal(
809
        entry.crystal,
810
        remove_hydrogens=remove_hydrogens,
811
        skip_disorder=skip_disorder,
812
        missing_coords=missing_coords,
813
        eq_site_tol=eq_site_tol,
814
        heaviest_component=heaviest_component,
815
        molecular_centres=molecular_centres,
816
    )
817
818
819
def periodicset_from_ccdc_crystal(
820
    crystal,
821
    remove_hydrogens: bool = False,
822
    skip_disorder: bool = False,
823
    missing_coords: str = "warn",
824
    eq_site_tol: float = 1e-3,
825
    heaviest_component: bool = False,
826
    molecular_centres: bool = False,
827
) -> PeriodicSet:
828
    """Convert a :class:`ccdc.crystal.Crystal` object to a
829
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
830
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
831
832
    Parameters
833
    ----------
834
    crystal : :class:`ccdc.crystal.Crystal`
835
        A ccdc Crystal object representing a crystal structure.
836
    remove_hydrogens : bool, optional
837
        Remove Hydrogens from the crystal.
838
    skip_disorder : bool
839
        Do not read disordered structures.
840
    heaviest_component : bool, optional
841
        Removes all but the heaviest molecule in the asymmeric unit,
842
        intended for removing solvents.
843
    molecular_centres : bool, default False
844
        Use molecular centres of mass as the motif instead of centres of
845
        atoms.
846
847
    Returns
848
    -------
849
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
850
        Represents the crystal as a periodic set, consisting of a finite
851
        set of points (motif) and lattice (unit cell). Contains other
852
        useful data, e.g. the crystal's name and information about the
853
        asymmetric unit for calculation.
854
855
    Raises
856
    ------
857
    ParseError
858
        Raised if the structure fails parsing for any of the following:
859
        1. :code:``disorder == 'skip'`` and disorder is found on any
860
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
861
        3. a.fractional_coordinates is None for any a in
862
        crystal.disordered_molecule, 4. The motif is empty after
863
        removing H, disordered sites or solvents.
864
    """
865
866
    molecule = crystal.asymmetric_unit_molecule
867
868
    # Disorder
869
    if skip_disorder:
870
        if crystal.has_disorder:
871
            raise ParseError(f"{crystal.identifier} has disorder")
872
873
    if remove_hydrogens:
874
        molecule.remove_atoms(a for a in molecule.atoms if a.atomic_symbol in "HD")
875
876
    # Missing coordinates
877
    if any(a.fractional_coordinates is None for a in molecule.atoms):
878
        if missing_coords == "skip":
879
            raise ParseError(f"{crystal.identifier} has missing coordinates")
880
        elif missing_coords == "warn":
881
            warnings.warn(f"{crystal.identifier} has missing coordinates")
882
883
        molecule.remove_atoms(
884
            a for a in molecule.atoms if a.fractional_coordinates is None
885
        )
886
887
    if heaviest_component and len(molecule.components) > 1:
888
        molecule = _heaviest_component_ccdc(molecule)
889
890
    # Unit cell
891
    cellpar = crystal.cell_lengths + crystal.cell_angles
892
    if None in cellpar:
893
        raise ParseError(f"{crystal.identifier} has no unit cell")
894
    cell = cellpar_to_cell(np.array(cellpar))
895
896
    if molecular_centres:
897
        frac_centres = _frac_molecular_centres_ccdc(crystal, eq_site_tol)
898
        mol_centres = np.matmul(frac_centres, cell)
899
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
900
901
    asym_atoms = molecule.atoms
902
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
903
904
    if asym_unit.shape[0] == 0:
905
        raise ParseError(f"{crystal.identifier} has no coordinates")
906
907
    asym_unit = np.mod(asym_unit, 1)
908
909
    # Symmetry operations
910
    sitesym = crystal.symmetry_operators
911
    if not sitesym:
912
        warnings.warn(f"{crystal.identifier} has no symmetry data, defaulting to P1")
913
        sitesym = ["x,y,z"]
914
915
    # Apply symmetries to asymmetric unit
916
    rot, trans = _parse_sitesyms(sitesym)
917
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, eq_site_tol)
918
    _, wyc_muls = np.unique(invs, return_counts=True)
919
    wyc_muls = np.array(wyc_muls, dtype=np.uint64)
920
    asym_inds = np.zeros_like(wyc_muls, dtype=np.uint64)
921
    asym_inds[1:] = np.cumsum(wyc_muls, dtype=np.uint64)[:-1]
922
    motif = np.matmul(frac_motif, cell)
923
924
    asym_types = np.array([a.atomic_number for a in asym_atoms], dtype=np.uint64)
925
    asym_labels = [a.label for a in asym_atoms]
926
    asym_occs = np.array([float(a.occupancy) for a in asym_atoms])
927
    assemblies = [None] * len(asym_unit)
928
    groups = [None] * len(asym_unit)
929
930
    if crystal.has_disorder and not crystal.disorder.is_suppressed:
931
        for asm in crystal.disorder.assemblies:
932
            for grp in asm.groups:
933
                for atom in grp.atoms:
934
                    assemblies[atom.index] = asm.id
935
                    groups[atom.index] = grp.id
936
937
    disorder = _disorder_assemblies(
938
        asym_unit, wyc_muls, cell, assemblies, groups, asym_occs, eq_site_tol
939
    )
940
941
    return PeriodicSet(
942
        motif=motif,
943
        cell=cell,
944
        name=crystal.identifier,
945
        asym_unit=asym_inds,
946
        multiplicities=wyc_muls,
947
        types=asym_types,
948
        labels=asym_labels,
949
        disorder=disorder,
950
    )
951
952
953
def _parse_sitesyms(symmetries: List[str]) -> Tuple[FloatArray, FloatArray]:
954
    """Parse a sequence of symmetries in xyz form and return rotation
955
    and translation arrays.
956
    """
957
    n = len(symmetries)
958
    rotations = np.empty((n, 3, 3), dtype=np.float64)
959
    translations = np.empty((n, 3), dtype=np.float64)
960
    # rotations = []
961
    # translations = []
962
    for i, sym in enumerate(symmetries):
963
        rot, trans = _parse_sitesym(sym)
964
        rotations[i] = rot
965
        translations[i] = trans
966
        # rotations.append(rot)
967
        # translations.append(trans)
968
    return rotations, translations
969
970
971
def memoize(f):
972
    """Cache for _parse_sitesym()."""
973
    cache = {}
974
975
    def wrapper(arg):
976
        if arg not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
977
            cache[arg] = f(arg)
978
        return cache[arg]
979
980
    return wrapper
981
982
983
@memoize
984
def _parse_sitesym(sym: str) -> Tuple[FloatArray, FloatArray]:
985
    """Parse a single symmetry as an xyz string and return a 3x3
986
    rotation matrix and a 3x1 translation vector.
987
    """
988
989
    rot = np.zeros((3, 3), dtype=np.float64)
990
    trans = np.zeros((3,), dtype=np.float64)
991
992
    for ind, element in enumerate(sym.split(",")):
993
        is_positive = True
994
        is_fraction = False
995
        sng_trans = None
996
        fst_trans = []
997
        snd_trans = []
998
999
        for char in element.lower():
1000
            if char == "+":
1001
                is_positive = True
1002
            elif char == "-":
1003
                is_positive = False
1004
            elif char == "/":
1005
                is_fraction = True
1006
            elif char in "xyz":
1007
                rot_sgn = 1.0 if is_positive else -1.0
1008
                rot[ind][ord(char) - ord("x")] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1009
            elif char.isdigit() or char == ".":
1010
                if sng_trans is None:
1011
                    sng_trans = 1.0 if is_positive else -1.0
1012
                if is_fraction:
1013
                    snd_trans.append(char)
1014
                else:
1015
                    fst_trans.append(char)
1016
1017
        if not fst_trans:
1018
            e_trans = 0.0
1019
        else:
1020
            e_trans = sng_trans * float("".join(fst_trans))
1021
1022
        if is_fraction:
1023
            e_trans /= float("".join(snd_trans))
1024
1025
        trans[ind] = e_trans
1026
1027
    return rot, trans
1028
1029
1030
def _expand_asym_unit(
1031
    asym_unit: FloatArray,
1032
    rotations: FloatArray,
1033
    translations: FloatArray,
1034
    tol: float,
1035
    uniquify_sites: bool = False,
1036
) -> Tuple[FloatArray, UIntArray]:
1037
    """Expand the asymmetric unit by applying symmetries given by
1038
    ``rotations`` and ``translations`` and removing invariant points.
1039
    """
1040
1041
    asym_unit = asym_unit.astype(np.float64, copy=False)
1042
    rotations = rotations.astype(np.float64, copy=False)
1043
    translations = translations.astype(np.float64, copy=False)
1044
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1045
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1046
1047
    if uniquify_sites:
1048
        if not all(_unique_sites(frac_motif, tol)):
1049
            frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1050
1051
    return frac_motif, invs
1052
1053
1054
@numba.njit(cache=True)
1055
def _expand_sites(
1056
    asym_unit: FloatArray, rotations: FloatArray, translations: FloatArray
1057
) -> FloatArray:
1058
    """Expand the asymmetric unit by applying ``rotations`` and
1059
    ``translations``, without yet removing points duplicated because
1060
    they are invariant under a symmetry. Returns a 3D array shape
1061
    (#points, #syms, dims).
1062
    """
1063
1064
    m, dims = asym_unit.shape
1065
    n_syms = len(rotations)
1066
    expanded_sites = np.empty((m, n_syms, dims), dtype=np.float64)
1067
    for i in range(m):
1068
        for j in range(n_syms):
1069
            for dim in range(3):
1070
                v = 0
1071
                for dim_ in range(3):
1072
                    v += rotations[j, dim, dim_] * asym_unit[i, dim_]
1073
                expanded_sites[i, j, dim] = v + translations[j, dim]
1074
    return np.mod(expanded_sites, 1)
1075
1076
1077
@numba.njit(cache=True)
1078
def _reduce_expanded_sites(
1079
    expanded_sites: FloatArray, tol: float
1080
) -> Tuple[FloatArray, UIntArray]:
1081
    """Reduce the asymmetric unit after being expended by symmetries by
1082
    removing invariant points. Assumes that no two sites in the
1083
    asymmetric unit are equivalent.
1084
    """
1085
1086
    all_unqiue_inds = []
1087
    n_sites, _, dims = expanded_sites.shape
1088
    multiplicities = np.empty(shape=(n_sites,), dtype=np.uint64)
1089
1090
    for i in range(n_sites):
1091
        unique_inds = _unique_sites(expanded_sites[i], tol)
1092
        all_unqiue_inds.append(unique_inds)
1093
        multiplicities[i] = np.sum(unique_inds)
1094
1095
    m = np.sum(multiplicities)
1096
    frac_motif = np.empty(shape=(m, dims), dtype=np.float64)
1097
    inverses = np.empty(shape=(m,), dtype=np.uint64)
1098
1099
    s = 0
1100
    for i in range(n_sites):
1101
        t = s + multiplicities[i]
1102
        frac_motif[s:t] = expanded_sites[i, all_unqiue_inds[i]]
1103
        inverses[s:t] = i
1104
        s = t
1105
1106
    return frac_motif, inverses
1107
1108
1109
def _reduce_expanded_equiv_sites(
1110
    expanded_sites: FloatArray, tol: float
1111
) -> Tuple[FloatArray, UIntArray]:
1112
    """Reduce the asymmetric unit after being expended by symmetries by
1113
    removing invariant points. This version also removes symmetrically
1114
    equivalent sites in the asymmetric unit.
1115
    """
1116
1117
    sites = expanded_sites[0]
1118
    unique_inds = _unique_sites(sites, tol)
1119
    frac_motif = sites[unique_inds]
1120
    inverses = [0] * len(frac_motif)
1121
1122
    for i in range(1, len(expanded_sites)):
1123
        sites = expanded_sites[i]
1124
        unique_inds = _unique_sites(sites, tol)
1125
1126
        points = []
1127
        for site in sites[unique_inds]:
1128
            diffs1 = np.abs(site - frac_motif)
1129
            diffs2 = np.abs(diffs1 - 1)
1130
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1131
1132
            if not np.any(mask):
1133
                points.append(site)
1134
            else:
1135
                warnings.warn(
1136
                    "has equivalent sites at positions "
1137
                    f"{inverses[np.argmax(mask)]}, {i}"
1138
                )
1139
1140
        if points:
1141
            inverses.extend(i for _ in range(len(points)))
1142
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1143
1144
    return frac_motif, np.array(inverses, dtype=np.uint64)
1145
1146
1147
@numba.njit(cache=True)
1148
def _unique_sites(asym_unit: FloatArray, tol: float) -> npt.NDArray[np.bool_]:
1149
    """Uniquify (within tol) a list of fractional coordinates,
1150
    considering all points modulo 1. Return an array of bools such that
1151
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1152
    """
1153
1154
    m, dims = asym_unit.shape
1155
    where_unique = np.full(shape=(m,), fill_value=True, dtype=np.bool_)
1156
    for i in range(1, m):
1157
        for j in range(i):
1158
            for d in range(dims):
1159
                diff1 = np.mod(np.abs(asym_unit[i, d] - asym_unit[j, d]), 1)
1160
                diff2 = np.abs(diff1 - 1)
1161
                if min(diff1, diff2) > tol:
1162
                    break
1163
            else:
1164
                where_unique[i] = False
1165
                break
1166
1167
    return where_unique
1168
1169
1170
def _disorder_assemblies(
1171
        asym_unit, multiplicities, cell, assemblies, groups, asym_occs, eq_site_tol
1172
):
1173
1174
    disorder = {}
1175
1176
    # Follow given assemblies and groups
1177
    for i, (asmbly, grp) in enumerate(zip(assemblies, groups)):
1178
1179
        if asym_occs[i] == 1:
1180
            continue
1181
1182
        if asmbly in disorder:
1183
            if grp in disorder[asmbly]:
1184
                disorder[asmbly][grp].append(i)
1185
            else:
1186
                disorder[asmbly][grp] = [i]
1187
        else:
1188
            disorder[asmbly] = {grp: [i]}
1189
1190
    # FIX: what to do with atoms with a given group but no assembly?
1191
1192
    # fractional occupancies with no given groups or assemblies
1193
    if None in disorder and None in disorder[None]:
1194
1195
        # Find substitutional disorder. Overlapping disordered atoms
1196
        # are assumed to be substitutionally disordered.
1197
        inds = disorder[None][None]
1198
        asm_count = 0
1199
        m = len(inds)
1200
        leftover_pdist = _pdist_pbc(asym_unit[inds], cell)
1201
        grps = _collapse_into_groups(leftover_pdist <= SUBS_DISORDER_TOL)
1202
1203
        leftover = []  # atoms not overlapping
1204
        for grp in grps:
1205
1206
            if len(grp) == 1:
1207
                leftover.append(inds[grp[0]])
1208
            else:
1209
                max_d = max(  # max interpoint dist in group
1210
                    leftover_pdist[m * i + j - ((i + 2) * (i + 1)) // 2]
1211
                    for i, j in itertools.combinations(grp, 2)
1212
                )
1213
                displaced_prefix = "d_" if max_d > 1e-10 else ""
1214
                asm_name = f"{displaced_prefix}sub_asm_{asm_count}"
1215
                disorder[asm_name] = {i: [inds[j]] for i, j in enumerate(grp)}
1216
                asm_count += 1
1217
1218
        # Not substiutional
1219
        if leftover:
1220
1221
            # seperate atoms with different occuapancies
1222
            leftover_occs_arr = np.array(asym_occs)[leftover]
1223
            occ_grps_pdist = (
1224
                pdist(leftover_occs_arr[None, :], metric="chebyshev") < 1e-6
1225
            )
1226
            occ_grps = _collapse_into_groups(occ_grps_pdist)
1227
            leftover_occs = np.array([leftover_occs_arr[grp[0]] for grp in occ_grps])
1228
1229
            # for now, put all 1/2 occupancy atoms in one assembly
1230
            where_half_occ = np.argwhere(leftover_occs == 0.5)
1231
            if where_half_occ.size > 0:
1232
                grp_i = where_half_occ[0, 0]
1233
                asm = {0: [leftover[j] for j in occ_grps[grp_i]]}
1234
                disorder[f"half_asm_{asm_count}"] = asm
1235
                asm_count += 1
1236
                del occ_grps[grp_i]
1237
                leftover_occs = np.delete(leftover_occs, grp_i)
1238
1239
            # Assemble atoms with occ > 1
1240
            where_gt_1 = np.argwhere(leftover_occs > 1)
1241
            if where_gt_1.size > 0:
1242
                where_gt1_flat = where_gt_1.flatten()
1243
                for grp_i in where_gt1_flat:
1244
                    asm = {0: [leftover[j] for j in occ_grps[grp_i]]}
1245
                    disorder[f"gt1_asm_{asm_count}"] = asm
1246
                    asm_count += 1
1247
1248
                for grp_i in sorted(where_gt1_flat, reverse=True):
1249
                    del occ_grps[grp_i]
1250
                leftover_occs = np.delete(leftover_occs, where_gt_1)
1251
1252
            unity_sum_grps = _tuples_sum_to_one(leftover_occs)
1253
            if unity_sum_grps:
1254
                for grps in unity_sum_grps:
1255
                    asm = {
1256
                        en: [leftover[j] for j in occ_grps[grp_i]]
1257
                        for en, (grp_i) in enumerate(grps)
1258
                    }
1259
                    disorder[f"sum_asm_{asm_count}"] = asm
1260
                    asm_count += 1
1261
1262
                all_grp_inds = [i for grp in unity_sum_grps for i in grp]
1263
                for grp_i in sorted(all_grp_inds, reverse=True):
1264
                    del occ_grps[grp_i]
1265
                leftover_occs = np.delete(leftover_occs, all_grp_inds)
1266
1267
            # assemble all remaining groups seperately.
1268
            for grp in occ_grps:
1269
                asm = {None: [leftover[j] for j in grp]}
1270
                disorder[f"asm_{asm_count}"] = asm
1271
                asm_count += 1
1272
1273
    # Check multiplicities ?
1274
    assemblies = []
1275
    for k, asm_ in disorder.items():
1276
        if k is None:
1277
            continue
1278
        assemblies.append(DisorderAssembly(
1279
            groups=[
1280
                DisorderGroup(indices=grp, occupancy=asym_occs[grp[0]])
1281
                for grp in asm_.values()
1282
            ],
1283
            is_substitutional=k.startswith("sub")
1284
        ))
1285
1286
    if not assemblies:
1287
        assemblies = None
1288
1289
    return assemblies
1290
1291
1292
@numba.njit(cache=True)
1293
def _pdist_pbc(points, cell):
1294
    m = len(points)
1295
    out = np.empty((int(m * (m - 1) / 2),), dtype=np.float64)
1296
    ind = 0
1297
    for i in range(m):
1298
        for j in range(i + 1, m):
1299
            diffs1 = np.abs(points[i] - points[j])
1300
            diffs2 = np.abs(diffs1 - 1)
1301
            diff_v = np.minimum(diffs1, diffs2) @ cell
1302
            out[ind] = np.sum(diff_v**2)
1303
            ind += 1
1304
    return np.sqrt(out)
1305
1306
1307
def _tuples_sum_to_one(unqiue_occs):
1308
    """Given a vector ``unqiue_occs`` of numbers between 0 and 1, return
1309
    tuples of indices specifying groups of elements summing to one."""
1310
1311
    ret = []
1312
    done = set()
1313
1314
    def _combs_totalling_one(occs, m):
1315
1316
        def helper(start, index, current_combination):
1317
1318
            if index == m:
1319
                if np.sum(occs[current_combination]) == 1:
1320
                    for i in current_combination:
1321
                        done.add(i)
1322
                    yield current_combination
1323
                return
1324
1325
            for i in range(start, len(occs)):
1326
                if any(current_combination[j] in done for j in range(index)):
1327
                    return
1328
                if i in done:
1329
                    continue
1330
                current_combination[index] = i
1331
                if np.sum(occs[current_combination[: index + 1]]) > 1:
1332
                    return
1333
                yield from helper(i + 1, index + 1, current_combination)
1334
1335
        combination = np.zeros((m,), dtype=np.int64)
1336
        yield from helper(0, 0, combination)
1337
1338
    argsorted = np.argsort(unqiue_occs)
1339
    occs = unqiue_occs[argsorted]
1340
1341
    for asize in range(2, len(occs)):
1342
        for comb in _combs_totalling_one(occs, asize):
1343
            ret.append(argsorted[comb])
1344
1345
    return ret
1346
1347
1348
def _has_disorder(label: str, occupancy) -> bool:
1349
    """Return True if label ends with ? or occupancy is a number < 1."""
1350
    try:
1351
        occupancy = float(occupancy)
1352
    except Exception:
1353
        occupancy = 1
1354
    return (occupancy < 1) or label.endswith("?")
1355
1356
1357
def _atomic_symbols_from_labels(symbols: List[str]) -> List[str]:
1358
    symbols_ = []
1359
    for label in symbols:
1360
        sym = ""
1361
        if label and label not in (".", "?"):
1362
            match = re.search(r"([A-Za-z][A-Za-z]?)", label)
1363
            if match is not None:
1364
                sym = match.group()
1365
            sym = list(sym)
1366
            if len(sym) > 0:
1367
                sym[0] = sym[0].upper()
1368
            if len(sym) > 1:
1369
                sym[1] = sym[1].lower()
1370
            sym = "".join(sym)
1371
        symbols_.append(sym)
1372
    return symbols_
1373
1374
1375
def _get_syms_pymatgen(data: dict) -> Tuple[FloatArray, FloatArray]:
1376
    """Parse symmetry operations given by data = block.data where block
1377
    is a pymatgen CifBlock object. If the symops are not present the
1378
    space group symbol/international number is parsed and symops are
1379
    generated.
1380
    """
1381
1382
    from pymatgen.symmetry.groups import SpaceGroup
1383
    import pymatgen.io.cif
1384
1385
    # Try xyz symmetry operations
1386
    for symmetry_label in _CIF_TAGS["symop"]:
1387
        xyz = data.get(symmetry_label)
1388
        if not xyz:
1389
            continue
1390
        if isinstance(xyz, str):
1391
            xyz = [xyz]
1392
        return _parse_sitesyms(xyz)
1393
1394
    symops = []
1395
    # Try spacegroup symbol
1396
    for symmetry_label in _CIF_TAGS["spacegroup_name"]:
1397
        sg = data.get(symmetry_label)
1398
        if not sg:
1399
            continue
1400
        sg = re.sub(r"[\s_]", "", sg)
1401
        try:
1402
            spg = pymatgen.io.cif.space_groups.get(sg)
1403
            if not spg:
1404
                continue
1405
            symops = SpaceGroup(spg).symmetry_ops
1406
            break
1407
        except ValueError:
1408
            pass
1409
        try:
1410
            for d in pymatgen.io.cif._get_cod_data():
1411
                if sg == re.sub(r"\s+", "", d["hermann_mauguin"]):
1412
                    return _parse_sitesyms(d["symops"])
1413
        except Exception:  # CHANGE
1414
            continue
1415
        if symops:
1416
            break
1417
1418
    # Try international number
1419
    if not symops:
1420
        for symmetry_label in _CIF_TAGS["spacegroup_number"]:
1421
            num = data.get(symmetry_label)
1422
            if not num:
1423
                continue
1424
            try:
1425
                i = int(str2float(num))
1426
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1427
                break
1428
            except ValueError:
1429
                continue
1430
1431
    if not symops:
1432
        warnings.warn("no symmetry data found, defaulting to P1")
1433
        return _parse_sitesyms(["x,y,z"])
1434
1435
    rotations = [op.rotation_matrix for op in symops]
1436
    translations = [op.translation_vector for op in symops]
1437
    rotations = np.array(rotations, dtype=np.float64)
1438
    translations = np.array(translations, dtype=np.float64)
1439
    return rotations, translations
1440
1441
1442
def _frac_molecular_centres_ccdc(crystal, tol: float) -> FloatArray:
1443
    """Return the geometric centres of molecules in the unit cell.
1444
    Expects a ccdc Crystal object and returns fractional coordiantes.
1445
    """
1446
1447
    frac_centres = []
1448
    for comp in crystal.packing(inclusion="CentroidIncluded").components:
1449
        coords = [a.fractional_coordinates for a in comp.atoms]
1450
        frac_centres.append([sum(ax) / len(coords) for ax in zip(*coords)])
1451
    frac_centres = np.mod(np.array(frac_centres, dtype=np.float64), 1)
1452
    return frac_centres[_unique_sites(frac_centres, tol)]
1453
1454
1455
def _heaviest_component_ccdc(molecule):
1456
    """Remove all but the heaviest component of the asymmetric unit.
1457
    Intended for removing solvents. Expects and returns a ccdc Molecule
1458
    object.
1459
    """
1460
1461
    component_weights = []
1462
    for component in molecule.components:
1463
        weight = 0
1464
        for a in component.atoms:
1465
            try:
1466
                occ = float(a.occupancy)
1467
            except:
1468
                occ = 1
1469
            try:
1470
                weight += float(a.atomic_weight) * occ
1471
            except ValueError:
1472
                pass
1473
        component_weights.append(weight)
1474
    largest_component_ind = np.argmax(np.array(component_weights))
1475
    molecule = molecule.components[largest_component_ind]
1476
    return molecule
1477
1478
1479
def str2float(string):
1480
    """Remove uncertainty brackets from strings and return the float.
1481
    Returns np.nan if given '.' or '?'."""
1482
    try:
1483
        return float(re.sub(r"\(.+\)*", "", string))
1484
    except TypeError:
1485
        if isinstance(string, list) and len(string) == 1:
1486
            return float(re.sub(r"\(.+\)*", "", string[0]))
1487
    except ValueError as e:
1488
        if string.strip() in (".", "?"):
1489
            return np.nan
1490
        raise e
1491
1492
1493
def _get_cif_tags(block, cif_tags):
1494
    import gemmi
1495
1496
    data = {}
1497
    for tag in cif_tags:
1498
        column = list(block.find_values(tag))
1499
1500
        if len(column) == 0:
1501
            data[tag] = None
1502
        else:
1503
            column_ = []
1504
            for v in column:
1505
                try:
1506
                    v = int(v)
1507
                except ValueError:
1508
                    try:
1509
                        v = str2float(v)
1510
                    except ValueError:
1511
                        v = gemmi.cif.as_string(v)
1512
                column_.append(v)
1513
            if len(column_) == 1:
1514
                data[tag] = column_[0]
1515
            else:
1516
                data[tag] = column_
1517
1518
    return data
1519
1520
1521
def _snap_small_prec_coords(frac_coords: FloatArray, tol: float) -> FloatArray:
1522
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1523
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1524
    """
1525
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.0
1526
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.0
1527
    return frac_coords
1528
1529
1530
# def periodicset_from_ase_cifblock(
1531
#     block,
1532
#     remove_hydrogens: bool = False,
1533
#     skip_disorder: bool = False,
1534
#     eq_site_tol: float = 1e-3,
1535
# ) -> PeriodicSet:
1536
#     """Convert a :class:`ase.io.cif.CIFBlock` object to a
1537
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1538
#     :class:`ase.io.cif.CIFBlock` is the type returned by
1539
#     :func:`ase.io.cif.parse_cif`.
1540
1541
#     Parameters
1542
#     ----------
1543
#     block : :class:`ase.io.cif.CIFBlock`
1544
#         An ase :class:`ase.io.cif.CIFBlock` object representing a
1545
#         crystal.
1546
#     remove_hydrogens : bool, optional
1547
#         Remove Hydrogens from the crystal.
1548
#     disorder : str, optional
1549
#         Controls how disordered structures are handled. Default is
1550
#         ``skip`` which skips any crystal with disorder, since disorder
1551
#         conflicts with the periodic set model. To read disordered
1552
#         structures anyway, choose either :code:`ordered_sites` to remove
1553
#         atoms with disorder or :code:`all_sites` include all atoms
1554
#         regardless of disorder. Note that :code:`all_sites` has
1555
#         different behaviour than :code:`periodicset_from_gemmi_block`
1556
#         and outputs of this function always have None assigned to the
1557
#         :code:`.occupancies` attribute.
1558
1559
#     Returns
1560
#     -------
1561
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1562
#         Represents the crystal as a periodic set, consisting of a finite
1563
#         set of points (motif) and lattice (unit cell). Contains other
1564
#         useful data, e.g. the crystal's name and information about the
1565
#         asymmetric unit for calculation.
1566
1567
#     Raises
1568
#     ------
1569
#     ParseError
1570
#         Raised if the structure fails to be parsed for any of the
1571
#         following: 1. Required data is missing (e.g. cell parameters),
1572
#         2. The motif is empty after removing H or disordered sites,
1573
#         3. :code:``disorder == 'skip'`` and disorder is found on any
1574
#         atom.
1575
#     """
1576
1577
#     import ase
1578
#     import ase.spacegroup
1579
1580
#     # Unit cell
1581
#     cellpar = [str2float(str(block.get(tag))) for tag in _CIF_TAGS["cellpar"]]
1582
#     if None in cellpar:
1583
#         raise ParseError(f"{block.name} has missing cell data")
1584
#     cell = cellpar_to_cell(np.array(cellpar))
1585
1586
#     # Asymmetric unit coordinates. ase removes uncertainty brackets
1587
#     asym_unit = [
1588
#         [str2float(str(n)) for n in block.get(tag)]
1589
#         for tag in _CIF_TAGS["atom_site_fract"]
1590
#     ]
1591
#     if None in asym_unit:
1592
#         asym_unit = [block.get(tag.lower()) for tag in _CIF_TAGS["atom_site_cartn"]]
1593
#         if None in asym_unit:
1594
#             raise ParseError(f"{block.name} has missing coordinates")
1595
#         else:
1596
#             raise ParseError(
1597
#                 f"{block.name} uses _atom_site_Cartn_ tags for coordinates, "
1598
#                 "only _atom_site_fract_ is supported"
1599
#             )
1600
#     asym_unit = list(zip(*asym_unit))
1601
1602
#     # Labels
1603
#     asym_labels = block.get("_atom_site_label")
1604
#     if asym_labels is None:
1605
#         asym_labels = [""] * len(asym_unit)
1606
1607
#     # Atomic types
1608
#     asym_symbols = block.get("_atom_site_type_symbol")
1609
#     if asym_symbols is not None:
1610
#         asym_symbols_ = _atomic_symbols_from_labels(asym_symbols)
1611
#     else:
1612
#         asym_symbols_ = [""] * len(asym_unit)
1613
1614
#     asym_types = []
1615
#     for s in asym_symbols_:
1616
#         if s in ATOMIC_NUMBERS:
1617
#             asym_types.append(ATOMIC_NUMBERS[s])
1618
#         else:
1619
#             asym_types.append(0)
1620
1621
#     # Find where sites have disorder if necassary
1622
#     has_disorder = []
1623
#     occupancies = block.get("_atom_site_occupancy")
1624
#     if occupancies is None:
1625
#         occupancies = [1] * len(asym_unit)
1626
#     for lab, occ in zip(asym_labels, occupancies):
1627
#         has_disorder.append(_has_disorder(lab, occ))
1628
1629
#     # Remove sites with ?, . or other invalid string for coordinates
1630
#     invalid = []
1631
#     for i, xyz in enumerate(asym_unit):
1632
#         if not all(isinstance(coord, (int, float)) for coord in xyz):
1633
#             invalid.append(i)
1634
#     if invalid:
1635
#         warnings.warn("atoms without sites or missing data will be removed")
1636
#         asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
1637
#         asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
1638
#         has_disorder = [d for i, d in enumerate(has_disorder) if i not in invalid]
1639
1640
#     remove_sites = []
1641
1642
#     if remove_hydrogens:
1643
#         remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
1644
1645
#     # Remove atoms with fractional occupancy or raise ParseError
1646
#     for i, dis in enumerate(has_disorder):
1647
#         if i in remove_sites:
1648
#             continue
1649
#         if dis:
1650
#             if skip_disorder:
1651
#                 raise ParseError(
1652
#                     f"{block.name} has disorder, pass "
1653
#                     "disorder='ordered_sites' or 'all_sites' to "
1654
#                     "remove/ignore disorder"
1655
#                 )
1656
#             remove_sites.append(i)
1657
1658
#     # Asymmetric unit
1659
#     asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
1660
#     asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
1661
#     if len(asym_unit) == 0:
1662
#         raise ParseError(f"{block.name} has no valid sites")
1663
#     asym_unit = np.mod(np.array(asym_unit), 1)
1664
#     asym_types = np.array(asym_types, dtype=np.uint8)
1665
1666
#     # # recommended by pymatgen
1667
#     # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1668
1669
#     # Get symmetry operations
1670
#     sitesym = block._get_any(_CIF_TAGS["symop"])
1671
#     if sitesym is None:
1672
#         label_or_num = block._get_any([s.lower() for s in _CIF_TAGS["spacegroup_name"]])
1673
#         if label_or_num is None:
1674
#             label_or_num = block._get_any(
1675
#                 [s.lower() for s in _CIF_TAGS["spacegroup_number"]]
1676
#             )
1677
#         if label_or_num is None:
1678
#             warnings.warn("no symmetry data found, defaulting to P1")
1679
#             label_or_num = 1
1680
#         spg = ase.spacegroup.Spacegroup(label_or_num)
1681
#         rot, trans = spg.get_op()
1682
#     else:
1683
#         if isinstance(sitesym, str):
1684
#             sitesym = [sitesym]
1685
#         rot, trans = _parse_sitesyms(sitesym)
1686
1687
#     frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, eq_site_tol)
1688
#     _, wyc_muls = np.unique(invs, return_counts=True)
1689
#     asym_inds = np.zeros_like(wyc_muls, dtype=np.int64)
1690
#     asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1691
#     motif = np.matmul(frac_motif, cell)
1692
1693
#     return PeriodicSet(
1694
#         motif=motif,
1695
#         cell=cell,
1696
#         name=block.name,
1697
#         asym_unit=asym_inds,
1698
#         multiplicities=wyc_muls,
1699
#         types=asym_types,
1700
#         occupancies=None,
1701
#     )
1702
1703
1704
# def periodicset_from_pymatgen_cifblock(
1705
#     block,
1706
#     remove_hydrogens: bool = False,
1707
#     skip_disorder: bool = False,
1708
#     eq_site_tol: float = 1e-3,
1709
# ) -> PeriodicSet:
1710
#     """Convert a :class:`pymatgen.io.cif.CifBlock` object to a
1711
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1712
#     :class:`pymatgen.io.cif.CifBlock` is the type returned by
1713
#     :class:`pymatgen.io.cif.CifFile`.
1714
1715
#     Parameters
1716
#     ----------
1717
#     block : :class:`pymatgen.io.cif.CifBlock`
1718
#         A pymatgen CifBlock object representing a crystal.
1719
#     remove_hydrogens : bool, optional
1720
#         Remove Hydrogens from the crystal.
1721
#     disorder : str, optional
1722
#         Controls how disordered structures are handled. Default is
1723
#         ``skip`` which skips any crystal with disorder, since disorder
1724
#         conflicts with the periodic set model. To read disordered
1725
#         structures anyway, choose either :code:`ordered_sites` to remove
1726
#         atoms with disorder or :code:`all_sites` include all atoms
1727
#         regardless of disorder. Note that :code:`all_sites` has
1728
#         different behaviour than :code:`periodicset_from_gemmi_block`
1729
#         and outputs of this function always have None assigned to the
1730
#         :code:`.occupancies` attribute.
1731
1732
#     Returns
1733
#     -------
1734
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1735
#         Represents the crystal as a periodic set, consisting of a finite
1736
#         set of points (motif) and lattice (unit cell). Contains other
1737
#         useful data, e.g. the crystal's name and information about the
1738
#         asymmetric unit for calculation.
1739
1740
#     Raises
1741
#     ------
1742
#     ParseError
1743
#         Raised if the structure can/should not be parsed for the
1744
#         following reasons: 1. No sites found or motif is empty after
1745
#         removing Hydrogens & disorder, 2. A site has missing
1746
#         coordinates, 3. :code:``disorder == 'skip'`` and disorder is
1747
#         found on any atom.
1748
#     """
1749
1750
#     odict = block.data
1751
1752
#     # Unit cell
1753
#     cellpar = [odict.get(tag) for tag in _CIF_TAGS["cellpar"]]
1754
#     if any(par in (None, "?", ".") for par in cellpar):
1755
#         raise ParseError(f"{block.header} has missing cell data")
1756
1757
#     try:
1758
#         cellpar = [str2float(v) for v in cellpar]
1759
#     except ValueError:
1760
#         raise ParseError(f"{block.header} could not be parsed")
1761
#     cell = cellpar_to_cell(np.array(cellpar, dtype=np.float64))
1762
1763
#     # Asymmetric unit coordinates
1764
#     asym_unit = [odict.get(tag) for tag in _CIF_TAGS["atom_site_fract"]]
1765
#     # check for . and ?
1766
#     if None in asym_unit:
1767
#         asym_unit = [odict.get(tag) for tag in _CIF_TAGS["atom_site_cartn"]]
1768
#         if None in asym_unit:
1769
#             raise ParseError(f"{block.header} has missing coordinates")
1770
#         else:
1771
#             raise ParseError(
1772
#                 f"{block.header} uses _atom_site_Cartn_ tags for coordinates, "
1773
#                 "only _atom_site_fract_ is supported"
1774
#             )
1775
#     asym_unit = list(zip(*asym_unit))
1776
#     try:
1777
#         asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
1778
#     except ValueError:
1779
#         raise ParseError(f"{block.header} could not be parsed")
1780
1781
#     # Labels
1782
#     asym_labels = odict.get("_atom_site_label")
1783
#     if asym_labels is None:
1784
#         asym_labels = [""] * len(asym_unit)
1785
1786
#     # Atomic types
1787
#     asym_symbols = odict.get("_atom_site_type_symbol")
1788
#     if asym_symbols is not None:
1789
#         asym_symbols_ = _atomic_symbols_from_labels(asym_symbols)
1790
#     else:
1791
#         asym_symbols_ = [""] * len(asym_unit)
1792
1793
#     asym_types = []
1794
#     for s in asym_symbols_:
1795
#         if s in ATOMIC_NUMBERS:
1796
#             asym_types.append(ATOMIC_NUMBERS[s])
1797
#         else:
1798
#             asym_types.append(0)
1799
1800
#     # Find where sites have disorder if necassary
1801
#     has_disorder = []
1802
#     occupancies = odict.get("_atom_site_occupancy")
1803
#     if occupancies is None:
1804
#         occupancies = np.ones((len(asym_unit),))
1805
#     else:
1806
#         occupancies = np.array([str2float(occ) for occ in occupancies])
1807
#     labels = odict.get("_atom_site_label")
1808
#     if labels is None:
1809
#         labels = [""] * len(asym_unit)
1810
#     for lab, occ in zip(labels, occupancies):
1811
#         has_disorder.append(_has_disorder(lab, occ))
1812
1813
#     # Remove sites with ?, . or other invalid string for coordinates
1814
#     invalid = []
1815
#     for i, xyz in enumerate(asym_unit):
1816
#         if not all(isinstance(coord, (int, float)) for coord in xyz):
1817
#             invalid.append(i)
1818
1819
#     if invalid:
1820
#         warnings.warn("atoms without sites or missing data will be removed")
1821
#         asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
1822
#         asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
1823
#         has_disorder = [d for i, d in enumerate(has_disorder) if i not in invalid]
1824
1825
#     remove_sites = []
1826
1827
#     if remove_hydrogens:
1828
#         remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
1829
1830
#     # Remove atoms with fractional occupancy or raise ParseError
1831
#     for i, dis in enumerate(has_disorder):
1832
#         if i in remove_sites:
1833
#             continue
1834
#         if dis:
1835
#             if skip_disorder:
1836
#                 raise ParseError(
1837
#                     f"{block.header} has disorder, pass "
1838
#                     "disorder='ordered_sites' or 'all_sites' to "
1839
#                     "remove/ignore disorder"
1840
#                 )
1841
#             remove_sites.append(i)
1842
1843
#     # Asymmetric unit
1844
#     asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
1845
#     asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
1846
#     if len(asym_unit) == 0:
1847
#         raise ParseError(f"{block.header} has no valid sites")
1848
#     asym_unit = np.mod(np.array(asym_unit), 1)
1849
#     asym_types = np.array(asym_types, dtype=np.uint8)
1850
1851
#     # recommended by pymatgen
1852
#     # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1853
1854
#     # Apply symmetries to asymmetric unit
1855
#     rot, trans = _get_syms_pymatgen(odict)
1856
#     frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, eq_site_tol)
1857
#     _, wyc_muls = np.unique(invs, return_counts=True)
1858
#     asym_inds = np.zeros_like(wyc_muls, dtype=np.int64)
1859
#     asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1860
#     motif = np.matmul(frac_motif, cell)
1861
1862
#     return PeriodicSet(
1863
#         motif=motif,
1864
#         cell=cell,
1865
#         name=block.header,
1866
#         asym_unit=asym_inds,
1867
#         multiplicities=wyc_muls,
1868
#         types=asym_types,
1869
#         occupancies=None,
1870
#     )
1871
1872
1873
# def periodicset_from_ase_atoms(
1874
#     atoms, remove_hydrogens: bool = False, eq_site_tol: float = 1e-3
1875
# ) -> PeriodicSet:
1876
#     """Convert an :class:`ase.atoms.Atoms` object to a
1877
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
1878
#     the option to remove disorder.
1879
1880
#     Parameters
1881
#     ----------
1882
#     atoms : :class:`ase.atoms.Atoms`
1883
#         An ase :class:`ase.atoms.Atoms` object representing a crystal.
1884
#     remove_hydrogens : bool, optional
1885
#         Remove Hydrogens from the crystal.
1886
1887
#     Returns
1888
#     -------
1889
#     :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1890
#         Represents the crystal as a periodic set, consisting of a finite
1891
#         set of points (motif) and lattice (unit cell). Contains other
1892
#         useful data, e.g. the crystal's name and information about the
1893
#         asymmetric unit for calculation.
1894
1895
#     Raises
1896
#     ------
1897
#     ParseError
1898
#         Raised if there are no valid sites in atoms.
1899
#     """
1900
1901
#     from ase.spacegroup import get_basis
1902
1903
#     cell = atoms.get_cell().array
1904
1905
#     remove_inds = []
1906
#     if remove_hydrogens:
1907
#         for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
1908
#             remove_inds.append(i)
1909
#     for i in sorted(remove_inds, reverse=True):
1910
#         atoms.pop(i)
1911
1912
#     if len(atoms) == 0:
1913
#         raise ParseError("ase Atoms object has no valid sites")
1914
1915
#     # Symmetry operations from spacegroup
1916
#     spg = None
1917
#     if "spacegroup" in atoms.info:
1918
#         spg = atoms.info["spacegroup"]
1919
#         rot, trans = spg.rotations, spg.translations
1920
#     else:
1921
#         warnings.warn("no symmetry data found, defaulting to P1")
1922
#         rot = np.identity(3)[None, :]
1923
#         trans = np.zeros((1, 3))
1924
1925
#     # Asymmetric unit. ase default tol is 1e-5
1926
#     # do differently! get_basis determines a reduced asym unit from the atoms;
1927
#     # surely this is not needed!
1928
#     asym_unit = get_basis(atoms, spacegroup=spg, tol=eq_site_tol)
1929
#     frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, eq_site_tol)
1930
#     _, wyc_muls = np.unique(invs, return_counts=True)
1931
#     asym_inds = np.zeros_like(wyc_muls, dtype=np.int64)
1932
#     asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1933
#     motif = np.matmul(frac_motif, cell)
1934
#     motif_types = atoms.get_atomic_numbers()
1935
#     types = np.array([motif_types[i] for i in asym_inds], dtype=np.uint8)
1936
1937
#     return PeriodicSet(
1938
#         motif=motif,
1939
#         cell=cell,
1940
#         asym_unit=asym_inds,
1941
#         multiplicities=wyc_muls,
1942
#         types=types,
1943
#         occupancies=None,
1944
#     )
1945
1946
1947
def periodicset_from_pymatgen_structure(
1948
    structure,
1949
    remove_hydrogens: bool = False,
1950
    skip_disorder: bool = False,
1951
    missing_coords: str = "warn",
1952
    eq_site_tol: float = 1e-3,
1953
) -> PeriodicSet:
1954
    """Convert a :class:`pymatgen.core.structure.Structure` object to a
1955
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not set
1956
    the name of the periodic set, as pymatgen Structure objects seem to
1957
    have no name attribute.
1958
1959
    Parameters
1960
    ----------
1961
    structure : :class:`pymatgen.core.structure.Structure`
1962
        A pymatgen Structure object representing a crystal.
1963
1964
    Returns
1965
    -------
1966
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1967
        Represents the crystal as a periodic set, consisting of a finite
1968
        set of points (motif) and lattice (unit cell). Contains other
1969
        useful data, e.g. the crystal's name and information about the
1970
        asymmetric unit for calculation.
1971
    """
1972
1973
    if remove_hydrogens:
1974
        structure.remove_species(["H", "D"])
1975
1976
    # # Disorder
1977
    # if skip_disorder:
1978
    #     if not structure.is_ordered:
1979
    #         raise ParseError("Structure has disorder")
1980
    # else:
1981
    #     remove_inds = []
1982
    #     for i, comp in enumerate(structure.species_and_occu):
1983
    #         if comp.num_atoms < 1:
1984
    #             remove_inds.append(i)
1985
    #     structure.remove_sites(remove_inds)
1986
1987
    motif = structure.cart_coords
1988
    cell = structure.lattice.matrix
1989
    types = np.array(structure.atomic_numbers, dtype=np.uint64)
1990
    # occupancies = np.ones((len(motif), ), dtype=np.float64)
1991
1992
    return PeriodicSet(
1993
        motif=motif,
1994
        cell=cell,
1995
        types=types,
1996
        # occupancies=occupancies,
1997
    )
1998