Passed
Push — master ( 7f7a74...2c655b )
by Daniel
05:50
created

amd.io._reduce_expanded_equiv_sites()   B

Complexity

Conditions 5

Size

Total Lines 37
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 37
rs 8.8853
c 0
b 0
f 0
cc 5
nop 2
1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import warnings
9
import collections
10
import os
11
import re
12
import functools
13
import errno
14
import math
15
import json
16
from pathlib import Path
17
from typing import (
18
    Iterable,
19
    Iterator,
20
    Optional,
21
    Union,
22
    Callable,
23
    Tuple,
24
    List
25
)
26
27
import numpy as np
28
import numpy.typing as npt
29
import numba
30
import tqdm
31
32
from .utils import cellpar_to_cell
33
from .periodicset import PeriodicSet
34
35
__all__ = [
36
    'CifReader',
37
    'CSDReader',
38
    'ParseError',
39
    'periodicset_from_gemmi_block',
40
    'periodicset_from_ase_cifblock',
41
    'periodicset_from_pymatgen_cifblock',
42
    'periodicset_from_ase_atoms',
43
    'periodicset_from_pymatgen_structure',
44
    'periodicset_from_ccdc_entry',
45
    'periodicset_from_ccdc_crystal',
46
]
47
48
FloatArray = npt.NDArray[np.floating]
49
IntArray = npt.NDArray[np.integer]
50
51
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
52
    return f'{category.__name__}: {message}\n'
53
54
warnings.formatwarning = _custom_warning
55
56
with open(str(Path(__file__).absolute().parent / 'atomic_numbers.json')) as f:
57
    _ATOMIC_NUMBERS = json.load(f)
58
59
_EQ_SITE_TOL: float = 1e-3
60
_CIF_TAGS: dict = {
61
    'cellpar': [
62
        '_cell_length_a',
63
        '_cell_length_b',
64
        '_cell_length_c',
65
        '_cell_angle_alpha',
66
        '_cell_angle_beta',
67
        '_cell_angle_gamma'
68
    ],
69
    'atom_site_fract': [
70
        '_atom_site_fract_x',
71
        '_atom_site_fract_y',
72
        '_atom_site_fract_z'
73
    ],
74
    'atom_site_cartn': [
75
        '_atom_site_Cartn_x',
76
        '_atom_site_Cartn_y',
77
        '_atom_site_Cartn_z'
78
    ],
79
    'symop': [
80
        '_space_group_symop_operation_xyz',
81
        '_space_group_symop.operation_xyz',
82
        '_symmetry_equiv_pos_as_xyz'
83
    ],
84
    'spacegroup_name': [
85
        '_space_group_name_H-M_alt',
86
        '_symmetry_space_group_name_H-M'
87
    ],
88
    'spacegroup_number': [
89
        '_space_group_IT_number',
90
        '_symmetry_Int_Tables_number'
91
    ],
92
}
93
94
__all__ = [
95
    'CifReader',
96
    'CSDReader',
97
    'ParseError',
98
    'periodicset_from_gemmi_block',
99
    'periodicset_from_ase_cifblock',
100
    'periodicset_from_pymatgen_cifblock',
101
    'periodicset_from_ase_atoms',
102
    'periodicset_from_pymatgen_structure',
103
    'periodicset_from_ccdc_entry',
104
    'periodicset_from_ccdc_crystal',
105
]
106
107
class _Reader(collections.abc.Iterator):
108
    """Base reader class."""
109
110
    def __init__(
111
            self,
112
            iterable: Iterable,
113
            converter: Callable[..., PeriodicSet],
114
            show_warnings: bool,
115
            verbose: bool
116
    ):
117
118
        self._iterator = iter(iterable)
119
        self._converter = converter
120
        self.show_warnings = show_warnings
121
        if verbose:
122
            self._progress_bar = tqdm.tqdm(desc='Reading', delay=1)
123
        else:
124
            self._progress_bar = None
125
126
    def __next__(self):
127
        """Iterate over self._iterator, passing items through
128
        self._converter and yielding. If
129
        :class:`ParseError <.io.ParseError>` is raised in a call to
130
        self._converter, the item is skipped. Warnings raised in
131
        self._converter are printed if self.show_warnings is True.
132
        """
133
134
        if not self.show_warnings:
135
            warnings.simplefilter('ignore')
136
137
        while True:
138
139
            try:
140
                item = next(self._iterator)
141
            except StopIteration:
142
                if self._progress_bar is not None:
143
                    self._progress_bar.close()
144
                raise StopIteration
145
146
            with warnings.catch_warnings(record=True) as warning_msgs:
147
                try:
148
                    periodic_set = self._converter(item)
149
                except ParseError as err:
150
                    warnings.warn(str(err))
151
                    continue
152
                finally:
153
                    if self._progress_bar is not None:
154
                        self._progress_bar.update(1)
155
156
            for warning in warning_msgs:
157
                msg = f'(name={periodic_set.name}) {warning.message}'
158
                warnings.warn(msg, category=warning.category)
159
160
            return periodic_set
161
162
    def read(self) -> Union[PeriodicSet, List[PeriodicSet]]:
163
        """Read the crystal(s), return one
164
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
165
        only one, otherwise return a list.
166
        """
167
        items = list(self)
168
        if len(items) == 1:
169
            return items[0]
170
        return items
171
172
173
class CifReader(_Reader):
174
    """Read all structures in a .cif file or all files in a folder
175
    with ase or csd-python-api (if installed), yielding
176
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
177
178
    Parameters
179
    ----------
180
    path : str
181
        Path to a .CIF file or directory. (Other files are accepted when
182
        using ``reader='ccdc'``, if csd-python-api is installed.)
183
    reader : str, optional
184
        The backend package used to parse the CIF. The default is
185
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
186
        accepted, as well as :code:`ccdc` if csd-python-api is
187
        installed. The ccdc reader should be able to read any format
188
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
189
        been tested.
190
    remove_hydrogens : bool, optional
191
        Remove Hydrogens from the crystals.
192
    disorder : str, optional
193
        Controls how disordered structures are handled. Default is
194
        ``skip`` which skips any crystal with disorder, since disorder
195
        conflicts with the periodic set model. To read disordered
196
        structures anyway, choose either :code:`ordered_sites` to remove
197
        atoms with disorder or :code:`all_sites` include all atoms
198
        regardless of disorder.
199
    heaviest_component : bool, optional
200
        csd-python-api only. Removes all but the heaviest molecule in
201
        the asymmeric unit, intended for removing solvents.
202
    molecular_centres : bool, default False
203
        csd-python-api only. Extract the centres of molecules in the
204
        unit cell and store in the attribute molecular_centres.
205
    show_warnings : bool, optional
206
        Controls whether warnings that arise during reading are printed.
207
    verbose : bool, default False
208
        If True, prints a progress bar showing the number of items
209
        processed.
210
211
    Yields
212
    ------
213
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
214
        Represents the crystal as a periodic set, consisting of a finite
215
        set of points (motif) and lattice (unit cell). Contains other
216
        data, e.g. the crystal's name and information about the
217
        asymmetric unit.
218
219
    Examples
220
    --------
221
222
        ::
223
224
            # Put all crystals in a .CIF in a list
225
            structures = list(amd.CifReader('mycif.cif'))
226
227
            # Can also accept path to a directory, reading all files inside
228
            structures = list(amd.CifReader('path/to/folder'))
229
230
            # Reads just one if the .CIF has just one crystal
231
            periodic_set = amd.CifReader('mycif.cif').read()
232
233
            # List of AMDs (k=100) of crystals in a .CIF
234
            amds = [amd.AMD(item, 100) for item in amd.CifReader('mycif.cif')]
235
    """
236
237
    def __init__(
238
            self,
239
            path: Union[str, os.PathLike],
240
            reader: str = 'gemmi',
241
            remove_hydrogens: bool = False,
242
            disorder: str = 'skip',
243
            heaviest_component: bool = False,
244
            molecular_centres: bool = False,
245
            show_warnings: bool = True,
246
            verbose: bool = False
247
    ):
248
249
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
250
            raise ValueError(
251
                f"'disorder'' parameter of {self.__class__.__name__} must be "
252
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
253
                f"'{disorder}')"
254
            )
255
256
        if reader != 'ccdc':
257
            if heaviest_component:
258
                raise NotImplementedError(
259
                    "'heaviest_component' parameter of "
260
                    f"{self.__class__.__name__} only implemented with "
261
                    "csd-python-api, if installed pass reader='ccdc'"
262
                )
263
            if molecular_centres:
264
                raise NotImplementedError(
265
                    "'molecular_centres' parameter of "
266
                    f"{self.__class__.__name__} only implemented with "
267
                    "csd-python-api, if installed pass reader='ccdc'"
268
                )
269
270
        # cannot handle some characters (�) in cifs
271
        if reader == 'gemmi':
272
            import gemmi
273
            extensions = {'cif'}
274
            file_parser = gemmi.cif.read_file
275
            converter = functools.partial(
276
                periodicset_from_gemmi_block,
277
                remove_hydrogens=remove_hydrogens,
278
                disorder=disorder
279
            )
280
281
        elif reader in ('ase', 'pycodcif'):
282
            from ase.io.cif import parse_cif
283
            extensions = {'cif'}
284
            file_parser = functools.partial(parse_cif, reader=reader)
285
            converter = functools.partial(
286
                periodicset_from_ase_cifblock,
287
                remove_hydrogens=remove_hydrogens,
288
                disorder=disorder
289
            )
290
291
        elif reader == 'pymatgen':
292
293
            def _pymatgen_cif_parser(path):
294
                from pymatgen.io.cif import CifFile
295
                return CifFile.from_file(path).data.values()
296
297
            extensions = {'cif'}
298
            file_parser = _pymatgen_cif_parser
299
            converter = functools.partial(
300
                periodicset_from_pymatgen_cifblock,
301
                remove_hydrogens=remove_hydrogens,
302
                disorder=disorder
303
            )
304
305
        elif reader == 'ccdc':
306
            try:
307
                import ccdc.io
308
            except (ImportError, RuntimeError) as e:
309
                raise ImportError('Failed to import csd-python-api') from e
310
311
            extensions = set(ccdc.io.EntryReader.known_suffixes.keys())
312
            file_parser = ccdc.io.EntryReader
313
            converter = functools.partial(
314
                periodicset_from_ccdc_entry,
315
                remove_hydrogens=remove_hydrogens,
316
                disorder=disorder,
317
                molecular_centres=molecular_centres,
318
                heaviest_component=heaviest_component
319
            )
320
321
        else:
322
            raise ValueError(
323
                f"'reader' parameter of {self.__class__.__name__} must be one "
324
                f"of 'gemmi', 'pymatgen', 'ccdc', 'ase', or 'pycodcif' "
325
                f"(passed '{reader}')"
326
            )
327
328
        path = Path(path)
329
        if path.is_file():
330
            iterable = file_parser(str(path))
331
        elif path.is_dir():
332
            iterable = CifReader._dir_generator(path, file_parser, extensions)
333
        else:
334
            raise FileNotFoundError(
335
                errno.ENOENT, os.strerror(errno.ENOENT), str(path)
336
            )
337
338
        super().__init__(iterable, converter, show_warnings, verbose)
339
340
    @staticmethod
341
    def _dir_generator(
342
            path: Path, file_parser: Callable, extensions: Iterable
343
    ) -> Iterator:
344
        """Generate items from all files with extensions in
345
        ``extensions`` from a directory ``path``."""
346
        for file_path in path.iterdir():
347
            if not file_path.is_file():
348
                continue
349
            if file_path.suffix[1:].lower() not in extensions:
350
                continue
351
            try:
352
                yield from file_parser(str(file_path))
353
            except Exception as e:
354
                warnings.warn(
355
                    f'Error parsing "{str(file_path)}", skipping file. '
356
                    f'Exception: {repr(e)}'
357
                )
358
359
360
class CSDReader(_Reader):
361
    """Read structures from the CSD with csd-python-api, yielding
362
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
363
364
    Parameters
365
    ----------
366
    refcodes : str or List[str], optional
367
        Single or list of CSD refcodes to read. If None or 'CSD',
368
        iterates over the whole CSD.
369
    refcode_families : bool, optional
370
        Interpret ``refcodes`` as one or more refcode families, reading
371
        all entries whose refcode starts with the given strings.
372
    remove_hydrogens : bool, optional
373
        Remove hydrogens from the crystals.
374
    disorder : str, optional
375
        Controls how disordered structures are handled. Default is
376
        ``skip`` which skips any crystal with disorder, since disorder
377
        conflicts with the periodic set model. To read disordered
378
        structures anyway, choose either :code:`ordered_sites` to remove
379
        atoms with disorder or :code:`all_sites` include all atoms
380
        regardless of disorder.
381
    heaviest_component : bool, optional
382
        Removes all but the heaviest molecule in the asymmeric unit,
383
        intended for removing solvents.
384
    molecular_centres : bool, default False
385
        Extract the centres of molecules in the unit cell and store in
386
        attribute molecular_centres.
387
    show_warnings : bool, optional
388
        Controls whether warnings that arise during reading are printed.
389
    verbose : bool, default False
390
        If True, prints a progress bar showing the number of items
391
        processed.
392
393
    Yields
394
    ------
395
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
396
        Represents the crystal as a periodic set, consisting of a finite
397
        set of points (motif) and lattice (unit cell). Contains other
398
        useful data, e.g. the crystal's name and information about the
399
        asymmetric unit for calculation.
400
401
    Examples
402
    --------
403
404
        ::
405
406
            # Put these entries in a list
407
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
408
            structures = list(amd.CSDReader(refcodes))
409
410
            # Read refcode families (any whose refcode starts with strings in the list)
411
            families = ['ACSALA', 'HXACAN']
412
            structures = list(amd.CSDReader(families, refcode_families=True))
413
414
            # Get AMDs (k=100) for crystals in these families
415
            refcodes = ['ACSALA', 'HXACAN']
416
            amds = []
417
            for periodic_set in amd.CSDReader(refcodes, refcode_families=True):
418
                amds.append(amd.AMD(periodic_set, 100))
419
420
            # Giving the reader nothing reads from the whole CSD.
421
            for periodic_set in amd.CSDReader():
422
                ...
423
    """
424
425
    def __init__(
426
            self,
427
            refcodes: Optional[Union[str, List[str]]] = None,
428
            refcode_families: bool = False,
429
            remove_hydrogens: bool = False,
430
            disorder: str = 'skip',
431
            heaviest_component: bool = False,
432
            molecular_centres: bool = False,
433
            show_warnings: bool = True,
434
            verbose: bool = False
435
    ):
436
437
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
438
            raise ValueError(
439
                f"'disorder'' parameter of {self.__class__.__name__} must be "
440
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
441
                f"'{disorder}')"
442
            )
443
444
        try:
445
            import ccdc.search
446
            import ccdc.io
447
        except (ImportError, RuntimeError) as e:
448
            raise ImportError('Failed to import csd-python-api') from e
449
450
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
451
            refcodes = None
452
        if refcodes is None:
453
            refcode_families = False
454
        elif isinstance(refcodes, str):
455
            refcodes = [refcodes]
456
        elif isinstance(refcodes, list):
457
            if not all(isinstance(refcode, str) for refcode in refcodes):
458
                raise ValueError(
459
                    f'{self.__class__.__name__} expects None, a string or '
460
                    'list of strings.'
461
                )
462
        else:
463
            raise ValueError(
464
                f'{self.__class__.__name__} expects None, a string or list of '
465
                f'strings, got {refcodes.__class__.__name__}'
466
            )
467
468
        if refcode_families:
469
            all_refcodes = []
470
            for refcode in refcodes:
471
                query = ccdc.search.TextNumericSearch()
472
                query.add_identifier(refcode)
473
                hits = [hit.identifier for hit in query.search()]
474
                all_refcodes.extend(hits)
475
            # filter to unique refcodes while keeping order
476
            refcodes = []
477
            seen = set()
478
            for refcode in all_refcodes:
479
                if refcode not in seen:
480
                    refcodes.append(refcode)
481
                    seen.add(refcode)
482
483
        converter = functools.partial(
484
            periodicset_from_ccdc_entry,
485
            remove_hydrogens=remove_hydrogens,
486
            disorder=disorder,
487
            molecular_centres=molecular_centres,
488
            heaviest_component=heaviest_component
489
        )
490
491
        entry_reader = ccdc.io.EntryReader('CSD')
492
        if refcodes is None:
493
            iterable = entry_reader
494
        else:
495
            iterable = map(entry_reader.entry, refcodes)
496
497
        super().__init__(iterable, converter, show_warnings, verbose)
498
499
500
class ParseError(ValueError):
501
    """Raised when an item cannot be parsed into a periodic set."""
502
    pass
503
504
505
def periodicset_from_gemmi_block(
506
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
507
) -> PeriodicSet:
508
    """Convert a :class:`gemmi.cif.Block` object to a
509
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
510
    :class:`gemmi.cif.Block` is the type returned by
511
    :func:`gemmi.cif.read_file`.
512
513
    Parameters
514
    ----------
515
    block : :class:`gemmi.cif.Block`
516
        An ase CIFBlock object representing a crystal.
517
    remove_hydrogens : bool, optional
518
        Remove Hydrogens from the crystal.
519
    disorder : str, optional
520
        Controls how disordered structures are handled. Default is
521
        ``skip`` which skips any crystal with disorder, since disorder
522
        conflicts with the periodic set model. To read disordered
523
        structures anyway, choose either :code:`ordered_sites` to remove
524
        atoms with disorder or :code:`all_sites` include all atoms
525
        regardless of disorder.
526
527
    Returns
528
    -------
529
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
530
        Represents the crystal as a periodic set, consisting of a finite
531
        set of points (motif) and lattice (unit cell). Contains other
532
        useful data, e.g. the crystal's name and information about the
533
        asymmetric unit for calculation.
534
535
    Raises
536
    ------
537
    ParseError
538
        Raised if the structure fails to be parsed for any of the
539
        following: 1. Required data is missing (e.g. cell parameters),
540
        2. :code:``disorder == 'skip'`` and disorder is found on any
541
        atom, 3. The motif is empty after removing H or disordered
542
        sites.
543
    """
544
545
    import gemmi
546
547
    # Unit cell
548
    cellpar = [block.find_value(t) for t in _CIF_TAGS['cellpar']]
549
    if not all(isinstance(par, str) for par in cellpar):
550
        raise ParseError(f'{block.name} has missing cell data')
551
    cellpar = np.array([str2float(par) for par in cellpar])
552
    if np.isnan(np.sum(cellpar)):
553
        raise ParseError(f'{block.name} has missing cell data')
554
    cell = cellpar_to_cell(cellpar)
555
556
    # Asymmetric unit coordinates
557
    xyz_loop = block.find(_CIF_TAGS['atom_site_fract']).loop
558
    if xyz_loop is None:
559
        xyz_loop = block.find(_CIF_TAGS['atom_site_cartn']).loop
560
        if xyz_loop is None:
561
            raise ParseError(f'{block.name} has missing coordinate data')
562
        else:
563
            raise ParseError(
564
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
565
                'only _atom_site_fract_ is supported'
566
            )
567
568
    tablified_loop = [[] for _ in range(len(xyz_loop.tags))]
569
    for i, item in enumerate(xyz_loop.values):
570
        tablified_loop[i % xyz_loop.width()].append(item)
571
    loop_dict = {tag: l for tag, l in zip(xyz_loop.tags, tablified_loop)}
572
    xyz_str = [loop_dict[t] for t in _CIF_TAGS['atom_site_fract']]
573
    asym_unit = np.transpose(np.array(
574
        [[str2float(c) for c in xyz] for xyz in xyz_str]
575
    ))
576
    asym_unit = np.mod(asym_unit, 1)
577
578
    # recommended by pymatgen
579
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4) 
580
581
    # Labels
582
    if '_atom_site_label' in loop_dict:
583
        labels = [
584
            gemmi.cif.as_string(lab) for lab in loop_dict['_atom_site_label']
585
        ]
586
    else:
587
        labels = [''] * xyz_loop.length()
588
589
    # Atomic types
590
    if '_atom_site_type_symbol' in loop_dict:
591
        symbols = []
592
        for s in loop_dict['_atom_site_type_symbol']:
593
            sym = gemmi.cif.as_string(s)
594
            match = re.search(r'([A-Za-z][A-Za-z]?)', sym)
595
            if match is not None:
596
                sym = match.group()
597
            else:
598
                sym = ''
599
            sym = list(sym)
600
            if len(sym) > 0:
601
                sym[0] = sym[0].upper()
602
            if len(sym) > 1:
603
                sym[1] = sym[1].lower()
604
            symbols.append(''.join(sym))
605
    else:  # Get atomic types from label
606
        symbols = _atomic_symbols_from_labels(labels)
607
608
    asym_types = []
609
    for s in symbols:
610
        if s in _ATOMIC_NUMBERS:
611
            asym_types.append(_ATOMIC_NUMBERS[s])
612
        else:
613
            asym_types.append(0)
614
615
    # Occupancies
616
    if '_atom_site_occupancy' in loop_dict:
617
        occs = [
618
            str2float(oc) for oc in loop_dict['_atom_site_occupancy']
619
        ]
620
        occupancies = [occ if not math.isnan(occ) else 1 for occ in occs]
621
    else:
622
        occupancies = [1] * xyz_loop.length()
623
624
    # Remove sites with missing coordinates, disorder and Hydrogens if needed
625
    remove_sites = []
626
    remove_sites.extend(np.nonzero(np.isnan(asym_unit.min(axis=-1)))[0])
627
628
    if disorder == 'skip':
629
        if any(_has_disorder(l, o) for l, o in zip(labels, occupancies)):
630
            raise ParseError(
631
                f"{block.name} has disorder, pass disorder='ordered_sites' or "
632
                "'all_sites' to remove/ignore disorder"
633
            )
634
    elif disorder == 'ordered_sites':
635
        for i, (label, occ) in enumerate(zip(labels, occupancies)):
636
            if _has_disorder(label, occ):
637
                remove_sites.append(i)
638
639
    if remove_hydrogens:
640
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
641
642
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
643
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
644
    if asym_unit.shape[0] == 0:
645
        raise ParseError(f'{block.name} has no valid sites')
646
647
    if disorder != 'all_sites':
648
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
649
        if not np.all(keep_sites):
650
            warnings.warn(
651
                'may have overlapping sites; duplicates will be removed'
652
            )
653
        asym_unit = asym_unit[keep_sites]
654
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
655
656
    # Symmetry operations, try xyz strings first
657
    for tag in _CIF_TAGS['symop']:
658
        sitesym = [v.str(0) for v in block.find([tag])]
659
        if sitesym:
660
            rot, trans = _parse_sitesyms(sitesym)
661
            break
662
    else:
663
        # Try spacegroup name; can be a pair or in a loop
664
        spg = None
665
        for tag in _CIF_TAGS['spacegroup_name']:
666
            for value in block.find([tag]):
667
                try:
668
                    # Some names cannot be parsed by gemmi.SpaceGroup
669
                    spg = gemmi.SpaceGroup(value.str(0))
670
                    break
671
                except ValueError:
672
                    continue
673
            if spg is not None:
674
                break
675
676
        if spg is None:
677
            # Try international number
678
            for tag in _CIF_TAGS['spacegroup_number']:
679
                spg_num = block.find_value(tag)
680
                if spg_num is not None:
681
                    spg_num = gemmi.cif.as_int(spg_num)
682
                    break
683
            else:
684
                warnings.warn('no symmetry data found, defaulting to P1')
685
                spg_num = 1
686
            spg = gemmi.SpaceGroup(spg_num)
687
688
        rot = np.array([np.array(o.rot) / o.DEN for o in spg.operations()])
689
        trans = np.array([np.array(o.tran) / o.DEN for o in spg.operations()])
690
691
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
692
    _, wyc_muls = np.unique(invs, return_counts=True)
693
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
694
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
695
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
696
    motif = np.matmul(frac_motif, cell)
697
698
    return PeriodicSet(
699
        motif=motif,
700
        cell=cell,
701
        name=block.name,
702
        asym_unit=asym_inds,
703
        multiplicities=wyc_muls,
704
        types=types
705
    )
706
707
708
def periodicset_from_ase_cifblock(
709
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
710
) -> PeriodicSet:
711
    """Convert a :class:`ase.io.cif.CIFBlock` object to a 
712
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
713
    :class:`ase.io.cif.CIFBlock` is the type returned by
714
    :func:`ase.io.cif.parse_cif`.
715
716
    Parameters
717
    ----------
718
    block : :class:`ase.io.cif.CIFBlock`
719
        An ase :class:`ase.io.cif.CIFBlock` object representing a
720
        crystal.
721
    remove_hydrogens : bool, optional
722
        Remove Hydrogens from the crystal.
723
    disorder : str, optional
724
        Controls how disordered structures are handled. Default is
725
        ``skip`` which skips any crystal with disorder, since disorder
726
        conflicts with the periodic set model. To read disordered
727
        structures anyway, choose either :code:`ordered_sites` to remove
728
        atoms with disorder or :code:`all_sites` include all atoms
729
        regardless of disorder.
730
731
    Returns
732
    -------
733
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
734
        Represents the crystal as a periodic set, consisting of a finite
735
        set of points (motif) and lattice (unit cell). Contains other
736
        useful data, e.g. the crystal's name and information about the
737
        asymmetric unit for calculation.
738
739
    Raises
740
    ------
741
    ParseError
742
        Raised if the structure fails to be parsed for any of the
743
        following: 1. Required data is missing (e.g. cell parameters),
744
        2. The motif is empty after removing H or disordered sites,
745
        3. :code:``disorder == 'skip'`` and disorder is found on any
746
        atom.
747
    """
748
749
    import ase
750
    import ase.spacegroup
751
752
    # Unit cell
753
    cellpar = [str2float(str(block.get(tag))) for tag in _CIF_TAGS['cellpar']]
754
    if None in cellpar:
755
        raise ParseError(f'{block.name} has missing cell data')
756
    cell = cellpar_to_cell(np.array(cellpar))
757
758
    # Asymmetric unit coordinates. ase removes uncertainty brackets
759
    asym_unit = [
760
        [str2float(str(n)) for n in block.get(tag)]
761
        for tag in _CIF_TAGS['atom_site_fract']
762
    ]
763
    if None in asym_unit:
764
        asym_unit = [
765
            block.get(tag.lower()) for tag in _CIF_TAGS['atom_site_cartn']
766
        ]
767
        if None in asym_unit:
768
            raise ParseError(f'{block.name} has missing coordinates')
769
        else:
770
            raise ParseError(
771
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
772
                'only _atom_site_fract_ is supported'
773
            )
774
    asym_unit = list(zip(*asym_unit))
775
776
    # Labels
777
    asym_labels = block.get('_atom_site_label')
778
    if asym_labels is None:
779
        asym_labels = [''] * len(asym_unit)
780
781
    # Atomic types
782
    asym_symbols = block.get('_atom_site_type_symbol')
783
    if asym_symbols is not None:
784
        asym_symbols_ = _atomic_symbols_from_labels(asym_symbols)
785
    else:
786
        asym_symbols_ = [''] * len(asym_unit)
787
788
    asym_types = []
789
    for s in asym_symbols_:
790
        if s in _ATOMIC_NUMBERS:
791
            asym_types.append(_ATOMIC_NUMBERS[s])
792
        else:
793
            asym_types.append(0)
794
795
    # Find where sites have disorder if necassary
796
    has_disorder = []
797
    if disorder != 'all_sites':
798
        occupancies = block.get('_atom_site_occupancy')
799
        if occupancies is None:
800
            occupancies = [1] * len(asym_unit)
801
        for lab, occ in zip(asym_labels, occupancies):
802
            has_disorder.append(_has_disorder(lab, occ))
803
804
    # Remove sites with ?, . or other invalid string for coordinates
805
    invalid = []
806
    for i, xyz in enumerate(asym_unit):
807
        if not all(isinstance(coord, (int, float)) for coord in xyz):
808
            invalid.append(i)
809
    if invalid:
810
        warnings.warn('atoms without sites or missing data will be removed')
811
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
812
        asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
813
        if disorder != 'all_sites':
814
            has_disorder = [
815
                d for i, d in enumerate(has_disorder) if i not in invalid
816
            ]
817
818
    remove_sites = []
819
820
    if remove_hydrogens:
821
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
822
823
    # Remove atoms with fractional occupancy or raise ParseError
824 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
825
        for i, dis in enumerate(has_disorder):
826
            if i in remove_sites:
827
                continue
828
            if dis:
829
                if disorder == 'skip':
830
                    raise ParseError(
831
                        f'{block.name} has disorder, pass '
832
                        "disorder='ordered_sites' or 'all_sites' to "
833
                        'remove/ignore disorder'
834
                    )
835
                elif disorder == 'ordered_sites':
836
                    remove_sites.append(i)
837
838
    # Asymmetric unit
839
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
840
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
841
    if len(asym_unit) == 0:
842
        raise ParseError(f'{block.name} has no valid sites')
843
    asym_unit = np.mod(np.array(asym_unit), 1)
844
845
    # recommended by pymatgen
846
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
847
848
    # Remove overlapping sites unless disorder == 'all_sites'
849
    if disorder != 'all_sites':
850
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
851
        if not np.all(keep_sites):
852
            warnings.warn(
853
                'may have overlapping sites, duplicates will be removed'
854
            )
855
            asym_unit = asym_unit[keep_sites]
856
            asym_types = [t for t, keep in zip(asym_types, keep_sites) if keep]
857
858
    # Get symmetry operations
859
    sitesym = block._get_any(_CIF_TAGS['symop'])
860
    if sitesym is None:
861
        label_or_num = block._get_any(
862
            [s.lower() for s in _CIF_TAGS['spacegroup_name']]
863
        )
864
        if label_or_num is None:
865
            label_or_num = block._get_any(
866
                [s.lower() for s in _CIF_TAGS['spacegroup_number']]
867
            )
868
        if label_or_num is None:
869
            warnings.warn('no symmetry data found, defaulting to P1')
870
            label_or_num = 1
871
        spg = ase.spacegroup.Spacegroup(label_or_num)
872
        rot, trans = spg.get_op()
873
    else:
874
        if isinstance(sitesym, str):
875
            sitesym = [sitesym]
876
        rot, trans = _parse_sitesyms(sitesym)
877
878
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
879
    _, wyc_muls = np.unique(invs, return_counts=True)
880
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
881
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
882
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
883
    motif = np.matmul(frac_motif, cell)
884
885
    return PeriodicSet(
886
        motif=motif,
887
        cell=cell,
888
        name=block.name,
889
        asym_unit=asym_inds,
890
        multiplicities=wyc_muls,
891
        types=types
892
    )
893
894
895
def periodicset_from_pymatgen_cifblock(
896
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
897
) -> PeriodicSet:
898
    """Convert a :class:`pymatgen.io.cif.CifBlock` object to a
899
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
900
    :class:`pymatgen.io.cif.CifBlock` is the type returned by
901
    :class:`pymatgen.io.cif.CifFile`.
902
903
    Parameters
904
    ----------
905
    block : :class:`pymatgen.io.cif.CifBlock`
906
        A pymatgen CifBlock object representing a crystal.
907
    remove_hydrogens : bool, optional
908
        Remove Hydrogens from the crystal.
909
    disorder : str, optional
910
        Controls how disordered structures are handled. Default is
911
        ``skip`` which skips any crystal with disorder, since disorder
912
        conflicts with the periodic set model. To read disordered
913
        structures anyway, choose either :code:`ordered_sites` to remove
914
        atoms with disorder or :code:`all_sites` include all atoms
915
        regardless of disorder.
916
917
    Returns
918
    -------
919
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
920
        Represents the crystal as a periodic set, consisting of a finite
921
        set of points (motif) and lattice (unit cell). Contains other
922
        useful data, e.g. the crystal's name and information about the
923
        asymmetric unit for calculation.
924
925
    Raises
926
    ------
927
    ParseError
928
        Raised if the structure can/should not be parsed for the
929
        following reasons: 1. No sites found or motif is empty after
930
        removing Hydrogens & disorder, 2. A site has missing
931
        coordinates, 3. :code:``disorder == 'skip'`` and disorder is
932
        found on any atom.
933
    """
934
935
    odict = block.data
936
937
    # Unit cell
938
    cellpar = [odict.get(tag) for tag in _CIF_TAGS['cellpar']]
939
    if any(par in (None, '?', '.') for par in cellpar):
940
        raise ParseError(f'{block.header} has missing cell data')
941
942
    try:
943
        cellpar = [str2float(v) for v in cellpar]
944
    except ValueError:
945
        raise ParseError(f'{block.header} could not be parsed')
946
    cell = cellpar_to_cell(
947
        np.array(cellpar, dtype=np.float64)
948
    )
949
950
    # Asymmetric unit coordinates
951
    asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
952
    # check for . and ?
953
    if None in asym_unit:
954
        asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_cartn']]
955
        if None in asym_unit:
956
            raise ParseError(f'{block.header} has missing coordinates')
957
        else:
958
            raise ParseError(
959
                f'{block.header} uses _atom_site_Cartn_ tags for coordinates, '
960
                'only _atom_site_fract_ is supported'
961
            )
962
    asym_unit = list(zip(*asym_unit))
963
    try:
964
        asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
965
    except ValueError:
966
        raise ParseError(f'{block.header} could not be parsed')
967
968
    # Labels
969
    asym_labels = odict.get('_atom_site_label')
970
    if asym_labels is None:
971
        asym_labels = [''] * len(asym_unit)
972
973
    # Atomic types
974
    asym_symbols = odict.get('_atom_site_type_symbol')
975
    if asym_symbols is not None:
976
        asym_symbols_ = _atomic_symbols_from_labels(asym_symbols)
977
    else:
978
        asym_symbols_ = [''] * len(asym_unit)
979
    
980
    asym_types = []
981
    for s in asym_symbols_:
982
        if s in _ATOMIC_NUMBERS:
983
            asym_types.append(_ATOMIC_NUMBERS[s])
984
        else:
985
            asym_types.append(0)
986
987
    # Find where sites have disorder if necassary
988
    has_disorder = []
989
    if disorder != 'all_sites':
990
        occupancies = odict.get('_atom_site_occupancy')
991
        if occupancies is None:
992
            occupancies = np.ones((len(asym_unit), ))
993
        else:
994
            occupancies = np.array([str2float(occ) for occ in occupancies])
995
        labels = odict.get('_atom_site_label')
996
        if labels is None:
997
            labels = [''] * len(asym_unit)
998
        for lab, occ in zip(labels, occupancies):
999
            has_disorder.append(_has_disorder(lab, occ))
1000
1001
    # Remove sites with ?, . or other invalid string for coordinates
1002
    invalid = []
1003
    for i, xyz in enumerate(asym_unit):
1004
        if not all(isinstance(coord, (int, float)) for coord in xyz):
1005
            invalid.append(i)
1006
1007
    if invalid:
1008
        warnings.warn('atoms without sites or missing data will be removed')
1009
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
1010
        asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
1011
        if disorder != 'all_sites':
1012
            has_disorder = [
1013
                d for i, d in enumerate(has_disorder) if i not in invalid
1014
            ]
1015
1016
    remove_sites = []
1017
1018
    if remove_hydrogens:
1019
        remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
0 ignored issues
show
introduced by
The variable i does not seem to be defined in case the for loop on line 1003 is not entered. Are you sure this can never be the case?
Loading history...
1020
1021
    # Remove atoms with fractional occupancy or raise ParseError
1022 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1023
        for i, dis in enumerate(has_disorder):
1024
            if i in remove_sites:
1025
                continue
1026
            if dis:
1027
                if disorder == 'skip':
1028
                    raise ParseError(
1029
                        f'{block.header} has disorder, pass '
1030
                        "disorder='ordered_sites' or 'all_sites' to "
1031
                        'remove/ignore disorder'
1032
                    )
1033
                elif disorder == 'ordered_sites':
1034
                    remove_sites.append(i)
1035
1036
    # Asymmetric unit
1037
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
1038
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
1039
    if len(asym_unit) == 0:
1040
        raise ParseError(f'{block.header} has no valid sites')
1041
    asym_unit = np.mod(np.array(asym_unit), 1)
1042
1043
    # recommended by pymatgen
1044
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1045
1046
    # Remove overlapping sites unless disorder == 'all_sites'
1047
    if disorder != 'all_sites':
1048
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1049
        if not np.all(keep_sites):
1050
            warnings.warn(
1051
                'may have overlapping sites; duplicates will be removed'
1052
            )
1053
        asym_unit = asym_unit[keep_sites]
1054
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1055
1056
    # Apply symmetries to asymmetric unit
1057
    rot, trans = _get_syms_pymatgen(odict)
1058
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1059
    _, wyc_muls = np.unique(invs, return_counts=True)
1060
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1061
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1062
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1063
    motif = np.matmul(frac_motif, cell)
1064
1065
    return PeriodicSet(
1066
        motif=motif,
1067
        cell=cell,
1068
        name=block.header,
1069
        asym_unit=asym_inds,
1070
        multiplicities=wyc_muls,
1071
        types=types
1072
    )
1073
1074
1075
def periodicset_from_ase_atoms(
1076
        atoms, remove_hydrogens: bool = False
1077
) -> PeriodicSet:
1078
    """Convert an :class:`ase.atoms.Atoms` object to a
1079
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
1080
    the option to remove disorder.
1081
1082
    Parameters
1083
    ----------
1084
    atoms : :class:`ase.atoms.Atoms`
1085
        An ase :class:`ase.atoms.Atoms` object representing a crystal.
1086
    remove_hydrogens : bool, optional
1087
        Remove Hydrogens from the crystal.
1088
1089
    Returns
1090
    -------
1091
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1092
        Represents the crystal as a periodic set, consisting of a finite
1093
        set of points (motif) and lattice (unit cell). Contains other
1094
        useful data, e.g. the crystal's name and information about the
1095
        asymmetric unit for calculation.
1096
1097
    Raises
1098
    ------
1099
    ParseError
1100
        Raised if there are no valid sites in atoms.
1101
    """
1102
1103
    from ase.spacegroup import get_basis
1104
1105
    cell = atoms.get_cell().array
1106
1107
    remove_inds = []
1108
    if remove_hydrogens:
1109
        for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
1110
            remove_inds.append(i)
1111
    for i in sorted(remove_inds, reverse=True):
1112
        atoms.pop(i)
1113
1114
    if len(atoms) == 0:
1115
        raise ParseError('ase Atoms object has no valid sites')
1116
1117
    # Symmetry operations from spacegroup
1118
    spg = None
1119
    if 'spacegroup' in atoms.info:
1120
        spg = atoms.info['spacegroup']
1121
        rot, trans = spg.rotations, spg.translations
1122
    else:
1123
        warnings.warn('no symmetry data found, defaulting to P1')
1124
        rot = np.identity(3)[None, :]
1125
        trans = np.zeros((1, 3))
1126
1127
    # Asymmetric unit. ase default tol is 1e-5
1128
    # do differently! get_basis determines a reduced asym unit from the atoms;
1129
    # surely this is not needed!
1130
    asym_unit = get_basis(atoms, spacegroup=spg, tol=_EQ_SITE_TOL)
1131
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1132
    _, wyc_muls = np.unique(invs, return_counts=True)
1133
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1134
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1135
    types = atoms.get_atomic_numbers().astype(np.uint8)
1136
    motif = np.matmul(frac_motif, cell)
1137
1138
    return PeriodicSet(
1139
        motif=motif,
1140
        cell=cell,
1141
        asym_unit=asym_inds,
1142
        multiplicities=wyc_muls,
1143
        types=types
1144
    )
1145
1146
1147
def periodicset_from_pymatgen_structure(
1148
        structure, remove_hydrogens: bool = False, disorder: str = 'skip'
1149
) -> PeriodicSet:
1150
    """Convert a :class:`pymatgen.core.structure.Structure` object to a
1151
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not set
1152
    the name of the periodic set, as pymatgen Structure objects seem to
1153
    have no name attribute.
1154
1155
    Parameters
1156
    ----------
1157
    structure : :class:`pymatgen.core.structure.Structure`
1158
        A pymatgen Structure object representing a crystal.
1159
    remove_hydrogens : bool, optional
1160
        Remove Hydrogens from the crystal.
1161
    disorder : str, optional
1162
        Controls how disordered structures are handled. Default is
1163
        ``skip`` which skips any crystal with disorder, since disorder
1164
        conflicts with the periodic set model. To read disordered
1165
        structures anyway, choose either :code:`ordered_sites` to remove
1166
        atoms with disorder or :code:`all_sites` include all atoms
1167
        regardless of disorder.
1168
1169
    Returns
1170
    -------
1171
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1172
        Represents the crystal as a periodic set, consisting of a finite
1173
        set of points (motif) and lattice (unit cell). Contains other
1174
        useful data, e.g. the crystal's name and information about the
1175
        asymmetric unit for calculation.
1176
1177
    Raises
1178
    ------
1179
    ParseError
1180
        Raised if the :code:`disorder == 'skip'` and
1181
        :code:`not structure.is_ordered`
1182
    """
1183
1184
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1185
1186
    if remove_hydrogens:
1187
        structure.remove_species(['H', 'D'])
1188
1189
    # Disorder
1190
    if disorder == 'skip':
1191
        if not structure.is_ordered:
1192
            raise ParseError(
1193
                'pymatgen Structure has disorder, pass '
1194
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1195
                'disorder'
1196
            )
1197
    elif disorder == 'ordered_sites':
1198
        remove_inds = []
1199
        for i, comp in enumerate(structure.species_and_occu):
1200
            if comp.num_atoms < 1:
1201
                remove_inds.append(i)
1202
        structure.remove_sites(remove_inds)
1203
1204
    spg = SpacegroupAnalyzer(structure)
1205
    # sym_struct = spg.get_symmetrized_structure()
1206
1207
    # asym_inds = np.array([inds[0] for inds in sym_struct.equivalent_indices])
1208
    # wyc_muls = np.array([len(inds) for inds in sym_struct.equivalent_indices])
1209
    # types = np.array(sym_struct.atomic_numbers, dtype=np.uint8)
1210
    # motif = sym_struct.cart_coords
1211
    # cell = sym_struct.lattice.matrix
1212
    
1213
    # ops = spg.get_space_group_operations()
1214
    # rot = np.array([op.rotation_matrix for op in ops])
1215
    # trans = np.array([op.translation_vector for op in ops])
1216
    # frac_motif, invs = _expand_asym_unit(, rot, trans, _EQ_SITE_TOL)
1217
1218
    # eq_inds = sym_structure.equivalent_indices
1219
    # asym_inds = np.array([ix_list[0] for ix_list in eq_inds], dtype=np.int32)
1220
    # wyc_muls = np.array([len(ix_list) for ix_list in eq_inds], dtype=np.int32)
1221
1222
    sym_structure = spg.get_conventional_standard_structure()
1223
    motif = sym_structure.cart_coords
1224
    cell = sym_structure.lattice.matrix
1225
1226
    asym_inds = np.arange(len(motif))
1227
    wyc_muls = np.ones((len(motif), ), dtype=np.int32)
1228
    types = np.array(sym_structure.atomic_numbers, dtype=np.uint8)
1229
1230
    return PeriodicSet(
1231
        motif=motif,
1232
        cell=cell,
1233
        asym_unit=asym_inds,
1234
        multiplicities=wyc_muls,
1235
        types=types
1236
    )
1237
1238
1239
def periodicset_from_ccdc_entry(
1240
        entry,
1241
        remove_hydrogens: bool = False,
1242
        disorder: str = 'skip',
1243
        heaviest_component: bool = False,
1244
        molecular_centres: bool = False
1245
) -> PeriodicSet:
1246
    """Convert a :class:`ccdc.entry.Entry` object to a
1247
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1248
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
1249
1250
    Parameters
1251
    ----------
1252
    entry : :class:`ccdc.entry.Entry`
1253
        A ccdc Entry object representing a database entry.
1254
    remove_hydrogens : bool, optional
1255
        Remove Hydrogens from the crystal.
1256
    disorder : str, optional
1257
        Controls how disordered structures are handled. Default is
1258
        ``skip`` which skips any crystal with disorder, since disorder
1259
        conflicts with the periodic set model. To read disordered
1260
        structures anyway, choose either :code:`ordered_sites` to remove
1261
        atoms with disorder or :code:`all_sites` include all atoms
1262
        regardless of disorder.
1263
    heaviest_component : bool, optional
1264
        Removes all but the heaviest molecule in the asymmeric unit,
1265
        intended for removing solvents.
1266
    molecular_centres : bool, default False
1267
        Use molecular centres of mass as the motif instead of centres of
1268
        atoms.
1269
1270
    Returns
1271
    -------
1272
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1273
        Represents the crystal as a periodic set, consisting of a finite
1274
        set of points (motif) and lattice (unit cell). Contains other
1275
        useful data, e.g. the crystal's name and information about the
1276
        asymmetric unit for calculation.
1277
1278
    Raises
1279
    ------
1280
    ParseError
1281
        Raised if the structure fails parsing for any of the following:
1282
        1. entry.has_3d_structure is False, 2.
1283
        :code:``disorder == 'skip'`` and disorder is found on any atom,
1284
        3. entry.crystal.molecule.all_atoms_have_sites is False,
1285
        4. a.fractional_coordinates is None for any a in
1286
        entry.crystal.disordered_molecule, 5. The motif is empty after
1287
        removing Hydrogens and disordered sites.
1288
    """
1289
1290
    # Entry specific flag
1291
    if not entry.has_3d_structure:
1292
        raise ParseError(f'{entry.identifier} has no 3D structure')
1293
1294
    # Disorder
1295
    if disorder == 'skip' and entry.has_disorder:
1296
        raise ParseError(
1297
            f"{entry.identifier} has disorder, pass disorder='ordered_sites' "
1298
            "or 'all_sites' to remove/ignore disorder"
1299
        )
1300
1301
    return periodicset_from_ccdc_crystal(
1302
        entry.crystal,
1303
        remove_hydrogens=remove_hydrogens,
1304
        disorder=disorder,
1305
        heaviest_component=heaviest_component,
1306
        molecular_centres=molecular_centres
1307
    )
1308
1309
1310
def periodicset_from_ccdc_crystal(
1311
        crystal,
1312
        remove_hydrogens: bool = False,
1313
        disorder: str = 'skip',
1314
        heaviest_component: bool = False,
1315
        molecular_centres: bool = False
1316
) -> PeriodicSet:
1317
    """Convert a :class:`ccdc.crystal.Crystal` object to a
1318
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1319
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
1320
1321
    Parameters
1322
    ----------
1323
    crystal : :class:`ccdc.crystal.Crystal`
1324
        A ccdc Crystal object representing a crystal structure.
1325
    remove_hydrogens : bool, optional
1326
        Remove Hydrogens from the crystal.
1327
    disorder : str, optional
1328
        Controls how disordered structures are handled. Default is
1329
        ``skip`` which skips any crystal with disorder, since disorder
1330
        conflicts with the periodic set model. To read disordered
1331
        structures anyway, choose either :code:`ordered_sites` to remove
1332
        atoms with disorder or :code:`all_sites` include all atoms
1333
        regardless of disorder.
1334
    heaviest_component : bool, optional
1335
        Removes all but the heaviest molecule in the asymmeric unit,
1336
        intended for removing solvents.
1337
    molecular_centres : bool, default False
1338
        Use molecular centres of mass as the motif instead of centres of
1339
        atoms.
1340
1341
    Returns
1342
    -------
1343
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1344
        Represents the crystal as a periodic set, consisting of a finite
1345
        set of points (motif) and lattice (unit cell). Contains other
1346
        useful data, e.g. the crystal's name and information about the
1347
        asymmetric unit for calculation.
1348
1349
    Raises
1350
    ------
1351
    ParseError
1352
        Raised if the structure fails parsing for any of the following:
1353
        1. :code:``disorder == 'skip'`` and disorder is found on any
1354
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
1355
        3. a.fractional_coordinates is None for any a in
1356
        crystal.disordered_molecule, 4. The motif is empty after
1357
        removing H, disordered sites or solvents.
1358
    """
1359
1360
    molecule = crystal.disordered_molecule
1361
1362
    # Disorder
1363
    if disorder == 'skip':
1364
        if crystal.has_disorder or \
1365
         any(_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
1366
            raise ParseError(
1367
                f"{crystal.identifier} has disorder, pass "
1368
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1369
                "disorder"
1370
            )
1371
    elif disorder == 'ordered_sites':
1372
        molecule.remove_atoms(
1373
            a for a in molecule.atoms if _has_disorder(a.label, a.occupancy)
1374
        )
1375
1376
    if remove_hydrogens:
1377
        molecule.remove_atoms(
1378
            a for a in molecule.atoms if a.atomic_symbol in 'HD'
1379
        )
1380
1381
    # Remove atoms with missing coordinates and warn
1382
    if any(a.fractional_coordinates is None for a in molecule.atoms):
1383
        warnings.warn('atoms without sites or missing data will be removed')
1384
        molecule.remove_atoms(
1385
            a for a in molecule.atoms if a.fractional_coordinates is None
1386
        )
1387
1388
    if heaviest_component and len(molecule.components) > 1:
1389
        molecule = _heaviest_component_ccdc(molecule)
1390
1391
    crystal.molecule = molecule
1392
    cellpar = crystal.cell_lengths + crystal.cell_angles
1393
    if None in cellpar:
1394
        raise ParseError(f'{crystal.identifier} has missing cell data')
1395
    cell = cellpar_to_cell(np.array(cellpar))
1396
1397
    if molecular_centres:
1398
        frac_centres = _frac_molecular_centres_ccdc(crystal, _EQ_SITE_TOL)
1399
        mol_centres = np.matmul(frac_centres, cell)
1400
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
1401
1402
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
1403
    # check for None?
1404
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
1405
1406
    if asym_unit.shape[0] == 0:
1407
        raise ParseError(f'{crystal.identifier} has no valid sites')
1408
1409
    asym_unit = np.mod(asym_unit, 1)
1410
1411
    # recommended by pymatgen
1412
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1413
1414
    asym_types = [a.atomic_number for a in asym_atoms]
1415
1416
    # Remove overlapping sites unless disorder == 'all_sites'
1417
    if disorder != 'all_sites':
1418
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1419
        if not np.all(keep_sites):
1420
            warnings.warn(
1421
                'may have overlapping sites; duplicates will be removed'
1422
            )
1423
        asym_unit = asym_unit[keep_sites]
1424
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1425
1426
    # Symmetry operations
1427
    sitesym = crystal.symmetry_operators
1428
    # try spacegroup numbers?
1429
    if not sitesym:
1430
        warnings.warn('no symmetry data found, defaulting to P1')
1431
        sitesym = ['x,y,z']
1432
1433
    # Apply symmetries to asymmetric unit
1434
    rot, trans = _parse_sitesyms(sitesym)
1435
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1436
    _, wyc_muls = np.unique(invs, return_counts=True)
1437
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1438
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1439
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1440
    motif = np.matmul(frac_motif, cell)
1441
1442
    return PeriodicSet(
1443
        motif=motif,
1444
        cell=cell,
1445
        name=crystal.identifier,
1446
        asym_unit=asym_inds,
1447
        multiplicities=wyc_muls,
1448
        types=types
1449
    )
1450
1451
1452
def memoize(f):
1453
    """Cache for _parse_sitesym()."""
1454
    cache = {}
1455
    def wrapper(arg):
1456
        if arg not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
1457
            cache[arg] = f(arg)
1458
        return cache[arg]
1459
    return wrapper
1460
1461
1462
@memoize
1463
def _parse_sitesym(sym: str) -> Tuple[FloatArray, FloatArray]:
1464
    """Parse a single symmetry as an xyz string and return a 3x3
1465
    rotation matrix and a 3x1 translation vector.
1466
    """
1467
1468
    rot = np.zeros((3, 3), dtype=np.float64)
1469
    trans = np.zeros((3, ), dtype=np.float64)
1470
1471
    for ind, element in enumerate(sym.split(',')):
1472
1473
        is_positive = True
1474
        is_fraction = False
1475
        sng_trans = None
1476
        fst_trans = []
1477
        snd_trans = []
1478
1479
        for char in element.lower():
1480
            if char == '+':
1481
                is_positive = True
1482
            elif char == '-':
1483
                is_positive = False
1484
            elif char == '/':
1485
                is_fraction = True
1486
            elif char in 'xyz':
1487
                rot_sgn = 1.0 if is_positive else -1.0
1488
                rot[ind][ord(char) - ord('x')] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1489
            elif char.isdigit() or char == '.':
1490
                if sng_trans is None:
1491
                    sng_trans = 1.0 if is_positive else -1.0
1492
                if is_fraction:
1493
                    snd_trans.append(char)
1494
                else:
1495
                    fst_trans.append(char)
1496
1497
        if not fst_trans:
1498
            e_trans = 0.0
1499
        else:
1500
            e_trans = sng_trans * float(''.join(fst_trans))
1501
1502
        if is_fraction:
1503
            e_trans /= float(''.join(snd_trans))
1504
1505
        trans[ind] = e_trans
1506
1507
    return rot, trans
1508
1509
1510
def _parse_sitesyms(symmetries: List[str]) -> Tuple[FloatArray, FloatArray]:
1511
    """Parse a sequence of symmetries in xyz form and return rotation
1512
    and translation arrays.
1513
    """
1514
    rotations = []
1515
    translations = []
1516
    for sym in symmetries:
1517
        rot, trans = _parse_sitesym(sym)
1518
        rotations.append(rot)
1519
        translations.append(trans)
1520
    return np.array(rotations), np.array(translations)
1521
1522
1523
def _expand_asym_unit(
1524
        asym_unit: FloatArray,
1525
        rotations: FloatArray,
1526
        translations: FloatArray,
1527
        tol: float
1528
) -> Tuple[FloatArray, IntArray]:
1529
    """Expand the asymmetric unit by applying symmetries given by
1530
    ``rotations`` and ``translations``.
1531
    """
1532
1533
    asym_unit = asym_unit.astype(np.float64, copy=False)
1534
    rotations = rotations.astype(np.float64, copy=False)
1535
    translations = translations.astype(np.float64, copy=False)
1536
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1537
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1538
1539
    if not all(_unique_sites(frac_motif, tol)):
1540
        frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1541
1542
    return frac_motif, invs
1543
1544
1545
@numba.njit(cache=True)
1546
def _expand_sites(
1547
        asym_unit: FloatArray, rotations: FloatArray, translations: FloatArray
1548
) -> FloatArray:
1549
    """Expand the asymmetric unit by applying ``rotations`` and
1550
    ``translations``, without yet removing points duplicated because
1551
    they are invariant under a symmetry. Returns a 3D array shape
1552
    (#points, #syms, dims).
1553
    """
1554
1555
    m, dims = asym_unit.shape
1556
    n_syms = len(rotations)
1557
    expanded_sites = np.empty((m, n_syms, dims), dtype=np.float64)
1558
    for i in range(m):
1559
        p = asym_unit[i]
1560
        for j in range(n_syms):
1561
            expanded_sites[i, j] = np.dot(rotations[j], p) + translations[j]
1562
    expanded_sites = np.mod(expanded_sites, 1)
1563
    return expanded_sites
1564
1565
1566
@numba.njit(cache=True)
1567
def _reduce_expanded_sites(
1568
        expanded_sites: FloatArray, tol: float
1569
) -> Tuple[FloatArray, IntArray]:
1570
    """Reduce the asymmetric unit after being expended by symmetries by
1571
    removing invariant points. This is the fast version which works in
1572
    the case that no two sites in the asymmetric unit are equivalent.
1573
    If they are, the reduction is re-ran with
1574
    _reduce_expanded_equiv_sites() to account for it.
1575
    """
1576
1577
    all_unqiue_inds = []
1578
    n_sites, _, dims = expanded_sites.shape
1579
    multiplicities = np.zeros(shape=(n_sites, ))
1580
1581
    for i, sites in enumerate(expanded_sites):
1582
        unique_inds = _unique_sites(sites, tol)
1583
        all_unqiue_inds.append(unique_inds)
1584
        multiplicities[i] = np.sum(unique_inds)
1585
1586
    m = int(np.sum(multiplicities))
1587
    frac_motif = np.zeros(shape=(m, dims))
1588
    inverses = np.zeros(shape=(m, ), dtype=np.int32)
1589
1590
    s = 0
1591
    for i in range(n_sites):
1592
        t = s + multiplicities[i]
1593
        frac_motif[s:t, :] = expanded_sites[i][all_unqiue_inds[i]]
1594
        inverses[s:t] = i
1595
        s = t
1596
1597
    return frac_motif, inverses
1598
1599
1600
def _reduce_expanded_equiv_sites(
1601
        expanded_sites: FloatArray, tol: float
1602
) -> Tuple[FloatArray, IntArray]:
1603
    """Reduce the asymmetric unit after being expended by symmetries by
1604
    removing invariant points. This is the slower version, called after
1605
    the fast version if we find equivalent motif points which need to be
1606
    removed.
1607
    """
1608
1609
    sites = expanded_sites[0]
1610
    unique_inds = _unique_sites(sites, tol)
1611
    frac_motif = sites[unique_inds]
1612
    inverses = [0] * len(frac_motif)
1613
1614
    for i in range(1, len(expanded_sites)):
1615
        sites = expanded_sites[i]
1616
        unique_inds = _unique_sites(sites, tol)
1617
1618
        points = []
1619
        for site in sites[unique_inds]:
1620
            diffs1 = np.abs(site - frac_motif)
1621
            diffs2 = np.abs(diffs1 - 1)
1622
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1623
1624
            if not np.any(mask):
1625
                points.append(site)
1626
            else:
1627
                warnings.warn(
1628
                    'has equivalent sites at positions '
1629
                    f'{inverses[np.argmax(mask)]}, {i}'
1630
                )
1631
1632
        if points:
1633
            inverses.extend(i for _ in range(len(points)))
1634
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1635
1636
    return frac_motif, np.array(inverses, dtype=np.int64)
1637
1638
1639
@numba.njit(cache=True)
1640
def _unique_sites(asym_unit: FloatArray, tol: float) -> npt.NDArray[np.bool_]:
1641
    """Uniquify (within tol) a list of fractional coordinates,
1642
    considering all points modulo 1. Return an array of bools such that
1643
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1644
    """
1645
1646
    m, _ = asym_unit.shape
1647
    where_unique = np.full(shape=(m, ), fill_value=True)
1648
    for i in range(1, m):
1649
        site_diffs1 = np.abs(asym_unit[:i, :] - asym_unit[i])
1650
        site_diffs2 = np.abs(site_diffs1 - 1)
1651
        sites_neq_mask = (site_diffs1 > tol) & (site_diffs2 > tol)
1652
        if not np.all(np.sum(sites_neq_mask, axis=-1)):
1653
            where_unique[i] = False
1654
    return where_unique
1655
1656
1657
def _has_disorder(label: str, occupancy) -> bool:
1658
    """Return True if label ends with ? or occupancy is a number < 1."""
1659
    try:
1660
        occupancy = float(occupancy)
1661
    except Exception:
1662
        occupancy = 1
1663
    return (occupancy < 1) or label.endswith('?')
1664
1665
1666
def _atomic_symbols_from_labels(symbols: List[str]) -> List[str]:
1667
    symbols_ = []
1668
    for label in symbols:
1669
        sym = ''
1670
        if label and label not in ('.', '?'):
1671
            match = re.search(r'([A-Za-z][A-Za-z]?)', label)
1672
            if match is not None:
1673
                sym = match.group()
1674
            sym = list(sym)
1675
            if len(sym) > 0:
1676
                sym[0] = sym[0].upper()
1677
            if len(sym) > 1:
1678
                sym[1] = sym[1].lower()
1679
            sym = ''.join(sym)
1680
        symbols_.append(sym)
1681
    return symbols_
1682
1683
1684
def _get_syms_pymatgen(data: dict) -> Tuple[FloatArray, FloatArray]:
1685
    """Parse symmetry operations given by data = block.data where block
1686
    is a pymatgen CifBlock object. If the symops are not present the
1687
    space group symbol/international number is parsed and symops are
1688
    generated.
1689
    """
1690
1691
    from pymatgen.symmetry.groups import SpaceGroup
1692
    import pymatgen.io.cif
1693
1694
    # Try xyz symmetry operations
1695
    for symmetry_label in _CIF_TAGS['symop']:
1696
        xyz = data.get(symmetry_label)
1697
        if not xyz:
1698
            continue
1699
        if isinstance(xyz, str):
1700
            xyz = [xyz]
1701
        return _parse_sitesyms(xyz)
1702
1703
    symops = []
1704
    # Try spacegroup symbol
1705
    for symmetry_label in _CIF_TAGS['spacegroup_name']:
1706
        sg = data.get(symmetry_label)
1707
        if not sg:
1708
            continue
1709
        sg = re.sub(r'[\s_]', '', sg)
1710
        try:
1711
            spg = pymatgen.io.cif.space_groups.get(sg)
1712
            if not spg:
1713
                continue
1714
            symops = SpaceGroup(spg).symmetry_ops
1715
            break
1716
        except ValueError:
1717
            pass
1718
        try:
1719
            for d in pymatgen.io.cif._get_cod_data():
1720
                if sg == re.sub(r'\s+', '', d['hermann_mauguin']):
1721
                    return _parse_sitesyms(d['symops'])
1722
        except Exception:  # CHANGE
1723
            continue
1724
        if symops:
1725
            break
1726
1727
    # Try international number
1728
    if not symops:
1729
        for symmetry_label in _CIF_TAGS['spacegroup_number']:
1730
            num = data.get(symmetry_label)
1731
            if not num:
1732
                continue
1733
            try:
1734
                i = int(str2float(num))
1735
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1736
                break
1737
            except ValueError:
1738
                continue
1739
1740
    if not symops:
1741
        warnings.warn('no symmetry data found, defaulting to P1')
1742
        return _parse_sitesyms(['x,y,z'])
1743
1744
    rotations = [op.rotation_matrix for op in symops]
1745
    translations = [op.translation_vector for op in symops]
1746
    rotations = np.array(rotations, dtype=np.float64)
1747
    translations = np.array(translations, dtype=np.float64)
1748
    return rotations, translations
1749
1750
1751
def _frac_molecular_centres_ccdc(crystal, tol: float) -> FloatArray:
1752
    """Return the geometric centres of molecules in the unit cell.
1753
    Expects a ccdc Crystal object and returns fractional coordiantes.
1754
    """
1755
1756
    frac_centres = []
1757
    for comp in crystal.packing(inclusion='CentroidIncluded').components:
1758
        coords = [a.fractional_coordinates for a in comp.atoms]
1759
        frac_centres.append([sum(ax) / len(coords) for ax in zip(*coords)])
1760
    frac_centres = np.mod(np.array(frac_centres, dtype=np.float64), 1)
1761
    return frac_centres[_unique_sites(frac_centres, tol)]
1762
1763
1764
def _heaviest_component_ccdc(molecule):
1765
    """Remove all but the heaviest component of the asymmetric unit.
1766
    Intended for removing solvents. Expects and returns a ccdc Molecule
1767
    object.
1768
    """
1769
1770
    component_weights = []
1771
    for component in molecule.components:
1772
        weight = 0
1773
        for a in component.atoms:
1774
            try:
1775
                occ = float(a.occupancy)
1776
            except:
1777
                occ = 1
1778
            try:
1779
                weight += float(a.atomic_weight) * occ
1780
            except ValueError:
1781
                pass
1782
        component_weights.append(weight)
1783
    largest_component_ind = np.argmax(np.array(component_weights))
1784
    molecule = molecule.components[largest_component_ind]
1785
    return molecule
1786
1787
1788
def str2float(string):
1789
    """Remove uncertainty brackets from strings and return the float."""
1790
    try:
1791
        return float(re.sub(r'\(.+\)*', '', string))
1792
    except TypeError:
1793
        if isinstance(string, list) and len(string) == 1:
1794
            return float(re.sub(r'\(.+\)*', '', string[0]))
1795
    except ValueError as e:
1796
        if string.strip() == '.':
1797
            return 0.
1798
        raise e
1799
    raise ParseError(f'{string} cannot be converted to float')
1800
1801
1802
def _snap_small_prec_coords(frac_coords: FloatArray, tol: float) -> FloatArray:
1803
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1804
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1805
    """
1806
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.
1807
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.
1808
    return frac_coords
1809