Passed
Push — master ( a617b6...34fcd6 )
by Daniel
07:54
created

amd.io.CifReader._dir_generator()   A

Complexity

Conditions 5

Size

Total Lines 16
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 16
rs 9.2833
c 0
b 0
f 0
cc 5
nop 3
1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import warnings
9
import collections
10
import os
11
import re
12
import functools
13
import errno
14
import math
15
import json
16
from pathlib import Path
17
from typing import (
18
    Iterable,
19
    Iterator,
20
    Optional,
21
    Union,
22
    Callable,
23
    Tuple,
24
    List
25
)
26
27
import numpy as np
28
import numba
29
import tqdm
30
31
from .utils import cellpar_to_cell
32
from .periodicset import PeriodicSet
33
34
__all__ = [
35
    'CifReader',
36
    'CSDReader',
37
    'ParseError',
38
    'periodicset_from_gemmi_block',
39
    'periodicset_from_ase_cifblock',
40
    'periodicset_from_pymatgen_cifblock',
41
    'periodicset_from_ase_atoms',
42
    'periodicset_from_pymatgen_structure',
43
    'periodicset_from_ccdc_entry',
44
    'periodicset_from_ccdc_crystal',
45
]
46
47
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
48
    return f'{category.__name__}: {message}\n'
49
50
warnings.formatwarning = _custom_warning
51
52
with open(str(Path(__file__).absolute().parent / 'atomic_numbers.json')) as f:
53
    _ATOMIC_NUMBERS = json.load(f)
54
55
_EQ_SITE_TOL: float = 1e-3
56
_CIF_TAGS: dict = {
57
    'cellpar': [
58
        '_cell_length_a',
59
        '_cell_length_b',
60
        '_cell_length_c',
61
        '_cell_angle_alpha',
62
        '_cell_angle_beta',
63
        '_cell_angle_gamma'
64
    ],
65
    'atom_site_fract': [
66
        '_atom_site_fract_x',
67
        '_atom_site_fract_y',
68
        '_atom_site_fract_z'
69
    ],
70
    'atom_site_cartn': [
71
        '_atom_site_Cartn_x',
72
        '_atom_site_Cartn_y',
73
        '_atom_site_Cartn_z'
74
    ],
75
    'symop': [
76
        '_space_group_symop_operation_xyz',
77
        '_space_group_symop.operation_xyz',
78
        '_symmetry_equiv_pos_as_xyz'
79
    ],
80
    'spacegroup_name': [
81
        '_space_group_name_H-M_alt',
82
        '_symmetry_space_group_name_H-M'
83
    ],
84
    'spacegroup_number': [
85
        '_space_group_IT_number',
86
        '_symmetry_Int_Tables_number'
87
    ],
88
}
89
90
91
class _Reader(collections.abc.Iterator):
92
    """Base reader class."""
93
94
    def __init__(
95
            self,
96
            iterable: Iterable,
97
            converter: Callable[..., PeriodicSet],
98
            show_warnings: bool,
99
            verbose: bool
100
    ):
101
102
        self._iterator = iter(iterable)
103
        self._converter = converter
104
        self.show_warnings = show_warnings
105
        if verbose:
106
            self._progress_bar = tqdm.tqdm(desc='Reading', delay=1)
107
        else:
108
            self._progress_bar = None
109
110
    def __next__(self):
111
        """Iterate over self._iterator, passing items through
112
        self._converter and yielding. If
113
        :class:`ParseError <.io.ParseError>` is raised in a call to
114
        self._converter, the item is skipped. Warnings raised in
115
        self._converter are printed if self.show_warnings is True.
116
        """
117
118
        if not self.show_warnings:
119
            warnings.simplefilter('ignore')
120
121
        while True:
122
123
            try:
124
                item = next(self._iterator)
125
            except StopIteration:
126
                if self._progress_bar is not None:
127
                    self._progress_bar.close()
128
                raise StopIteration
129
130
            with warnings.catch_warnings(record=True) as warning_msgs:
131
                try:
132
                    periodic_set = self._converter(item)
133
                except ParseError as err:
134
                    warnings.warn(str(err))
135
                    continue
136
                finally:
137
                    if self._progress_bar is not None:
138
                        self._progress_bar.update(1)
139
140
            for warning in warning_msgs:
141
                msg = f'(name={periodic_set.name}) {warning.message}'
142
                warnings.warn(msg, category=warning.category)
143
144
            return periodic_set
145
146
    def read(self) -> Union[PeriodicSet, List[PeriodicSet]]:
147
        """Read the crystal(s), return one
148
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
149
        only one, otherwise return a list.
150
        """
151
        items = list(self)
152
        if len(items) == 1:
153
            return items[0]
154
        return items
155
156
157
class CifReader(_Reader):
158
    """Read all structures in a .cif file or all files in a folder
159
    with ase or csd-python-api (if installed), yielding
160
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
161
162
    Parameters
163
    ----------
164
    path : str
165
        Path to a .CIF file or directory. (Other files are accepted when
166
        using ``reader='ccdc'``, if csd-python-api is installed.)
167
    reader : str, optional
168
        The backend package used to parse the CIF. The default is
169
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
170
        accepted, as well as :code:`ccdc` if csd-python-api is
171
        installed. The ccdc reader should be able to read any format
172
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
173
        been tested.
174
    remove_hydrogens : bool, optional
175
        Remove Hydrogens from the crystals.
176
    disorder : str, optional
177
        Controls how disordered structures are handled. Default is
178
        ``skip`` which skips any crystal with disorder, since disorder
179
        conflicts with the periodic set model. To read disordered
180
        structures anyway, choose either :code:`ordered_sites` to remove
181
        atoms with disorder or :code:`all_sites` include all atoms
182
        regardless of disorder.
183
    heaviest_component : bool, optional
184
        csd-python-api only. Removes all but the heaviest molecule in
185
        the asymmeric unit, intended for removing solvents.
186
    molecular_centres : bool, default False
187
        csd-python-api only. Extract the centres of molecules in the
188
        unit cell and store in the attribute molecular_centres.
189
    show_warnings : bool, optional
190
        Controls whether warnings that arise during reading are printed.
191
    verbose : bool, default False
192
        If True, prints a progress bar showing the number of items
193
        processed.
194
195
    Yields
196
    ------
197
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
198
        Represents the crystal as a periodic set, consisting of a finite
199
        set of points (motif) and lattice (unit cell). Contains other
200
        data, e.g. the crystal's name and information about the
201
        asymmetric unit.
202
203
    Examples
204
    --------
205
206
        ::
207
208
            # Put all crystals in a .CIF in a list
209
            structures = list(amd.CifReader('mycif.cif'))
210
211
            # Can also accept path to a directory, reading all files inside
212
            structures = list(amd.CifReader('path/to/folder'))
213
214
            # Reads just one if the .CIF has just one crystal
215
            periodic_set = amd.CifReader('mycif.cif').read()
216
217
            # List of AMDs (k=100) of crystals in a .CIF
218
            amds = [amd.AMD(item, 100) for item in amd.CifReader('mycif.cif')]
219
    """
220
221
    def __init__(
222
            self,
223
            path: Union[str, os.PathLike],
224
            reader: str = 'gemmi',
225
            remove_hydrogens: bool = False,
226
            disorder: str = 'skip',
227
            heaviest_component: bool = False,
228
            molecular_centres: bool = False,
229
            show_warnings: bool = True,
230
            verbose: bool = False
231
    ):
232
233
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
234
            raise ValueError(
235
                f"'disorder'' parameter of {self.__class__.__name__} must be "
236
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
237
                f"'{disorder}')"
238
            )
239
240
        if reader != 'ccdc':
241
            if heaviest_component:
242
                raise NotImplementedError(
243
                    "'heaviest_component' parameter of "
244
                    f"{self.__class__.__name__} only implemented with "
245
                    "csd-python-api, if installed pass reader='ccdc'"
246
                )
247
            if molecular_centres:
248
                raise NotImplementedError(
249
                    "'molecular_centres' parameter of "
250
                    f"{self.__class__.__name__} only implemented with "
251
                    "csd-python-api, if installed pass reader='ccdc'"
252
                )
253
254
        # cannot handle some characters (�) in cifs
255
        if reader == 'gemmi':
256
            import gemmi
257
            extensions = {'cif'}
258
            file_parser = gemmi.cif.read_file
259
            converter = functools.partial(
260
                periodicset_from_gemmi_block,
261
                remove_hydrogens=remove_hydrogens,
262
                disorder=disorder
263
            )
264
265
        elif reader in ('ase', 'pycodcif'):
266
            from ase.io.cif import parse_cif
267
            extensions = {'cif'}
268
            file_parser = functools.partial(parse_cif, reader=reader)
269
            converter = functools.partial(
270
                periodicset_from_ase_cifblock,
271
                remove_hydrogens=remove_hydrogens,
272
                disorder=disorder
273
            )
274
275
        elif reader == 'pymatgen':
276
277
            def _pymatgen_cif_parser(path):
278
                from pymatgen.io.cif import CifFile
279
                return CifFile.from_file(path).data.values()
280
281
            extensions = {'cif'}
282
            file_parser = _pymatgen_cif_parser
283
            converter = functools.partial(
284
                periodicset_from_pymatgen_cifblock,
285
                remove_hydrogens=remove_hydrogens,
286
                disorder=disorder
287
            )
288
289
        elif reader == 'ccdc':
290
            try:
291
                import ccdc.io
292
            except (ImportError, RuntimeError) as e:
293
                raise ImportError('Failed to import csd-python-api') from e
294
295
            extensions = set(ccdc.io.EntryReader.known_suffixes.keys())
296
            file_parser = ccdc.io.EntryReader
297
            converter = functools.partial(
298
                periodicset_from_ccdc_entry,
299
                remove_hydrogens=remove_hydrogens,
300
                disorder=disorder,
301
                molecular_centres=molecular_centres,
302
                heaviest_component=heaviest_component
303
            )
304
305
        else:
306
            raise ValueError(
307
                f"'reader' parameter of {self.__class__.__name__} must be one "
308
                f"of 'gemmi', 'pymatgen', 'ccdc', 'ase', or 'pycodcif' "
309
                f"(passed '{reader}')"
310
            )
311
312
        path = Path(path)
313
        if path.is_file():
314
            iterable = file_parser(str(path))
315
        elif path.is_dir():
316
            iterable = CifReader._dir_generator(path, file_parser, extensions)
317
        else:
318
            raise FileNotFoundError(
319
                errno.ENOENT, os.strerror(errno.ENOENT), str(path)
320
            )
321
322
        super().__init__(iterable, converter, show_warnings, verbose)
323
324
    @staticmethod
325
    def _dir_generator(
326
            path: Path, file_parser: Callable, extensions: Iterable
327
    ) -> Iterator:
328
        """Generate items from all files with extensions in
329
        ``extensions`` from a directory ``path``."""
330
        for file_path in path.iterdir():
331
            if not file_path.is_file():
332
                continue
333
            if file_path.suffix[1:].lower() not in extensions:
334
                continue
335
            try:
336
                yield from file_parser(str(file_path))
337
            except Exception as e:
338
                warnings.warn(
339
                    f'Error parsing "{str(file_path)}", skipping file. '
340
                    f'Exception: {repr(e)}'
341
                )
342
343
344
class CSDReader(_Reader):
345
    """Read structures from the CSD with csd-python-api, yielding
346
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
347
348
    Parameters
349
    ----------
350
    refcodes : str or List[str], optional
351
        Single or list of CSD refcodes to read. If None or 'CSD',
352
        iterates over the whole CSD.
353
    families : bool, optional
354
        Read all entries whose refcode starts with the given strings, or
355
        'families' (e.g. giving 'DEBXIT' reads all entries starting with
356
        DEBXIT).
357
    remove_hydrogens : bool, optional
358
        Remove hydrogens from the crystals.
359
    disorder : str, optional
360
        Controls how disordered structures are handled. Default is
361
        ``skip`` which skips any crystal with disorder, since disorder
362
        conflicts with the periodic set model. To read disordered
363
        structures anyway, choose either :code:`ordered_sites` to remove
364
        atoms with disorder or :code:`all_sites` include all atoms
365
        regardless of disorder.
366
    heaviest_component : bool, optional
367
        Removes all but the heaviest molecule in the asymmeric unit,
368
        intended for removing solvents.
369
    molecular_centres : bool, default False
370
        Extract the centres of molecules in the unit cell and store in
371
        attribute molecular_centres.
372
    show_warnings : bool, optional
373
        Controls whether warnings that arise during reading are printed.
374
    verbose : bool, default False
375
        If True, prints a progress bar showing the number of items
376
        processed.
377
378
    Yields
379
    ------
380
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
381
        Represents the crystal as a periodic set, consisting of a finite
382
        set of points (motif) and lattice (unit cell). Contains other
383
        useful data, e.g. the crystal's name and information about the
384
        asymmetric unit for calculation.
385
386
    Examples
387
    --------
388
389
        ::
390
391
            # Put these entries in a list
392
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
393
            structures = list(amd.CSDReader(refcodes))
394
395
            # Read refcode families (any whose refcode starts with strings in the list)
396
            refcode_families = ['ACSALA', 'HXACAN']
397
            structures = list(amd.CSDReader(refcode_families, families=True))
398
399
            # Get AMDs (k=100) for crystals in these families
400
            refcodes = ['ACSALA', 'HXACAN']
401
            amds = []
402
            for periodic_set in amd.CSDReader(refcodes, families=True):
403
                amds.append(amd.AMD(periodic_set, 100))
404
405
            # Giving the reader nothing reads from the whole CSD.
406
            for periodic_set in amd.CSDReader():
407
                ...
408
    """
409
410
    def __init__(
411
            self,
412
            refcodes: Optional[Union[str, List[str]]] = None,
413
            families: bool = False,
414
            remove_hydrogens: bool = False,
415
            disorder: str = 'skip',
416
            heaviest_component: bool = False,
417
            molecular_centres: bool = False,
418
            show_warnings: bool = True,
419
            verbose: bool = False
420
    ):
421
422
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
423
            raise ValueError(
424
                f"'disorder'' parameter of {self.__class__.__name__} must be "
425
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
426
                f"'{disorder}')"
427
            )
428
429
        try:
430
            import ccdc.search
431
            import ccdc.io
432
        except (ImportError, RuntimeError) as e:
433
            raise ImportError('Failed to import csd-python-api') from e
434
435
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
436
            refcodes = None
437
        if refcodes is None:
438
            families = False
439
        elif isinstance(refcodes, str):
440
            refcodes = [refcodes]
441
        elif isinstance(refcodes, list):
442
            if not all(isinstance(refcode, str) for refcode in refcodes):
443
                raise ValueError(
444
                    f'{self.__class__.__name__} expects None, a string or '
445
                    'list of strings.'
446
                )
447
        else:
448
            raise ValueError(
449
                f'{self.__class__.__name__} expects None, a string or list of '
450
                f'strings, got {refcodes.__class__.__name__}'
451
            )
452
453
        if families:
454
            all_refcodes = []
455
            for refcode in refcodes:
456
                query = ccdc.search.TextNumericSearch()
457
                query.add_identifier(refcode)
458
                hits = [hit.identifier for hit in query.search()]
459
                all_refcodes.extend(hits)
460
            # filter to unique refcodes while keeping order
461
            refcodes = []
462
            seen = set()
463
            for refcode in all_refcodes:
464
                if refcode not in seen:
465
                    refcodes.append(refcode)
466
                    seen.add(refcode)
467
468
        converter = functools.partial(
469
            periodicset_from_ccdc_entry,
470
            remove_hydrogens=remove_hydrogens,
471
            disorder=disorder,
472
            molecular_centres=molecular_centres,
473
            heaviest_component=heaviest_component
474
        )
475
476
        entry_reader = ccdc.io.EntryReader('CSD')
477
        if refcodes is None:
478
            iterable = entry_reader
479
        else:
480
            iterable = map(entry_reader.entry, refcodes)
481
482
        super().__init__(iterable, converter, show_warnings, verbose)
483
484
485
class ParseError(ValueError):
486
    """Raised when an item cannot be parsed into a periodic set."""
487
    pass
488
489
490
def periodicset_from_gemmi_block(
491
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
492
) -> PeriodicSet:
493
    """Convert a :class:`gemmi.cif.Block` object to a
494
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
495
    :class:`gemmi.cif.Block` is the type returned by
496
    :func:`gemmi.cif.read_file`.
497
498
    Parameters
499
    ----------
500
    block : :class:`gemmi.cif.Block`
501
        An ase CIFBlock object representing a crystal.
502
    remove_hydrogens : bool, optional
503
        Remove Hydrogens from the crystal.
504
    disorder : str, optional
505
        Controls how disordered structures are handled. Default is
506
        ``skip`` which skips any crystal with disorder, since disorder
507
        conflicts with the periodic set model. To read disordered
508
        structures anyway, choose either :code:`ordered_sites` to remove
509
        atoms with disorder or :code:`all_sites` include all atoms
510
        regardless of disorder.
511
512
    Returns
513
    -------
514
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
515
        Represents the crystal as a periodic set, consisting of a finite
516
        set of points (motif) and lattice (unit cell). Contains other
517
        useful data, e.g. the crystal's name and information about the
518
        asymmetric unit for calculation.
519
520
    Raises
521
    ------
522
    ParseError
523
        Raised if the structure fails to be parsed for any of the
524
        following: 1. Required data is missing (e.g. cell parameters),
525
        2. :code:``disorder == 'skip'`` and disorder is found on any
526
        atom, 3. The motif is empty after removing H or disordered
527
        sites.
528
    """
529
530
    import gemmi
531
532
    # Unit cell
533
    cellpar = [block.find_value(t) for t in _CIF_TAGS['cellpar']]
534
    if not all(isinstance(par, str) for par in cellpar):
535
        raise ParseError(f'{block.name} has missing cell data')
536
    cellpar = np.array([gemmi.cif.as_number(par) for par in cellpar])
537
    if np.isnan(np.sum(cellpar)):
538
        raise ParseError(f'{block.name} has missing cell data')
539
    cell = cellpar_to_cell(cellpar)
540
541
    # Asymmetric unit coordinates
542
    xyz_loop = block.find(_CIF_TAGS['atom_site_fract']).loop
543
    if xyz_loop is None:
544
        xyz_loop = block.find(_CIF_TAGS['atom_site_cartn']).loop
545
        if xyz_loop is None:
546
            raise ParseError(f'{block.name} has missing coordinate data')
547
        else:
548
            raise ParseError(
549
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
550
                'only _atom_site_fract_ is supported'
551
            )
552
553
    tablified_loop = [[] for _ in range(len(xyz_loop.tags))]
554
    for i, item in enumerate(xyz_loop.values):
555
        tablified_loop[i % xyz_loop.width()].append(item)
556
    loop_dict = {tag: l for tag, l in zip(xyz_loop.tags, tablified_loop)}
557
    xyz_str = [loop_dict[t] for t in _CIF_TAGS['atom_site_fract']]
558
    asym_unit = np.transpose(np.array(
559
        [[gemmi.cif.as_number(c) for c in xyz] for xyz in xyz_str]
560
    ))
561
    asym_unit = np.mod(asym_unit, 1)
562
563
    # recommended by pymatgen
564
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4) 
565
566
    # Labels
567
    if '_atom_site_label' in loop_dict:
568
        labels = [
569
            gemmi.cif.as_string(lab) for lab in loop_dict['_atom_site_label']
570
        ]
571
    else:
572
        labels = [''] * xyz_loop.length()
573
574
    # Atomic types
575
    if '_atom_site_type_symbol' in loop_dict:
576
        symbols = []
577
        for s in loop_dict['_atom_site_type_symbol']:
578
            sym = gemmi.cif.as_string(s)
579
            sym = re.search(r'([A-Z][a-z]?)', sym).group()
580
            symbols.append(sym)
581
    else:  # Get atomic types from label
582
        symbols = _process_symbols(labels)
583
    asym_types = [_ATOMIC_NUMBERS[s] for s in symbols]
584
585
    # Occupancies
586
    if '_atom_site_occupancy' in loop_dict:
587
        occs = [
588
            gemmi.cif.as_number(oc) for oc in loop_dict['_atom_site_occupancy']
589
        ]
590
        occupancies = [occ if not math.isnan(occ) else 1 for occ in occs]
591
    else:
592
        occupancies = [1] * xyz_loop.length()
593
594
    # Remove sites with missing coordinates, disorder and Hydrogens if needed
595
    remove_sites = []
596
    remove_sites.extend(np.nonzero(np.isnan(asym_unit.min(axis=-1)))[0])
597
598
    if disorder == 'skip':
599
        if any(_has_disorder(l, o) for l, o in zip(labels, occupancies)):
600
            raise ParseError(
601
                f"{block.name} has disorder, pass disorder='ordered_sites' or "
602
                "'all_sites' to remove/ignore disorder"
603
            )
604
    elif disorder == 'ordered_sites':
605
        for i, (label, occ) in enumerate(zip(labels, occupancies)):
606
            if _has_disorder(label, occ):
607
                remove_sites.append(i)
608
609
    if remove_hydrogens:
610
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
611
612
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
613
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
614
    if asym_unit.shape[0] == 0:
615
        raise ParseError(f'{block.name} has no valid sites')
616
617
    if disorder != 'all_sites':
618
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
619
        if not np.all(keep_sites):
620
            warnings.warn(
621
                'may have overlapping sites; duplicates will be removed'
622
            )
623
        asym_unit = asym_unit[keep_sites]
624
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
625
626
    # Symmetry operations, try xyz strings first
627
    for tag in _CIF_TAGS['symop']:
628
        sitesym = [v.str(0) for v in block.find([tag])]
629
        if sitesym:
630
            rot, trans = _parse_sitesyms(sitesym)
631
            break
632
    else:
633
        # Try spacegroup name; can be a pair or in a loop
634
        spg = None
635
        for tag in _CIF_TAGS['spacegroup_name']:
636
            for value in block.find([tag]):
637
                try:
638
                    # Some names cannot be parsed by gemmi.SpaceGroup
639
                    spg = gemmi.SpaceGroup(value.str(0))
640
                    break
641
                except ValueError:
642
                    continue
643
            if spg is not None:
644
                break
645
646
        if spg is None:
647
            # Try international number
648
            for tag in _CIF_TAGS['spacegroup_number']:
649
                spg_num = block.find_value(tag)
650
                if spg_num is not None:
651
                    spg_num = gemmi.cif.as_int(spg_num)
652
                    break
653
            else:
654
                warnings.warn('no symmetry data found, defaulting to P1')
655
                spg_num = 1
656
            spg = gemmi.SpaceGroup(spg_num)
657
658
        rot = np.array([np.array(o.rot) / o.DEN for o in spg.operations()])
659
        trans = np.array([np.array(o.tran) / o.DEN for o in spg.operations()])
660
661
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
662
    _, wyc_muls = np.unique(invs, return_counts=True)
663
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
664
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
665
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
666
    motif = np.matmul(frac_motif, cell)
667
668
    return PeriodicSet(
669
        motif=motif,
670
        cell=cell,
671
        name=block.name,
672
        asym_unit=asym_inds,
673
        multiplicities=wyc_muls,
674
        types=types
675
    )
676
677
678
def periodicset_from_ase_cifblock(
679
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
680
) -> PeriodicSet:
681
    """Convert a :class:`ase.io.cif.CIFBlock` object to a 
682
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
683
    :class:`ase.io.cif.CIFBlock` is the type returned by
684
    :func:`ase.io.cif.parse_cif`.
685
686
    Parameters
687
    ----------
688
    block : :class:`ase.io.cif.CIFBlock`
689
        An ase :class:`ase.io.cif.CIFBlock` object representing a
690
        crystal.
691
    remove_hydrogens : bool, optional
692
        Remove Hydrogens from the crystal.
693
    disorder : str, optional
694
        Controls how disordered structures are handled. Default is
695
        ``skip`` which skips any crystal with disorder, since disorder
696
        conflicts with the periodic set model. To read disordered
697
        structures anyway, choose either :code:`ordered_sites` to remove
698
        atoms with disorder or :code:`all_sites` include all atoms
699
        regardless of disorder.
700
701
    Returns
702
    -------
703
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
704
        Represents the crystal as a periodic set, consisting of a finite
705
        set of points (motif) and lattice (unit cell). Contains other
706
        useful data, e.g. the crystal's name and information about the
707
        asymmetric unit for calculation.
708
709
    Raises
710
    ------
711
    ParseError
712
        Raised if the structure fails to be parsed for any of the
713
        following: 1. Required data is missing (e.g. cell parameters),
714
        2. The motif is empty after removing H or disordered sites,
715
        3. :code:``disorder == 'skip'`` and disorder is found on any
716
        atom.
717
    """
718
719
    import ase
720
    import ase.spacegroup
721
722
    # Unit cell
723
    cellpar = [block.get(tag) for tag in _CIF_TAGS['cellpar']]
724
    if None in cellpar:
725
        raise ParseError(f'{block.name} has missing cell data')
726
    cell = cellpar_to_cell(np.array(cellpar))
727
728
    # Asymmetric unit coordinates. ase removes uncertainty brackets
729
    asym_unit = [block.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
730
    if None in asym_unit:
731
        asym_unit = [
732
            block.get(tag.lower()) for tag in _CIF_TAGS['atom_site_cartn']
733
        ]
734
        if None in asym_unit:
735
            raise ParseError(f'{block.name} has missing coordinates')
736
        else:
737
            raise ParseError(
738
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
739
                'only _atom_site_fract_ is supported'
740
            )
741
    asym_unit = list(zip(*asym_unit))
742
743
    # Labels
744
    asym_labels = block.get('_atom_site_label')
745
    if asym_labels is None:
746
        asym_labels = [''] * len(asym_unit)
747
748
    # Atomic types
749
    asym_symbols = block.get('_atom_site_type_symbol')
750
    if asym_symbols is not None:
751
        asym_symbols_ = _process_symbols(asym_symbols)
752
    else:
753
        asym_symbols_ = [''] * len(asym_unit)
754
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
755
756
    # Find where sites have disorder if necassary
757
    has_disorder = []
758
    if disorder != 'all_sites':
759
        occupancies = block.get('_atom_site_occupancy')
760
        if occupancies is None:
761
            occupancies = [1] * len(asym_unit)
762
        for lab, occ in zip(asym_labels, occupancies):
763
            has_disorder.append(_has_disorder(lab, occ))
764
765
    # Remove sites with ?, . or other invalid string for coordinates
766
    invalid = []
767
    for i, xyz in enumerate(asym_unit):
768
        if not all(isinstance(coord, (int, float)) for coord in xyz):
769
            invalid.append(i)
770
    if invalid:
771
        warnings.warn('atoms without sites or missing data will be removed')
772
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
773
        asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
774
        if disorder != 'all_sites':
775
            has_disorder = [
776
                d for i, d in enumerate(has_disorder) if i not in invalid
777
            ]
778
779
    remove_sites = []
780
781
    if remove_hydrogens:
782
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
783
784
    # Remove atoms with fractional occupancy or raise ParseError
785 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
786
        for i, dis in enumerate(has_disorder):
787
            if i in remove_sites:
788
                continue
789
            if dis:
790
                if disorder == 'skip':
791
                    raise ParseError(
792
                        f'{block.name} has disorder, pass '
793
                        "disorder='ordered_sites' or 'all_sites' to "
794
                        'remove/ignore disorder'
795
                    )
796
                elif disorder == 'ordered_sites':
797
                    remove_sites.append(i)
798
799
    # Asymmetric unit
800
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
801
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
802
    if len(asym_unit) == 0:
803
        raise ParseError(f'{block.name} has no valid sites')
804
    asym_unit = np.mod(np.array(asym_unit), 1)
805
806
    # recommended by pymatgen
807
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
808
809
    # Remove overlapping sites unless disorder == 'all_sites'
810
    if disorder != 'all_sites':
811
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
812
        if not np.all(keep_sites):
813
            warnings.warn(
814
                'may have overlapping sites, duplicates will be removed'
815
            )
816
            asym_unit = asym_unit[keep_sites]
817
            asym_types = [t for t, keep in zip(asym_types, keep_sites) if keep]
818
819
    # Get symmetry operations
820
    sitesym = block._get_any(_CIF_TAGS['symop'])
821
    if sitesym is None:
822
        label_or_num = block._get_any(
823
            [s.lower() for s in _CIF_TAGS['spacegroup_name']]
824
        )
825
        if label_or_num is None:
826
            label_or_num = block._get_any(
827
                [s.lower() for s in _CIF_TAGS['spacegroup_number']]
828
            )
829
        if label_or_num is None:
830
            warnings.warn('no symmetry data found, defaulting to P1')
831
            label_or_num = 1
832
        spg = ase.spacegroup.Spacegroup(label_or_num)
833
        rot, trans = spg.get_op()
834
    else:
835
        if isinstance(sitesym, str):
836
            sitesym = [sitesym]
837
        rot, trans = _parse_sitesyms(sitesym)
838
839
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
840
    _, wyc_muls = np.unique(invs, return_counts=True)
841
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
842
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
843
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
844
    motif = np.matmul(frac_motif, cell)
845
846
    return PeriodicSet(
847
        motif=motif,
848
        cell=cell,
849
        name=block.name,
850
        asym_unit=asym_inds,
851
        multiplicities=wyc_muls,
852
        types=types
853
    )
854
855
856
def periodicset_from_pymatgen_cifblock(
857
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
858
) -> PeriodicSet:
859
    """Convert a :class:`pymatgen.io.cif.CifBlock` object to a
860
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
861
    :class:`pymatgen.io.cif.CifBlock` is the type returned by
862
    :class:`pymatgen.io.cif.CifFile`.
863
864
    Parameters
865
    ----------
866
    block : :class:`pymatgen.io.cif.CifBlock`
867
        A pymatgen CifBlock object representing a crystal.
868
    remove_hydrogens : bool, optional
869
        Remove Hydrogens from the crystal.
870
    disorder : str, optional
871
        Controls how disordered structures are handled. Default is
872
        ``skip`` which skips any crystal with disorder, since disorder
873
        conflicts with the periodic set model. To read disordered
874
        structures anyway, choose either :code:`ordered_sites` to remove
875
        atoms with disorder or :code:`all_sites` include all atoms
876
        regardless of disorder.
877
878
    Returns
879
    -------
880
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
881
        Represents the crystal as a periodic set, consisting of a finite
882
        set of points (motif) and lattice (unit cell). Contains other
883
        useful data, e.g. the crystal's name and information about the
884
        asymmetric unit for calculation.
885
886
    Raises
887
    ------
888
    ParseError
889
        Raised if the structure can/should not be parsed for the
890
        following reasons: 1. No sites found or motif is empty after
891
        removing Hydrogens & disorder, 2. A site has missing
892
        coordinates, 3. :code:``disorder == 'skip'`` and disorder is
893
        found on any atom.
894
    """
895
896
    from pymatgen.io.cif import str2float
897
898
    odict = block.data
899
900
    # Unit cell
901
    cellpar = [odict.get(tag) for tag in _CIF_TAGS['cellpar']]
902
    if any(par in (None, '?', '.') for par in cellpar):
903
        raise ParseError(f'{block.header} has missing cell data')
904
    cell = cellpar_to_cell(
905
        np.array([str2float(v) for v in cellpar], dtype=np.float64)
906
    )
907
908
    # Asymmetric unit coordinates
909
    asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
910
    # check for . and ?
911
    if None in asym_unit:
912
        asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_cartn']]
913
        if None in asym_unit:
914
            raise ParseError(f'{block.header} has missing coordinates')
915
        else:
916
            raise ParseError(
917
                f'{block.header} uses _atom_site_Cartn_ tags for coordinates, '
918
                'only _atom_site_fract_ is supported'
919
            )
920
    asym_unit = list(zip(*asym_unit))
921
    asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
922
923
    # Labels
924
    asym_labels = odict.get('_atom_site_label')
925
    if asym_labels is None:
926
        asym_labels = [''] * len(asym_unit)
927
928
    # Atomic types
929
    asym_symbols = odict.get('_atom_site_type_symbol')
930
    if asym_symbols is not None:
931
        asym_symbols_ = _process_symbols(asym_symbols)
932
    else:
933
        asym_symbols_ = [''] * len(asym_unit)
934
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
935
936
    # Find where sites have disorder if necassary
937
    has_disorder = []
938
    if disorder != 'all_sites':
939
        occupancies = odict.get('_atom_site_occupancy')
940
        if occupancies is None:
941
            occupancies = np.ones((len(asym_unit), ))
942
        else:
943
            occupancies = np.array([str2float(occ) for occ in occupancies])
944
        labels = odict.get('_atom_site_label')
945
        if labels is None:
946
            labels = [''] * len(asym_unit)
947
        for lab, occ in zip(labels, occupancies):
948
            has_disorder.append(_has_disorder(lab, occ))
949
950
    # Remove sites with ?, . or other invalid string for coordinates
951
    invalid = []
952
    for i, xyz in enumerate(asym_unit):
953
        if not all(isinstance(coord, (int, float)) for coord in xyz):
954
            invalid.append(i)
955
956
    if invalid:
957
        warnings.warn('atoms without sites or missing data will be removed')
958
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
959
        asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
960
        if disorder != 'all_sites':
961
            has_disorder = [
962
                d for i, d in enumerate(has_disorder) if i not in invalid
963
            ]
964
965
    remove_sites = []
966
967
    if remove_hydrogens:
968
        remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
0 ignored issues
show
introduced by
The variable i does not seem to be defined in case the for loop on line 952 is not entered. Are you sure this can never be the case?
Loading history...
969
970
    # Remove atoms with fractional occupancy or raise ParseError
971 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
972
        for i, dis in enumerate(has_disorder):
973
            if i in remove_sites:
974
                continue
975
            if dis:
976
                if disorder == 'skip':
977
                    raise ParseError(
978
                        f'{block.header} has disorder, pass '
979
                        "disorder='ordered_sites' or 'all_sites' to "
980
                        'remove/ignore disorder'
981
                    )
982
                elif disorder == 'ordered_sites':
983
                    remove_sites.append(i)
984
985
    # Asymmetric unit
986
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
987
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
988
    if len(asym_unit) == 0:
989
        raise ParseError(f'{block.header} has no valid sites')
990
    asym_unit = np.mod(np.array(asym_unit), 1)
991
992
    # recommended by pymatgen
993
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
994
995
    # Remove overlapping sites unless disorder == 'all_sites'
996
    if disorder != 'all_sites':
997
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
998
        if not np.all(keep_sites):
999
            warnings.warn(
1000
                'may have overlapping sites; duplicates will be removed'
1001
            )
1002
        asym_unit = asym_unit[keep_sites]
1003
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1004
1005
    # Apply symmetries to asymmetric unit
1006
    rot, trans = _get_syms_pymatgen(odict)
1007
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1008
    _, wyc_muls = np.unique(invs, return_counts=True)
1009
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1010
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1011
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1012
    motif = np.matmul(frac_motif, cell)
1013
1014
    return PeriodicSet(
1015
        motif=motif,
1016
        cell=cell,
1017
        name=block.header,
1018
        asym_unit=asym_inds,
1019
        multiplicities=wyc_muls,
1020
        types=types
1021
    )
1022
1023
1024
def periodicset_from_ase_atoms(
1025
        atoms, remove_hydrogens: bool = False
1026
) -> PeriodicSet:
1027
    """Convert an :class:`ase.atoms.Atoms` object to a
1028
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
1029
    the option to remove disorder.
1030
1031
    Parameters
1032
    ----------
1033
    atoms : :class:`ase.atoms.Atoms`
1034
        An ase :class:`ase.atoms.Atoms` object representing a crystal.
1035
    remove_hydrogens : bool, optional
1036
        Remove Hydrogens from the crystal.
1037
1038
    Returns
1039
    -------
1040
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1041
        Represents the crystal as a periodic set, consisting of a finite
1042
        set of points (motif) and lattice (unit cell). Contains other
1043
        useful data, e.g. the crystal's name and information about the
1044
        asymmetric unit for calculation.
1045
1046
    Raises
1047
    ------
1048
    ParseError
1049
        Raised if there are no valid sites in atoms.
1050
    """
1051
1052
    from ase.spacegroup import get_basis
1053
1054
    cell = atoms.get_cell().array
1055
1056
    remove_inds = []
1057
    if remove_hydrogens:
1058
        for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
1059
            remove_inds.append(i)
1060
    for i in sorted(remove_inds, reverse=True):
1061
        atoms.pop(i)
1062
1063
    if len(atoms) == 0:
1064
        raise ParseError('ase Atoms object has no valid sites')
1065
1066
    # Symmetry operations from spacegroup
1067
    spg = None
1068
    if 'spacegroup' in atoms.info:
1069
        spg = atoms.info['spacegroup']
1070
        rot, trans = spg.rotations, spg.translations
1071
    else:
1072
        warnings.warn('no symmetry data found, defaulting to P1')
1073
        rot = np.identity(3)[None, :]
1074
        trans = np.zeros((1, 3))
1075
1076
    # Asymmetric unit. ase default tol is 1e-5
1077
    # do differently! get_basis determines a reduced asym unit from the atoms;
1078
    # surely this is not needed!
1079
    asym_unit = get_basis(atoms, spacegroup=spg, tol=_EQ_SITE_TOL)
1080
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1081
    _, wyc_muls = np.unique(invs, return_counts=True)
1082
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1083
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1084
    types = atoms.get_atomic_numbers().astype(np.uint8)
1085
    motif = np.matmul(frac_motif, cell)
1086
1087
    return PeriodicSet(
1088
        motif=motif,
1089
        cell=cell,
1090
        asym_unit=asym_inds,
1091
        multiplicities=wyc_muls,
1092
        types=types
1093
    )
1094
1095
1096
def periodicset_from_pymatgen_structure(
1097
        structure, remove_hydrogens: bool = False, disorder: str = 'skip'
1098
) -> PeriodicSet:
1099
    """Convert a :class:`pymatgen.core.structure.Structure` object to a
1100
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not set
1101
    the name of the periodic set, as pymatgen Structure objects seem to
1102
    have no name attribute.
1103
1104
    Parameters
1105
    ----------
1106
    structure : :class:`pymatgen.core.structure.Structure`
1107
        A pymatgen Structure object representing a crystal.
1108
    remove_hydrogens : bool, optional
1109
        Remove Hydrogens from the crystal.
1110
    disorder : str, optional
1111
        Controls how disordered structures are handled. Default is
1112
        ``skip`` which skips any crystal with disorder, since disorder
1113
        conflicts with the periodic set model. To read disordered
1114
        structures anyway, choose either :code:`ordered_sites` to remove
1115
        atoms with disorder or :code:`all_sites` include all atoms
1116
        regardless of disorder.
1117
1118
    Returns
1119
    -------
1120
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1121
        Represents the crystal as a periodic set, consisting of a finite
1122
        set of points (motif) and lattice (unit cell). Contains other
1123
        useful data, e.g. the crystal's name and information about the
1124
        asymmetric unit for calculation.
1125
1126
    Raises
1127
    ------
1128
    ParseError
1129
        Raised if the :code:`disorder == 'skip'` and
1130
        :code:`not structure.is_ordered`
1131
    """
1132
1133
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1134
1135
    if remove_hydrogens:
1136
        structure.remove_species(['H', 'D'])
1137
1138
    # Disorder
1139
    if disorder == 'skip':
1140
        if not structure.is_ordered:
1141
            raise ParseError(
1142
                'pymatgen Structure has disorder, pass '
1143
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1144
                'disorder'
1145
            )
1146
    elif disorder == 'ordered_sites':
1147
        remove_inds = []
1148
        for i, comp in enumerate(structure.species_and_occu):
1149
            if comp.num_atoms < 1:
1150
                remove_inds.append(i)
1151
        structure.remove_sites(remove_inds)
1152
1153
    motif = structure.cart_coords
1154
    cell = structure.lattice.matrix
1155
    sym_structure = SpacegroupAnalyzer(structure).get_symmetrized_structure()
1156
    eq_inds = sym_structure.equivalent_indices
1157
    asym_inds = np.array([ix_list[0] for ix_list in eq_inds], dtype=np.int32)
1158
    wyc_muls = np.array([len(ix_list) for ix_list in eq_inds], dtype=np.int32)
1159
    types = np.array(sym_structure.atomic_numbers, dtype=np.uint8)
1160
1161
    return PeriodicSet(
1162
        motif=motif,
1163
        cell=cell,
1164
        asym_unit=asym_inds,
1165
        multiplicities=wyc_muls,
1166
        types=types
1167
    )
1168
1169
1170
def periodicset_from_ccdc_entry(
1171
        entry,
1172
        remove_hydrogens: bool = False,
1173
        disorder: str = 'skip',
1174
        heaviest_component: bool = False,
1175
        molecular_centres: bool = False
1176
) -> PeriodicSet:
1177
    """Convert a :class:`ccdc.entry.Entry` object to a
1178
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1179
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
1180
1181
    Parameters
1182
    ----------
1183
    entry : :class:`ccdc.entry.Entry`
1184
        A ccdc Entry object representing a database entry.
1185
    remove_hydrogens : bool, optional
1186
        Remove Hydrogens from the crystal.
1187
    disorder : str, optional
1188
        Controls how disordered structures are handled. Default is
1189
        ``skip`` which skips any crystal with disorder, since disorder
1190
        conflicts with the periodic set model. To read disordered
1191
        structures anyway, choose either :code:`ordered_sites` to remove
1192
        atoms with disorder or :code:`all_sites` include all atoms
1193
        regardless of disorder.
1194
    heaviest_component : bool, optional
1195
        Removes all but the heaviest molecule in the asymmeric unit,
1196
        intended for removing solvents.
1197
    molecular_centres : bool, default False
1198
        Use molecular centres of mass as the motif instead of centres of
1199
        atoms.
1200
1201
    Returns
1202
    -------
1203
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1204
        Represents the crystal as a periodic set, consisting of a finite
1205
        set of points (motif) and lattice (unit cell). Contains other
1206
        useful data, e.g. the crystal's name and information about the
1207
        asymmetric unit for calculation.
1208
1209
    Raises
1210
    ------
1211
    ParseError
1212
        Raised if the structure fails parsing for any of the following:
1213
        1. entry.has_3d_structure is False, 2.
1214
        :code:``disorder == 'skip'`` and disorder is found on any atom,
1215
        3. entry.crystal.molecule.all_atoms_have_sites is False,
1216
        4. a.fractional_coordinates is None for any a in
1217
        entry.crystal.disordered_molecule, 5. The motif is empty after
1218
        removing Hydrogens and disordered sites.
1219
    """
1220
1221
    # Entry specific flag
1222
    if not entry.has_3d_structure:
1223
        raise ParseError(f'{entry.identifier} has no 3D structure')
1224
1225
    # Disorder
1226
    if disorder == 'skip' and entry.has_disorder:
1227
        raise ParseError(
1228
            f"{entry.identifier} has disorder, pass disorder='ordered_sites' "
1229
            "or 'all_sites' to remove/ignore disorder"
1230
        )
1231
1232
    return periodicset_from_ccdc_crystal(
1233
        entry.crystal,
1234
        remove_hydrogens=remove_hydrogens,
1235
        disorder=disorder,
1236
        heaviest_component=heaviest_component,
1237
        molecular_centres=molecular_centres
1238
    )
1239
1240
1241
def periodicset_from_ccdc_crystal(
1242
        crystal,
1243
        remove_hydrogens: bool = False,
1244
        disorder: str = 'skip',
1245
        heaviest_component: bool = False,
1246
        molecular_centres: bool = False
1247
) -> PeriodicSet:
1248
    """Convert a :class:`ccdc.crystal.Crystal` object to a
1249
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1250
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
1251
1252
    Parameters
1253
    ----------
1254
    crystal : :class:`ccdc.crystal.Crystal`
1255
        A ccdc Crystal object representing a crystal structure.
1256
    remove_hydrogens : bool, optional
1257
        Remove Hydrogens from the crystal.
1258
    disorder : str, optional
1259
        Controls how disordered structures are handled. Default is
1260
        ``skip`` which skips any crystal with disorder, since disorder
1261
        conflicts with the periodic set model. To read disordered
1262
        structures anyway, choose either :code:`ordered_sites` to remove
1263
        atoms with disorder or :code:`all_sites` include all atoms
1264
        regardless of disorder.
1265
    heaviest_component : bool, optional
1266
        Removes all but the heaviest molecule in the asymmeric unit,
1267
        intended for removing solvents.
1268
    molecular_centres : bool, default False
1269
        Use molecular centres of mass as the motif instead of centres of
1270
        atoms.
1271
1272
    Returns
1273
    -------
1274
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1275
        Represents the crystal as a periodic set, consisting of a finite
1276
        set of points (motif) and lattice (unit cell). Contains other
1277
        useful data, e.g. the crystal's name and information about the
1278
        asymmetric unit for calculation.
1279
1280
    Raises
1281
    ------
1282
    ParseError
1283
        Raised if the structure fails parsing for any of the following:
1284
        1. :code:``disorder == 'skip'`` and disorder is found on any
1285
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
1286
        3. a.fractional_coordinates is None for any a in
1287
        crystal.disordered_molecule, 4. The motif is empty after
1288
        removing H, disordered sites or solvents.
1289
    """
1290
1291
    molecule = crystal.disordered_molecule
1292
1293
    # Disorder
1294
    if disorder == 'skip':
1295
        if crystal.has_disorder or \
1296
         any(_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
1297
            raise ParseError(
1298
                f"{crystal.identifier} has disorder, pass "
1299
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1300
                "disorder"
1301
            )
1302
    elif disorder == 'ordered_sites':
1303
        molecule.remove_atoms(
1304
            a for a in molecule.atoms if _has_disorder(a.label, a.occupancy)
1305
        )
1306
1307
    if remove_hydrogens:
1308
        molecule.remove_atoms(
1309
            a for a in molecule.atoms if a.atomic_symbol in 'HD'
1310
        )
1311
1312
    # Remove atoms with missing coordinates and warn
1313
    if any(a.fractional_coordinates is None for a in molecule.atoms):
1314
        warnings.warn('atoms without sites or missing data will be removed')
1315
        molecule.remove_atoms(
1316
            a for a in molecule.atoms if a.fractional_coordinates is None
1317
        )
1318
1319
    if heaviest_component and len(molecule.components) > 1:
1320
        molecule = _heaviest_component_ccdc(molecule)
1321
1322
    crystal.molecule = molecule
1323
    cellpar = crystal.cell_lengths + crystal.cell_angles
1324
    if None in cellpar:
1325
        raise ParseError(f'{crystal.identifier} has missing cell data')
1326
    cell = cellpar_to_cell(np.array(cellpar))
1327
1328
    if molecular_centres:
1329
        frac_centres = _frac_molecular_centres_ccdc(crystal, _EQ_SITE_TOL)
1330
        mol_centres = np.matmul(frac_centres, cell)
1331
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
1332
1333
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
1334
    # check for None?
1335
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
1336
1337
    if asym_unit.shape[0] == 0:
1338
        raise ParseError(f'{crystal.identifier} has no valid sites')
1339
1340
    asym_unit = np.mod(asym_unit, 1)
1341
1342
    # recommended by pymatgen
1343
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1344
1345
    asym_types = [a.atomic_number for a in asym_atoms]
1346
1347
    # Remove overlapping sites unless disorder == 'all_sites'
1348
    if disorder != 'all_sites':
1349
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1350
        if not np.all(keep_sites):
1351
            warnings.warn(
1352
                'may have overlapping sites; duplicates will be removed'
1353
            )
1354
        asym_unit = asym_unit[keep_sites]
1355
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1356
1357
    # Symmetry operations
1358
    sitesym = crystal.symmetry_operators
1359
    # try spacegroup numbers?
1360
    if not sitesym:
1361
        warnings.warn('no symmetry data found, defaulting to P1')
1362
        sitesym = ['x,y,z']
1363
1364
    # Apply symmetries to asymmetric unit
1365
    rot, trans = _parse_sitesyms(sitesym)
1366
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1367
    _, wyc_muls = np.unique(invs, return_counts=True)
1368
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1369
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1370
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1371
    motif = np.matmul(frac_motif, cell)
1372
1373
    return PeriodicSet(
1374
        motif=motif,
1375
        cell=cell,
1376
        name=crystal.identifier,
1377
        asym_unit=asym_inds,
1378
        multiplicities=wyc_muls,
1379
        types=types
1380
    )
1381
1382
1383
def memoize(f):
1384
    """Cache for _parse_sitesym()."""
1385
    cache = {}
1386
    def wrapper(arg):
1387
        if arg not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
1388
            cache[arg] = f(arg)
1389
        return cache[arg]
1390
    return wrapper
1391
1392
1393
@memoize
1394
def _parse_sitesym(sym: str) -> Tuple[np.ndarray, np.ndarray]:
1395
    """Parse a single symmetry as an xyz string and return a 3x3
1396
    rotation matrix and a 3x1 translation vector.
1397
    """
1398
1399
    rot = np.zeros((3, 3), dtype=np.float64)
1400
    trans = np.zeros((3, ), dtype=np.float64)
1401
1402
    for ind, element in enumerate(sym.split(',')):
1403
1404
        is_positive = True
1405
        is_fraction = False
1406
        sng_trans = None
1407
        fst_trans = []
1408
        snd_trans = []
1409
1410
        for char in element.lower():
1411
            if char == '+':
1412
                is_positive = True
1413
            elif char == '-':
1414
                is_positive = False
1415
            elif char == '/':
1416
                is_fraction = True
1417
            elif char in 'xyz':
1418
                rot_sgn = 1.0 if is_positive else -1.0
1419
                rot[ind][ord(char) - ord('x')] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1420
            elif char.isdigit() or char == '.':
1421
                if sng_trans is None:
1422
                    sng_trans = 1.0 if is_positive else -1.0
1423
                if is_fraction:
1424
                    snd_trans.append(char)
1425
                else:
1426
                    fst_trans.append(char)
1427
1428
        if not fst_trans:
1429
            e_trans = 0.0
1430
        else:
1431
            e_trans = sng_trans * float(''.join(fst_trans))
1432
1433
        if is_fraction:
1434
            e_trans /= float(''.join(snd_trans))
1435
1436
        trans[ind] = e_trans
1437
1438
    return rot, trans
1439
1440
1441
def _parse_sitesyms(symmetries: List[str]) -> Tuple[np.ndarray, np.ndarray]:
1442
    """Parse a sequence of symmetries in xyz form and return rotation
1443
    and translation arrays.
1444
    """
1445
    rotations = []
1446
    translations = []
1447
    for sym in symmetries:
1448
        rot, trans = _parse_sitesym(sym)
1449
        rotations.append(rot)
1450
        translations.append(trans)
1451
    return np.array(rotations), np.array(translations)
1452
1453
1454
def _expand_asym_unit(
1455
        asym_unit: np.ndarray,
1456
        rotations: np.ndarray,
1457
        translations: np.ndarray,
1458
        tol: float
1459
) -> Tuple[np.ndarray, np.ndarray]:
1460
    """Expand the asymmetric unit by applying symmetries given by
1461
    ``rotations`` and ``translations``.
1462
    """
1463
1464
    asym_unit = asym_unit.astype(np.float64, copy=False)
1465
    rotations = rotations.astype(np.float64, copy=False)
1466
    translations = translations.astype(np.float64, copy=False)
1467
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1468
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1469
1470
    if not all(_unique_sites(frac_motif, tol)):
1471
        frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1472
1473
    return frac_motif, invs
1474
1475
1476
@numba.njit(cache=True)
1477
def _expand_sites(
1478
        asym_unit: np.ndarray, rotations: np.ndarray, translations: np.ndarray
1479
) -> np.ndarray:
1480
    """Expand the asymmetric unit by applying ``rotations`` and
1481
    ``translations``, without yet removing points duplicated because
1482
    they are invariant under a symmetry. Returns a 3D array shape
1483
    (#points, #syms, dims).
1484
    """
1485
1486
    m, dims = asym_unit.shape
1487
    n_syms = len(rotations)
1488
    expanded_sites = np.empty((m, n_syms, dims), dtype=np.float64)
1489
    for i in range(m):
1490
        p = asym_unit[i]
1491
        for j in range(n_syms):
1492
            expanded_sites[i, j] = np.dot(rotations[j], p) + translations[j]
1493
    expanded_sites = np.mod(expanded_sites, 1)
1494
    return expanded_sites
1495
1496
1497
@numba.njit(cache=True)
1498
def _reduce_expanded_sites(
1499
        expanded_sites: np.ndarray, tol: float
1500
) -> Tuple[np.ndarray, np.ndarray]:
1501
    """Reduce the asymmetric unit after being expended by symmetries by
1502
    removing invariant points. This is the fast version which works in
1503
    the case that no two sites in the asymmetric unit are equivalent.
1504
    If they are, the reduction is re-ran with
1505
    _reduce_expanded_equiv_sites() to account for it.
1506
    """
1507
1508
    all_unqiue_inds = []
1509
    n_sites, _, dims = expanded_sites.shape
1510
    multiplicities = np.zeros(shape=(n_sites, ))
1511
1512
    for i, sites in enumerate(expanded_sites):
1513
        unique_inds = _unique_sites(sites, tol)
1514
        all_unqiue_inds.append(unique_inds)
1515
        multiplicities[i] = np.sum(unique_inds)
1516
1517
    m = int(np.sum(multiplicities))
1518
    frac_motif = np.zeros(shape=(m, dims))
1519
    inverses = np.zeros(shape=(m, ), dtype=np.int32)
1520
1521
    s = 0
1522
    for i in range(n_sites):
1523
        t = s + multiplicities[i]
1524
        frac_motif[s:t, :] = expanded_sites[i][all_unqiue_inds[i]]
1525
        inverses[s:t] = i
1526
        s = t
1527
1528
    return frac_motif, inverses
1529
1530
1531
def _reduce_expanded_equiv_sites(
1532
        expanded_sites: np.ndarray, tol: float
1533
) -> Tuple[np.ndarray, np.ndarray]:
1534
    """Reduce the asymmetric unit after being expended by symmetries by
1535
    removing invariant points. This is the slower version, called after
1536
    the fast version if we find equivalent motif points which need to be
1537
    removed.
1538
    """
1539
1540
    sites = expanded_sites[0]
1541
    unique_inds = _unique_sites(sites, tol)
1542
    frac_motif = sites[unique_inds]
1543
    inverses = [0] * len(frac_motif)
1544
1545
    for i in range(1, len(expanded_sites)):
1546
        sites = expanded_sites[i]
1547
        unique_inds = _unique_sites(sites, tol)
1548
1549
        points = []
1550
        for site in sites[unique_inds]:
1551
            diffs1 = np.abs(site - frac_motif)
1552
            diffs2 = np.abs(diffs1 - 1)
1553
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1554
1555
            if not np.any(mask):
1556
                points.append(site)
1557
            else:
1558
                warnings.warn(
1559
                    'has equivalent sites at positions '
1560
                    f'{inverses[np.argmax(mask)]}, {i}'
1561
                )
1562
1563
        if points:
1564
            inverses.extend(i for _ in range(len(points)))
1565
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1566
1567
    return frac_motif, np.array(inverses, dtype=np.int32)
1568
1569
1570
@numba.njit(cache=True)
1571
def _unique_sites(asym_unit: np.ndarray, tol: float) -> np.ndarray:
1572
    """Uniquify (within tol) a list of fractional coordinates,
1573
    considering all points modulo 1. Return an array of bools such that
1574
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1575
    """
1576
1577
    m, _ = asym_unit.shape
1578
    where_unique = np.full(shape=(m, ), fill_value=True)
1579
1580
    for i in range(1, m):
1581
        site_diffs1 = np.abs(asym_unit[:i, :] - asym_unit[i])
1582
        site_diffs2 = np.abs(site_diffs1 - 1)
1583
        sites_neq_mask = (site_diffs1 > tol) & (site_diffs2 > tol)
1584
        if not np.all(np.sum(sites_neq_mask, axis=-1)):
1585
            where_unique[i] = False
1586
1587
    return where_unique
1588
1589
1590
def _has_disorder(label: str, occupancy) -> bool:
1591
    """Return True if label ends with ? or occupancy is a number < 1."""
1592
    try:
1593
        occupancy = float(occupancy)
1594
    except Exception:
1595
        occupancy = 1
1596
    return (occupancy < 1) or label.endswith('?')
1597
1598
1599
def _process_symbols(symbols):
1600
    """Remove extra parts from atom symbols (_atom_site_type_symbol),
1601
    e.g. charges, and replace unknown values with empty strings.
1602
    """
1603
    symbols_ = []
1604
    for label in symbols:
1605
        sym = ''
1606
        if label and label not in ('.', '?'):
1607
            match = re.search(r'([A-Z][a-z]?)', label)
1608
            if match is not None:
1609
                sym = match.group() 
1610
        symbols_.append(sym)
1611
    return symbols_
1612
1613
1614
def _get_syms_pymatgen(data: dict) -> Tuple[np.ndarray, np.ndarray]:
1615
    """Parse symmetry operations given by data = block.data where block
1616
    is a pymatgen CifBlock object. If the symops are not present the
1617
    space group symbol/international number is parsed and symops are
1618
    generated.
1619
    """
1620
1621
    from pymatgen.symmetry.groups import SpaceGroup
1622
    import pymatgen.io.cif
1623
1624
    # Try xyz symmetry operations
1625
    for symmetry_label in _CIF_TAGS['symop']:
1626
        xyz = data.get(symmetry_label)
1627
        if not xyz:
1628
            continue
1629
        if isinstance(xyz, str):
1630
            xyz = [xyz]
1631
        return _parse_sitesyms(xyz)
1632
1633
    symops = []
1634
    # Try spacegroup symbol
1635
    for symmetry_label in _CIF_TAGS['spacegroup_name']:
1636
        sg = data.get(symmetry_label)
1637
        if not sg:
1638
            continue
1639
        sg = re.sub(r'[\s_]', '', sg)
1640
        try:
1641
            spg = pymatgen.io.cif.space_groups.get(sg)
1642
            if not spg:
1643
                continue
1644
            symops = SpaceGroup(spg).symmetry_ops
1645
            break
1646
        except ValueError:
1647
            pass
1648
        try:
1649
            for d in pymatgen.io.cif._get_cod_data():
1650
                if sg == re.sub(r'\s+', '', d['hermann_mauguin']):
1651
                    return _parse_sitesyms(d['symops'])
1652
        except Exception:
1653
            continue
1654
        if symops:
1655
            break
1656
1657
    # Try international number
1658
    if not symops:
1659
        for symmetry_label in _CIF_TAGS['spacegroup_number']:
1660
            num = data.get(symmetry_label)
1661
            if not num:
1662
                continue
1663
            try:
1664
                i = int(pymatgen.io.cif.str2float(num))
1665
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1666
                break
1667
            except ValueError:
1668
                continue
1669
1670
    if not symops:
1671
        warnings.warn('no symmetry data found, defaulting to P1')
1672
        return _parse_sitesyms(['x,y,z'])
1673
1674
    rotations = [op.rotation_matrix for op in symops]
1675
    translations = [op.translation_vector for op in symops]
1676
    rotations = np.array(rotations, dtype=np.float64)
1677
    translations = np.array(translations, dtype=np.float64)
1678
    return rotations, translations
1679
1680
1681
def _frac_molecular_centres_ccdc(crystal, tol: float) -> np.ndarray:
1682
    """Return the geometric centres of molecules in the unit cell.
1683
    Expects a ccdc Crystal object and returns fractional coordiantes.
1684
    """
1685
1686
    frac_centres = []
1687
    for comp in crystal.packing(inclusion='CentroidIncluded').components:
1688
        coords = [a.fractional_coordinates for a in comp.atoms]
1689
        frac_centres.append([sum(ax) / len(coords) for ax in zip(*coords)])
1690
    frac_centres = np.mod(np.array(frac_centres, dtype=np.float64), 1)
1691
    return frac_centres[_unique_sites(frac_centres, tol)]
1692
1693
1694
def _heaviest_component_ccdc(molecule):
1695
    """Remove all but the heaviest component of the asymmetric unit.
1696
    Intended for removing solvents. Expects and returns a ccdc Molecule
1697
    object.
1698
    """
1699
1700
    component_weights = []
1701
    for component in molecule.components:
1702
        weight = 0
1703
        for a in component.atoms:
1704
            try:
1705
                occ = float(a.occupancy)
1706
            except:
1707
                occ = 1
1708
            try:
1709
                weight += float(a.atomic_weight) * occ
1710
            except ValueError:
1711
                pass
1712
        component_weights.append(weight)
1713
    largest_component_ind = np.argmax(np.array(component_weights))
1714
    molecule = molecule.components[largest_component_ind]
1715
    return molecule
1716
1717
1718
def _snap_small_prec_coords(frac_coords: np.ndarray, tol: float) -> np.ndarray:
1719
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1720
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1721
    """
1722
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.
1723
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.
1724
    return frac_coords
1725