Passed
Push — master ( 8c12c2...4daa36 )
by Daniel
07:46
created

amd.io._heaviest_component_ccdc()   A

Complexity

Conditions 5

Size

Total Lines 22
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 22
rs 9.0833
c 0
b 0
f 0
cc 5
nop 1
1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import warnings
9
import os
10
import re
11
import functools
12
import errno
13
import math
14
import json
15
from pathlib import Path
16
from typing import (
17
    Iterable,
18
    Iterator,
19
    Optional,
20
    Union,
21
    Callable,
22
    Tuple,
23
    List
24
)
25
26
import numpy as np
27
import numba
28
import tqdm
29
30
from .utils import cellpar_to_cell
31
from .periodicset import PeriodicSet
32
33
__all__ = [
34
    'CifReader',
35
    'CSDReader',
36
    'ParseError',
37
    'periodicset_from_gemmi_block',
38
    'periodicset_from_ase_cifblock',
39
    'periodicset_from_pymatgen_cifblock',
40
    'periodicset_from_ase_atoms',
41
    'periodicset_from_pymatgen_structure',
42
    'periodicset_from_ccdc_entry',
43
    'periodicset_from_ccdc_crystal',
44
]
45
46
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
47
    return f'{category.__name__}: {message}\n'
48
49
warnings.formatwarning = _custom_warning
50
51
with open(str(Path(__file__).absolute().parent / 'atomic_numbers.json')) as f:
52
    _ATOMIC_NUMBERS = json.load(f)
53
54
_EQ_SITE_TOL: float = 1e-3
55
_CIF_TAGS: dict = {
56
    'cellpar': [
57
        '_cell_length_a',
58
        '_cell_length_b',
59
        '_cell_length_c',
60
        '_cell_angle_alpha',
61
        '_cell_angle_beta',
62
        '_cell_angle_gamma'
63
    ],
64
    'atom_site_fract': [
65
        '_atom_site_fract_x',
66
        '_atom_site_fract_y',
67
        '_atom_site_fract_z'
68
    ],
69
    'atom_site_cartn': [
70
        '_atom_site_Cartn_x',
71
        '_atom_site_Cartn_y',
72
        '_atom_site_Cartn_z'
73
    ],
74
    'symop': [
75
        '_space_group_symop_operation_xyz',
76
        '_space_group_symop.operation_xyz',
77
        '_symmetry_equiv_pos_as_xyz'
78
    ],
79
    'spacegroup_name': [
80
        '_space_group_name_H-M_alt',
81
        '_symmetry_space_group_name_H-M'
82
    ],
83
    'spacegroup_number': [
84
        '_space_group_IT_number',
85
        '_symmetry_Int_Tables_number'
86
    ],
87
}
88
89
90
class _Reader:
91
    """Base reader class."""
92
93
    def __init__(
94
            self,
95
            iterable: Iterable,
96
            converter: Callable[..., PeriodicSet],
97
            show_warnings: bool,
98
            verbose: bool
99
    ):
100
101
        self._iterator = iter(iterable)
102
        self._converter = converter
103
        self.show_warnings = show_warnings
104
        if verbose:
105
            self._progress_bar = tqdm.tqdm(desc='Reading', delay=1)
106
        else:
107
            self._progress_bar = None
108
109
    def __iter__(self):
110
        return self
111
112
    def __next__(self):
113
        """Iterate over self._iterator, passing items through
114
        self._converter and yielding. If
115
        :class:`ParseError <.io.ParseError>` is raised in a call to
116
        self._converter, the item is skipped. Warnings raised in
117
        self._converter are printed if self.show_warnings is True.
118
        """
119
120
        if not self.show_warnings:
121
            warnings.simplefilter('ignore')
122
123
        while True:
124
125
            try:
126
                item = next(self._iterator)
127
            except StopIteration:
128
                if self._progress_bar is not None:
129
                    self._progress_bar.close()
130
                raise StopIteration
131
132
            with warnings.catch_warnings(record=True) as warning_msgs:
133
                try:
134
                    periodic_set = self._converter(item)
135
                except ParseError as err:
136
                    warnings.warn(str(err))
137
                    continue
138
                finally:
139
                    if self._progress_bar is not None:
140
                        self._progress_bar.update(1)
141
142
            for warning in warning_msgs:
143
                msg = f'(name={periodic_set.name}) {warning.message}'
144
                warnings.warn(msg, category=warning.category)
145
146
            return periodic_set
147
148
    def read(self) -> Union[PeriodicSet, List[PeriodicSet]]:
149
        """Read the crystal(s), return one
150
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
151
        only one, otherwise return a list.
152
        """
153
        items = list(self)
154
        if len(items) == 1:
155
            return items[0]
156
        return items
157
158
159
class CifReader(_Reader):
160
    """Read all structures in a .cif file or all files in a folder
161
    with ase or csd-python-api (if installed), yielding
162
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
163
164
    Parameters
165
    ----------
166
    path : str
167
        Path to a .CIF file or directory. (Other files are accepted when
168
        using ``reader='ccdc'``, if csd-python-api is installed.)
169
    reader : str, optional
170
        The backend package used to parse the CIF. The default is
171
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
172
        accepted, as well as :code:`ccdc` if csd-python-api is
173
        installed. The ccdc reader should be able to read any format
174
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
175
        been tested.
176
    remove_hydrogens : bool, optional
177
        Remove Hydrogens from the crystals.
178
    disorder : str, optional
179
        Controls how disordered structures are handled. Default is
180
        ``skip`` which skips any crystal with disorder, since disorder
181
        conflicts with the periodic set model. To read disordered
182
        structures anyway, choose either :code:`ordered_sites` to remove
183
        atoms with disorder or :code:`all_sites` include all atoms
184
        regardless of disorder.
185
    heaviest_component : bool, optional
186
        csd-python-api only. Removes all but the heaviest molecule in
187
        the asymmeric unit, intended for removing solvents.
188
    molecular_centres : bool, default False
189
        csd-python-api only. Extract the centres of molecules in the
190
        unit cell and store in the attribute molecular_centres.
191
    show_warnings : bool, optional
192
        Controls whether warnings that arise during reading are printed.
193
    verbose : bool, default False
194
        If True, prints a progress bar showing the number of items
195
        processed.
196
197
    Yields
198
    ------
199
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
200
        Represents the crystal as a periodic set, consisting of a finite
201
        set of points (motif) and lattice (unit cell). Contains other
202
        data, e.g. the crystal's name and information about the
203
        asymmetric unit.
204
205
    Examples
206
    --------
207
208
        ::
209
210
            # Put all crystals in a .CIF in a list
211
            structures = list(amd.CifReader('mycif.cif'))
212
213
            # Can also accept path to a directory, reading all files inside
214
            structures = list(amd.CifReader('path/to/folder'))
215
216
            # Reads just one if the .CIF has just one crystal
217
            periodic_set = amd.CifReader('mycif.cif').read()
218
219
            # List of AMDs (k=100) of crystals in a .CIF
220
            amds = [amd.AMD(item, 100) for item in amd.CifReader('mycif.cif')]
221
    """
222
223
    def __init__(
224
            self,
225
            path: Union[str, os.PathLike],
226
            reader: str = 'gemmi',
227
            remove_hydrogens: bool = False,
228
            disorder: str = 'skip',
229
            heaviest_component: bool = False,
230
            molecular_centres: bool = False,
231
            show_warnings: bool = True,
232
            verbose: bool = False
233
    ):
234
235
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
236
            raise ValueError(
237
                f"'disorder'' parameter of {self.__class__.__name__} must be "
238
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
239
                f"'{disorder}')"
240
            )
241
242
        if reader != 'ccdc':
243
            if heaviest_component:
244
                raise NotImplementedError(
245
                    "'heaviest_component' parameter of "
246
                    f"{self.__class__.__name__} only implemented with "
247
                    "csd-python-api, if installed pass reader='ccdc'"
248
                )
249
            if molecular_centres:
250
                raise NotImplementedError(
251
                    "'molecular_centres' parameter of "
252
                    f"{self.__class__.__name__} only implemented with "
253
                    "csd-python-api, if installed pass reader='ccdc'"
254
                )
255
256
        # cannot handle some characters (�) in cifs
257
        if reader == 'gemmi':
258
            import gemmi
259
            extensions = {'cif'}
260
            file_parser = gemmi.cif.read_file
261
            converter = functools.partial(
262
                periodicset_from_gemmi_block,
263
                remove_hydrogens=remove_hydrogens,
264
                disorder=disorder
265
            )
266
267
        elif reader in ('ase', 'pycodcif'):
268
            from ase.io.cif import parse_cif
269
            extensions = {'cif'}
270
            file_parser = functools.partial(parse_cif, reader=reader)
271
            converter = functools.partial(
272
                periodicset_from_ase_cifblock,
273
                remove_hydrogens=remove_hydrogens,
274
                disorder=disorder
275
            )
276
277
        elif reader == 'pymatgen':
278
279
            def _pymatgen_cif_parser(path):
280
                from pymatgen.io.cif import CifFile
281
                return CifFile.from_file(path).data.values()
282
283
            extensions = {'cif'}
284
            file_parser = _pymatgen_cif_parser
285
            converter = functools.partial(
286
                periodicset_from_pymatgen_cifblock,
287
                remove_hydrogens=remove_hydrogens,
288
                disorder=disorder
289
            )
290
291
        elif reader == 'ccdc':
292
            try:
293
                import ccdc.io
294
            except (ImportError, RuntimeError) as e:
295
                raise ImportError('Failed to import csd-python-api') from e
296
297
            extensions = set(ccdc.io.EntryReader.known_suffixes.keys())
298
            file_parser = ccdc.io.EntryReader
299
            converter = functools.partial(
300
                periodicset_from_ccdc_entry,
301
                remove_hydrogens=remove_hydrogens,
302
                disorder=disorder,
303
                molecular_centres=molecular_centres,
304
                heaviest_component=heaviest_component
305
            )
306
307
        else:
308
            raise ValueError(
309
                f"'reader' parameter of {self.__class__.__name__} must be one "
310
                f"of 'gemmi', 'pymatgen', 'ccdc', 'ase', or 'pycodcif' "
311
                f"(passed '{reader}')"
312
            )
313
314
        path = Path(path)
315
        if path.is_file():
316
            iterable = file_parser(str(path))
317
        elif path.is_dir():
318
            iterable = CifReader._dir_generator(path, file_parser, extensions)
319
        else:
320
            raise FileNotFoundError(
321
                errno.ENOENT, os.strerror(errno.ENOENT), str(path)
322
            )
323
324
        super().__init__(iterable, converter, show_warnings, verbose)
325
326
    @staticmethod
327
    def _dir_generator(
328
            path: Path, file_parser: Callable, extensions: Iterable
329
    ) -> Iterator:
330
        for file_path in path.iterdir():
331
            if not file_path.is_file():
332
                continue
333
            if file_path.suffix[1:].lower() not in extensions:
334
                continue
335
            try:
336
                yield from file_parser(str(file_path))
337
            except Exception as e:
338
                warnings.warn(
339
                    f'Error parsing "{str(file_path)}", skipping file. '
340
                    f'Exception: {repr(e)}'
341
                )
342
343
344
class CSDReader(_Reader):
345
    """Read structures from the CSD with csd-python-api, yielding
346
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
347
348
    Parameters
349
    ----------
350
    refcodes : str or List[str], optional
351
        Single or list of CSD refcodes to read. If None or 'CSD',
352
        iterates over the whole CSD.
353
    families : bool, optional
354
        Read all entries whose refcode starts with the given strings, or
355
        'families' (e.g. giving 'DEBXIT' reads all entries starting with
356
        DEBXIT).
357
    remove_hydrogens : bool, optional
358
        Remove hydrogens from the crystals.
359
    disorder : str, optional
360
        Controls how disordered structures are handled. Default is
361
        ``skip`` which skips any crystal with disorder, since disorder
362
        conflicts with the periodic set model. To read disordered
363
        structures anyway, choose either :code:`ordered_sites` to remove
364
        atoms with disorder or :code:`all_sites` include all atoms
365
        regardless of disorder.
366
    heaviest_component : bool, optional
367
        Removes all but the heaviest molecule in the asymmeric unit,
368
        intended for removing solvents.
369
    molecular_centres : bool, default False
370
        Extract the centres of molecules in the unit cell and store in
371
        attribute molecular_centres.
372
    show_warnings : bool, optional
373
        Controls whether warnings that arise during reading are printed.
374
    verbose : bool, default False
375
        If True, prints a progress bar showing the number of items
376
        processed.
377
378
    Yields
379
    ------
380
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
381
        Represents the crystal as a periodic set, consisting of a finite
382
        set of points (motif) and lattice (unit cell). Contains other
383
        useful data, e.g. the crystal's name and information about the
384
        asymmetric unit for calculation.
385
386
    Examples
387
    --------
388
389
        ::
390
391
            # Put these entries in a list
392
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
393
            structures = list(amd.CSDReader(refcodes))
394
395
            # Read refcode families (any whose refcode starts with strings in the list)
396
            refcode_families = ['ACSALA', 'HXACAN']
397
            structures = list(amd.CSDReader(refcode_families, families=True))
398
399
            # Get AMDs (k=100) for crystals in these families
400
            refcodes = ['ACSALA', 'HXACAN']
401
            amds = []
402
            for periodic_set in amd.CSDReader(refcodes, families=True):
403
                amds.append(amd.AMD(periodic_set, 100))
404
405
            # Giving the reader nothing reads from the whole CSD.
406
            for periodic_set in amd.CSDReader():
407
                ...
408
    """
409
410
    def __init__(
411
            self,
412
            refcodes: Optional[Union[str, List[str]]] = None,
413
            families: bool = False,
414
            remove_hydrogens: bool = False,
415
            disorder: str = 'skip',
416
            heaviest_component: bool = False,
417
            molecular_centres: bool = False,
418
            show_warnings: bool = True,
419
            verbose: bool = False
420
    ):
421
422
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
423
            raise ValueError(
424
                f"'disorder'' parameter of {self.__class__.__name__} must be "
425
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
426
                f"'{disorder}')"
427
            )
428
429
        try:
430
            import ccdc.search
431
            import ccdc.io
432
        except (ImportError, RuntimeError) as e:
433
            raise ImportError('Failed to import csd-python-api') from e
434
435
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
436
            refcodes = None
437
        if refcodes is None:
438
            families = False
439
        elif isinstance(refcodes, str):
440
            refcodes = [refcodes]
441
        elif isinstance(refcodes, list):
442
            if not all(isinstance(refcode, str) for refcode in refcodes):
443
                raise ValueError(
444
                    f'{self.__class__.__name__} expects None, a string or '
445
                    'list of strings.'
446
                )
447
        else:
448
            raise ValueError(
449
                f'{self.__class__.__name__} expects None, a string or list of '
450
                f'strings, got {refcodes.__class__.__name__}'
451
            )
452
453
        if families:
454
            all_refcodes = []
455
            for refcode in refcodes:
456
                query = ccdc.search.TextNumericSearch()
457
                query.add_identifier(refcode)
458
                hits = [hit.identifier for hit in query.search()]
459
                all_refcodes.extend(hits)
460
            # filter to unique refcodes while keeping order
461
            refcodes = []
462
            seen = set()
463
            for refcode in all_refcodes:
464
                if refcode not in seen:
465
                    refcodes.append(refcode)
466
                    seen.add(refcode)
467
468
        converter = functools.partial(
469
            periodicset_from_ccdc_entry,
470
            remove_hydrogens=remove_hydrogens,
471
            disorder=disorder,
472
            molecular_centres=molecular_centres,
473
            heaviest_component=heaviest_component
474
        )
475
476
        entry_reader = ccdc.io.EntryReader('CSD')
477
        if refcodes is None:
478
            iterable = entry_reader
479
        else:
480
            iterable = map(entry_reader.entry, refcodes)
481
482
        super().__init__(iterable, converter, show_warnings, verbose)
483
484
485
class ParseError(ValueError):
486
    """Raised when an item cannot be parsed into a periodic set."""
487
    pass
488
489
490
def periodicset_from_gemmi_block(
491
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
492
) -> PeriodicSet:
493
    """Convert a :class:`gemmi.cif.Block` object to a
494
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
495
    :class:`gemmi.cif.Block` is the type returned by
496
    :func:`gemmi.cif.read_file`.
497
498
    Parameters
499
    ----------
500
    block : :class:`gemmi.cif.Block`
501
        An ase CIFBlock object representing a crystal.
502
    remove_hydrogens : bool, optional
503
        Remove Hydrogens from the crystal.
504
    disorder : str, optional
505
        Controls how disordered structures are handled. Default is
506
        ``skip`` which skips any crystal with disorder, since disorder
507
        conflicts with the periodic set model. To read disordered
508
        structures anyway, choose either :code:`ordered_sites` to remove
509
        atoms with disorder or :code:`all_sites` include all atoms
510
        regardless of disorder.
511
512
    Returns
513
    -------
514
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
515
        Represents the crystal as a periodic set, consisting of a finite
516
        set of points (motif) and lattice (unit cell). Contains other
517
        useful data, e.g. the crystal's name and information about the
518
        asymmetric unit for calculation.
519
520
    Raises
521
    ------
522
    ParseError
523
        Raised if the structure fails to be parsed for any of the
524
        following: 1. Required data is missing (e.g. cell parameters),
525
        2. :code:``disorder == 'skip'`` and disorder is found on any
526
        atom, 3. The motif is empty after removing H or disordered
527
        sites.
528
    """
529
530
    import gemmi
531
    from gemmi.cif import as_number, as_string, as_int
532
533
    # Unit cell
534
    cellpar = [block.find_value(t) for t in _CIF_TAGS['cellpar']]
535
    if not all(isinstance(par, str) for par in cellpar):
536
        raise ParseError(f'{block.name} has missing cell data')
537
    cellpar = np.array([as_number(par) for par in cellpar])
538
    if np.isnan(np.sum(cellpar)):
539
        raise ParseError(f'{block.name} has missing cell data')
540
    cell = cellpar_to_cell(cellpar)
541
542
    # Asymmetric unit coordinates
543
    xyz_loop = block.find(_CIF_TAGS['atom_site_fract']).loop
544
    if xyz_loop is None:
545
        xyz_loop = block.find(_CIF_TAGS['atom_site_cartn']).loop
546
        if xyz_loop is None:
547
            raise ParseError(f'{block.name} has missing coordinate data')
548
        else:
549
            raise ParseError(
550
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
551
                'only _atom_site_fract_ is supported'
552
            )
553
554
    tablified_loop = [[] for _ in range(len(xyz_loop.tags))]
555
    for i, item in enumerate(xyz_loop.values):
556
        tablified_loop[i % xyz_loop.width()].append(item)
557
    loop_dict = {tag: l for tag, l in zip(xyz_loop.tags, tablified_loop)}
558
    xyz_str = [loop_dict[t] for t in _CIF_TAGS['atom_site_fract']]
559
    asym_unit = np.transpose(np.array(
560
        [[as_number(c) for c in xyz] for xyz in xyz_str]
561
    ))
562
    asym_unit = np.mod(asym_unit, 1)
563
564
    # recommended by pymatgen
565
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4) 
566
567
    # Labels
568
    if '_atom_site_label' in loop_dict:
569
        labels = [as_string(label) for label in loop_dict['_atom_site_label']]
570
    else:
571
        labels = [''] * xyz_loop.length()
572
573
    # Atomic types
574
    if '_atom_site_type_symbol' in loop_dict:
575
        symbols = [as_string(s) for s in loop_dict['_atom_site_type_symbol']]
576
    else:  # Get atomic types from label
577
        symbols = []
578
        for label in labels:
579
            sym = ''
580
            if label:
581
                match = re.search(r'([A-Z][a-z]?)', label)
582
                if match is not None:
583
                    sym = match.group() 
584
            symbols.append(sym)
585
    asym_types = [_ATOMIC_NUMBERS[s] for s in symbols]
586
587
    # Occupancies
588
    if '_atom_site_occupancy' in loop_dict:
589
        occs = [as_number(occ) for occ in loop_dict['_atom_site_occupancy']]
590
        occupancies = [occ if not math.isnan(occ) else 1 for occ in occs]
591
    else:
592
        occupancies = [1] * xyz_loop.length()
593
594
    # Remove sites with missing coordinates, disorder and Hydrogens if needed
595
    remove_sites = []
596
    remove_sites.extend(np.nonzero(np.isnan(asym_unit.min(axis=-1)))[0])
597
598
    if disorder == 'skip':
599
        if any(_has_disorder(l, o) for l, o in zip(labels, occupancies)):
600
            raise ParseError(
601
                f"{block.name} has disorder, pass disorder='ordered_sites' or "
602
                "'all_sites' to remove/ignore disorder"
603
            )
604
    elif disorder == 'ordered_sites':
605
        for i, (label, occ) in enumerate(zip(labels, occupancies)):
606
            if _has_disorder(label, occ):
607
                remove_sites.append(i)
608
609
    if remove_hydrogens:
610
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
611
612
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
613
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
614
    if asym_unit.shape[0] == 0:
615
        raise ParseError(f'{block.name} has no valid sites')
616
617
    if disorder != 'all_sites':
618
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
619
        if not np.all(keep_sites):
620
            warnings.warn(
621
                'may have overlapping sites; duplicates will be removed'
622
            )
623
        asym_unit = asym_unit[keep_sites]
624
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
625
626
    # Symmetry operations, try xyz strings first
627
    for tag in _CIF_TAGS['symop']:
628
        sitesym = [v.str(0) for v in block.find([tag])]
629
        if sitesym:
630
            rot, trans = _parse_sitesyms(sitesym)
631
            break
632
    else:
633
        # Try spacegroup name; can be a pair or in a loop
634
        spg = None
635
        for tag in _CIF_TAGS['spacegroup_name']:
636
            for value in block.find([tag]):
637
                try:
638
                    # Some names cannot be parsed by gemmi.SpaceGroup
639
                    spg = gemmi.SpaceGroup(value.str(0))
640
                    break
641
                except ValueError:
642
                    continue
643
            if spg is not None:
644
                break
645
646
        if spg is None:
647
            # Try international number
648
            for tag in _CIF_TAGS['spacegroup_number']:
649
                spg_num = block.find_value(tag)
650
                if spg_num is not None:
651
                    spg_num = as_int(spg_num)
652
                    break
653
            else:
654
                warnings.warn('no symmetry data found, defaulting to P1')
655
                spg_num = 1
656
            spg = gemmi.SpaceGroup(spg_num)
657
658
        rot = np.array([np.array(o.rot) / o.DEN for o in spg.operations()])
659
        trans = np.array([np.array(o.tran) / o.DEN for o in spg.operations()])
660
661
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
662
    _, wyc_muls = np.unique(invs, return_counts=True)
663
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
664
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
665
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
666
    motif = np.matmul(frac_motif, cell)
667
668
    return PeriodicSet(
669
        motif=motif,
670
        cell=cell,
671
        name=block.name,
672
        asym_unit=asym_inds,
673
        multiplicities=wyc_muls,
674
        types=types
675
    )
676
677
678
def periodicset_from_ase_cifblock(
679
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
680
) -> PeriodicSet:
681
    """Convert a :class:`ase.io.cif.CIFBlock` object to a 
682
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
683
    :class:`ase.io.cif.CIFBlock` is the type returned by
684
    :func:`ase.io.cif.parse_cif`.
685
686
    Parameters
687
    ----------
688
    block : :class:`ase.io.cif.CIFBlock`
689
        An ase :class:`ase.io.cif.CIFBlock` object representing a
690
        crystal.
691
    remove_hydrogens : bool, optional
692
        Remove Hydrogens from the crystal.
693
    disorder : str, optional
694
        Controls how disordered structures are handled. Default is
695
        ``skip`` which skips any crystal with disorder, since disorder
696
        conflicts with the periodic set model. To read disordered
697
        structures anyway, choose either :code:`ordered_sites` to remove
698
        atoms with disorder or :code:`all_sites` include all atoms
699
        regardless of disorder.
700
701
    Returns
702
    -------
703
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
704
        Represents the crystal as a periodic set, consisting of a finite
705
        set of points (motif) and lattice (unit cell). Contains other
706
        useful data, e.g. the crystal's name and information about the
707
        asymmetric unit for calculation.
708
709
    Raises
710
    ------
711
    ParseError
712
        Raised if the structure fails to be parsed for any of the
713
        following: 1. Required data is missing (e.g. cell parameters),
714
        2. The motif is empty after removing H or disordered sites,
715
        3. :code:``disorder == 'skip'`` and disorder is found on any
716
        atom.
717
    """
718
719
    import ase
720
    import ase.spacegroup
721
722
    # Unit cell
723
    cellpar = [block.get(tag) for tag in _CIF_TAGS['cellpar']]
724
    if None in cellpar:
725
        raise ParseError(f'{block.name} has missing cell data')
726
    cell = cellpar_to_cell(np.array(cellpar))
727
728
    # Asymmetric unit coordinates. ase removes uncertainty brackets
729
    asym_unit = [block.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
730
    if None in asym_unit:
731
        asym_unit = [
732
            block.get(tag.lower()) for tag in _CIF_TAGS['atom_site_cartn']
733
        ]
734
        if None in asym_unit:
735
            raise ParseError(f'{block.name} has missing coordinates')
736
        else:
737
            raise ParseError(
738
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
739
                'only _atom_site_fract_ is supported'
740
            )
741
    asym_unit = list(zip(*asym_unit))
742
743
    # Labels
744
    asym_labels = block.get('_atom_site_label')
745
    if asym_labels is None:
746
        asym_labels = [''] * len(asym_unit)
747
748
    # Atomic types
749
    asym_symbols = block.get('_atom_site_type_symbol')
750 View Code Duplication
    if asym_symbols is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
751
        asym_symbols_ = []
752
        for label in asym_symbols:
753
            sym = ''
754
            if label and label not in ('.', '?'):
755
                match = re.search(r'([A-Z][a-z]?)', label)
756
                if match is not None:
757
                    sym = match.group() 
758
            asym_symbols_.append(sym)
759
    else:
760
        asym_symbols_ = [''] * len(asym_unit)
761
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
762
763
    # Find where sites have disorder if necassary
764
    has_disorder = []
765
    if disorder != 'all_sites':
766
        occupancies = block.get('_atom_site_occupancy')
767
        if occupancies is None:
768
            occupancies = [1] * len(asym_unit)
769
        for lab, occ in zip(asym_labels, occupancies):
770
            has_disorder.append(_has_disorder(lab, occ))
771
772
    # Remove sites with ?, . or other invalid string for coordinates
773
    invalid = []
774
    for i, xyz in enumerate(asym_unit):
775
        if not all(isinstance(coord, (int, float)) for coord in xyz):
776
            invalid.append(i)
777
    if invalid:
778
        warnings.warn('atoms without sites or missing data will be removed')
779
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
780
        asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
781
        if disorder != 'all_sites':
782
            has_disorder = [
783
                d for i, d in enumerate(has_disorder) if i not in invalid
784
            ]
785
786
    remove_sites = []
787
788
    if remove_hydrogens:
789
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
790
791
    # Remove atoms with fractional occupancy or raise ParseError
792 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
793
        for i, dis in enumerate(has_disorder):
794
            if i in remove_sites:
795
                continue
796
            if dis:
797
                if disorder == 'skip':
798
                    raise ParseError(
799
                        f'{block.name} has disorder, pass '
800
                        "disorder='ordered_sites' or 'all_sites' to "
801
                        'remove/ignore disorder'
802
                    )
803
                elif disorder == 'ordered_sites':
804
                    remove_sites.append(i)
805
806
    # Asymmetric unit
807
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
808
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
809
    if len(asym_unit) == 0:
810
        raise ParseError(f'{block.name} has no valid sites')
811
    asym_unit = np.mod(np.array(asym_unit), 1)
812
813
    # recommended by pymatgen
814
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
815
816
    # Remove overlapping sites unless disorder == 'all_sites'
817
    if disorder != 'all_sites':
818
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
819
        if not np.all(keep_sites):
820
            warnings.warn(
821
                'may have overlapping sites, duplicates will be removed'
822
            )
823
            asym_unit = asym_unit[keep_sites]
824
            asym_types = [t for t, keep in zip(asym_types, keep_sites) if keep]
825
826
    # Get symmetry operations
827
    sitesym = block._get_any(_CIF_TAGS['symop'])
828
    if sitesym is None:
829
        label_or_num = block._get_any(
830
            [s.lower() for s in _CIF_TAGS['spacegroup_name']]
831
        )
832
        if label_or_num is None:
833
            label_or_num = block._get_any(
834
                [s.lower() for s in _CIF_TAGS['spacegroup_number']]
835
            )
836
        if label_or_num is None:
837
            warnings.warn('no symmetry data found, defaulting to P1')
838
            label_or_num = 1
839
        spg = ase.spacegroup.Spacegroup(label_or_num)
840
        rot, trans = spg.get_op()
841
    else:
842
        if isinstance(sitesym, str):
843
            sitesym = [sitesym]
844
        rot, trans = _parse_sitesyms(sitesym)
845
846
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
847
    _, wyc_muls = np.unique(invs, return_counts=True)
848
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
849
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
850
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
851
    motif = np.matmul(frac_motif, cell)
852
853
    return PeriodicSet(
854
        motif=motif,
855
        cell=cell,
856
        name=block.name,
857
        asym_unit=asym_inds,
858
        multiplicities=wyc_muls,
859
        types=types
860
    )
861
862
863
def periodicset_from_pymatgen_cifblock(
864
        block, remove_hydrogens: bool = False, disorder: str = 'skip'
865
) -> PeriodicSet:
866
    """Convert a :class:`pymatgen.io.cif.CifBlock` object to a
867
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
868
    :class:`pymatgen.io.cif.CifBlock` is the type returned by
869
    :class:`pymatgen.io.cif.CifFile`.
870
871
    Parameters
872
    ----------
873
    block : :class:`pymatgen.io.cif.CifBlock`
874
        A pymatgen CifBlock object representing a crystal.
875
    remove_hydrogens : bool, optional
876
        Remove Hydrogens from the crystal.
877
    disorder : str, optional
878
        Controls how disordered structures are handled. Default is
879
        ``skip`` which skips any crystal with disorder, since disorder
880
        conflicts with the periodic set model. To read disordered
881
        structures anyway, choose either :code:`ordered_sites` to remove
882
        atoms with disorder or :code:`all_sites` include all atoms
883
        regardless of disorder.
884
885
    Returns
886
    -------
887
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
888
        Represents the crystal as a periodic set, consisting of a finite
889
        set of points (motif) and lattice (unit cell). Contains other
890
        useful data, e.g. the crystal's name and information about the
891
        asymmetric unit for calculation.
892
893
    Raises
894
    ------
895
    ParseError
896
        Raised if the structure can/should not be parsed for the
897
        following reasons: 1. No sites found or motif is empty after
898
        removing Hydrogens & disorder, 2. A site has missing
899
        coordinates, 3. :code:``disorder == 'skip'`` and disorder is
900
        found on any atom.
901
    """
902
903
    from pymatgen.io.cif import str2float
904
905
    odict = block.data
906
907
    # Unit cell
908
    cellpar = [odict.get(tag) for tag in _CIF_TAGS['cellpar']]
909
    if any(par in (None, '?', '.') for par in cellpar):
910
        raise ParseError(f'{block.header} has missing cell data')
911
    cell = cellpar_to_cell(
912
        np.array([str2float(v) for v in cellpar], dtype=np.float64)
913
    )
914
915
    # Asymmetric unit coordinates
916
    asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
917
    # check for . and ?
918
    if None in asym_unit:
919
        asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_cartn']]
920
        if None in asym_unit:
921
            raise ParseError(f'{block.header} has missing coordinates')
922
        else:
923
            raise ParseError(
924
                f'{block.header} uses _atom_site_Cartn_ tags for coordinates, '
925
                'only _atom_site_fract_ is supported'
926
            )
927
    asym_unit = list(zip(*asym_unit))
928
    asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
929
930
    # Labels
931
    asym_labels = odict.get('_atom_site_label')
932
    if asym_labels is None:
933
        asym_labels = [''] * len(asym_unit)
934
935
    # Atomic types
936
    asym_symbols = odict.get('_atom_site_type_symbol')
937 View Code Duplication
    if asym_symbols is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
938
        asym_symbols_ = []
939
        for label in asym_symbols:
940
            sym = ''
941
            if label and label not in ('.', '?'):
942
                match = re.search(r'([A-Z][a-z]?)', label)
943
                if match is not None:
944
                    sym = match.group() 
945
            asym_symbols_.append(sym)
946
    else:
947
        asym_symbols_ = [''] * len(asym_unit)
948
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
949
950
    # Find where sites have disorder if necassary
951
    has_disorder = []
952
    if disorder != 'all_sites':
953
        occupancies = odict.get('_atom_site_occupancy')
954
        if occupancies is None:
955
            occupancies = np.ones((len(asym_unit), ))
956
        else:
957
            occupancies = np.array([str2float(occ) for occ in occupancies])
958
        labels = odict.get('_atom_site_label')
959
        if labels is None:
960
            labels = [''] * len(asym_unit)
961
        for lab, occ in zip(labels, occupancies):
962
            has_disorder.append(_has_disorder(lab, occ))
963
964
    # Remove sites with ?, . or other invalid string for coordinates
965
    invalid = []
966
    for i, xyz in enumerate(asym_unit):
967
        if not all(isinstance(coord, (int, float)) for coord in xyz):
968
            invalid.append(i)
969
970
    if invalid:
971
        warnings.warn('atoms without sites or missing data will be removed')
972
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
973
        asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
974
        if disorder != 'all_sites':
975
            has_disorder = [
976
                d for i, d in enumerate(has_disorder) if i not in invalid
977
            ]
978
979
    remove_sites = []
980
981
    if remove_hydrogens:
982
        remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
0 ignored issues
show
introduced by
The variable i does not seem to be defined in case the for loop on line 966 is not entered. Are you sure this can never be the case?
Loading history...
983
984
    # Remove atoms with fractional occupancy or raise ParseError
985 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
986
        for i, dis in enumerate(has_disorder):
987
            if i in remove_sites:
988
                continue
989
            if dis:
990
                if disorder == 'skip':
991
                    raise ParseError(
992
                        f'{block.header} has disorder, pass '
993
                        "disorder='ordered_sites' or 'all_sites' to "
994
                        'remove/ignore disorder'
995
                    )
996
                elif disorder == 'ordered_sites':
997
                    remove_sites.append(i)
998
999
    # Asymmetric unit
1000
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
1001
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
1002
    if len(asym_unit) == 0:
1003
        raise ParseError(f'{block.header} has no valid sites')
1004
    asym_unit = np.mod(np.array(asym_unit), 1)
1005
1006
    # recommended by pymatgen
1007
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1008
1009
    # Remove overlapping sites unless disorder == 'all_sites'
1010
    if disorder != 'all_sites':
1011
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1012
        if not np.all(keep_sites):
1013
            warnings.warn(
1014
                'may have overlapping sites; duplicates will be removed'
1015
            )
1016
        asym_unit = asym_unit[keep_sites]
1017
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1018
1019
    # Apply symmetries to asymmetric unit
1020
    rot, trans = _get_syms_pymatgen(odict)
1021
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1022
    _, wyc_muls = np.unique(invs, return_counts=True)
1023
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1024
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1025
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1026
    motif = np.matmul(frac_motif, cell)
1027
1028
    return PeriodicSet(
1029
        motif=motif,
1030
        cell=cell,
1031
        name=block.header,
1032
        asym_unit=asym_inds,
1033
        multiplicities=wyc_muls,
1034
        types=types
1035
    )
1036
1037
1038
def periodicset_from_ase_atoms(
1039
        atoms, remove_hydrogens: bool = False
1040
) -> PeriodicSet:
1041
    """Convert an :class:`ase.atoms.Atoms` object to a
1042
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
1043
    the option to remove disorder.
1044
1045
    Parameters
1046
    ----------
1047
    atoms : :class:`ase.atoms.Atoms`
1048
        An ase :class:`ase.atoms.Atoms` object representing a crystal.
1049
    remove_hydrogens : bool, optional
1050
        Remove Hydrogens from the crystal.
1051
1052
    Returns
1053
    -------
1054
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1055
        Represents the crystal as a periodic set, consisting of a finite
1056
        set of points (motif) and lattice (unit cell). Contains other
1057
        useful data, e.g. the crystal's name and information about the
1058
        asymmetric unit for calculation.
1059
1060
    Raises
1061
    ------
1062
    ParseError
1063
        Raised if there are no valid sites in atoms.
1064
    """
1065
1066
    from ase.spacegroup import get_basis
1067
1068
    cell = atoms.get_cell().array
1069
1070
    remove_inds = []
1071
    if remove_hydrogens:
1072
        for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
1073
            remove_inds.append(i)
1074
    for i in sorted(remove_inds, reverse=True):
1075
        atoms.pop(i)
1076
1077
    if len(atoms) == 0:
1078
        raise ParseError('ase Atoms object has no valid sites')
1079
1080
    # Symmetry operations from spacegroup
1081
    spg = None
1082
    if 'spacegroup' in atoms.info:
1083
        spg = atoms.info['spacegroup']
1084
        rot, trans = spg.rotations, spg.translations
1085
    else:
1086
        warnings.warn('no symmetry data found, defaulting to P1')
1087
        rot = np.identity(3)[None, :]
1088
        trans = np.zeros((1, 3))
1089
1090
    # Asymmetric unit. ase default tol is 1e-5
1091
    # do differently! get_basis determines a reduced asym unit from the atoms;
1092
    # surely this is not needed!
1093
    asym_unit = get_basis(atoms, spacegroup=spg, tol=_EQ_SITE_TOL)
1094
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1095
    _, wyc_muls = np.unique(invs, return_counts=True)
1096
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1097
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1098
    types = atoms.get_atomic_numbers().astype(np.uint8)
1099
    motif = np.matmul(frac_motif, cell)
1100
1101
    return PeriodicSet(
1102
        motif=motif,
1103
        cell=cell,
1104
        asym_unit=asym_inds,
1105
        multiplicities=wyc_muls,
1106
        types=types
1107
    )
1108
1109
1110
def periodicset_from_pymatgen_structure(
1111
        structure, remove_hydrogens: bool = False, disorder: str = 'skip'
1112
) -> PeriodicSet:
1113
    """Convert a :class:`pymatgen.core.structure.Structure` object to a
1114
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not set
1115
    the name of the periodic set, as pymatgen Structure objects seem to
1116
    have no name attribute.
1117
1118
    Parameters
1119
    ----------
1120
    structure : :class:`pymatgen.core.structure.Structure`
1121
        A pymatgen Structure object representing a crystal.
1122
    remove_hydrogens : bool, optional
1123
        Remove Hydrogens from the crystal.
1124
    disorder : str, optional
1125
        Controls how disordered structures are handled. Default is
1126
        ``skip`` which skips any crystal with disorder, since disorder
1127
        conflicts with the periodic set model. To read disordered
1128
        structures anyway, choose either :code:`ordered_sites` to remove
1129
        atoms with disorder or :code:`all_sites` include all atoms
1130
        regardless of disorder.
1131
1132
    Returns
1133
    -------
1134
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1135
        Represents the crystal as a periodic set, consisting of a finite
1136
        set of points (motif) and lattice (unit cell). Contains other
1137
        useful data, e.g. the crystal's name and information about the
1138
        asymmetric unit for calculation.
1139
1140
    Raises
1141
    ------
1142
    ParseError
1143
        Raised if the :code:`disorder == 'skip'` and
1144
        :code:`not structure.is_ordered`
1145
    """
1146
1147
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1148
1149
    if remove_hydrogens:
1150
        structure.remove_species(['H', 'D'])
1151
1152
    # Disorder
1153
    if disorder == 'skip':
1154
        if not structure.is_ordered:
1155
            raise ParseError(
1156
                'pymatgen Structure has disorder, pass '
1157
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1158
                'disorder'
1159
            )
1160
    elif disorder == 'ordered_sites':
1161
        remove_inds = []
1162
        for i, comp in enumerate(structure.species_and_occu):
1163
            if comp.num_atoms < 1:
1164
                remove_inds.append(i)
1165
        structure.remove_sites(remove_inds)
1166
1167
    motif = structure.cart_coords
1168
    cell = structure.lattice.matrix
1169
    sym_structure = SpacegroupAnalyzer(structure).get_symmetrized_structure()
1170
    eq_inds = sym_structure.equivalent_indices
1171
    asym_inds = np.array([ix_list[0] for ix_list in eq_inds], dtype=np.int32)
1172
    wyc_muls = np.array([len(ix_list) for ix_list in eq_inds], dtype=np.int32)
1173
    types = np.array(sym_structure.atomic_numbers, dtype=np.uint8)
1174
1175
    return PeriodicSet(
1176
        motif=motif,
1177
        cell=cell,
1178
        asym_unit=asym_inds,
1179
        multiplicities=wyc_muls,
1180
        types=types
1181
    )
1182
1183
1184
def periodicset_from_ccdc_entry(
1185
        entry,
1186
        remove_hydrogens: bool = False,
1187
        disorder: str = 'skip',
1188
        heaviest_component: bool = False,
1189
        molecular_centres: bool = False
1190
) -> PeriodicSet:
1191
    """Convert a :class:`ccdc.entry.Entry` object to a
1192
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1193
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
1194
1195
    Parameters
1196
    ----------
1197
    entry : :class:`ccdc.entry.Entry`
1198
        A ccdc Entry object representing a database entry.
1199
    remove_hydrogens : bool, optional
1200
        Remove Hydrogens from the crystal.
1201
    disorder : str, optional
1202
        Controls how disordered structures are handled. Default is
1203
        ``skip`` which skips any crystal with disorder, since disorder
1204
        conflicts with the periodic set model. To read disordered
1205
        structures anyway, choose either :code:`ordered_sites` to remove
1206
        atoms with disorder or :code:`all_sites` include all atoms
1207
        regardless of disorder.
1208
    heaviest_component : bool, optional
1209
        Removes all but the heaviest molecule in the asymmeric unit,
1210
        intended for removing solvents.
1211
    molecular_centres : bool, default False
1212
        Use molecular centres of mass as the motif instead of centres of
1213
        atoms.
1214
1215
    Returns
1216
    -------
1217
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1218
        Represents the crystal as a periodic set, consisting of a finite
1219
        set of points (motif) and lattice (unit cell). Contains other
1220
        useful data, e.g. the crystal's name and information about the
1221
        asymmetric unit for calculation.
1222
1223
    Raises
1224
    ------
1225
    ParseError
1226
        Raised if the structure fails parsing for any of the following:
1227
        1. entry.has_3d_structure is False, 2.
1228
        :code:``disorder == 'skip'`` and disorder is found on any atom,
1229
        3. entry.crystal.molecule.all_atoms_have_sites is False,
1230
        4. a.fractional_coordinates is None for any a in
1231
        entry.crystal.disordered_molecule, 5. The motif is empty after
1232
        removing Hydrogens and disordered sites.
1233
    """
1234
1235
    # Entry specific flag
1236
    if not entry.has_3d_structure:
1237
        raise ParseError(f'{entry.identifier} has no 3D structure')
1238
1239
    # Disorder
1240
    if disorder == 'skip' and entry.has_disorder:
1241
        raise ParseError(
1242
            f"{entry.identifier} has disorder, pass disorder='ordered_sites' "
1243
            "or 'all_sites' to remove/ignore disorder"
1244
        )
1245
1246
    return periodicset_from_ccdc_crystal(
1247
        entry.crystal,
1248
        remove_hydrogens=remove_hydrogens,
1249
        disorder=disorder,
1250
        heaviest_component=heaviest_component,
1251
        molecular_centres=molecular_centres
1252
    )
1253
1254
1255
def periodicset_from_ccdc_crystal(
1256
        crystal,
1257
        remove_hydrogens: bool = False,
1258
        disorder: str = 'skip',
1259
        heaviest_component: bool = False,
1260
        molecular_centres: bool = False
1261
) -> PeriodicSet:
1262
    """Convert a :class:`ccdc.crystal.Crystal` object to a
1263
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1264
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
1265
1266
    Parameters
1267
    ----------
1268
    crystal : :class:`ccdc.crystal.Crystal`
1269
        A ccdc Crystal object representing a crystal structure.
1270
    remove_hydrogens : bool, optional
1271
        Remove Hydrogens from the crystal.
1272
    disorder : str, optional
1273
        Controls how disordered structures are handled. Default is
1274
        ``skip`` which skips any crystal with disorder, since disorder
1275
        conflicts with the periodic set model. To read disordered
1276
        structures anyway, choose either :code:`ordered_sites` to remove
1277
        atoms with disorder or :code:`all_sites` include all atoms
1278
        regardless of disorder.
1279
    heaviest_component : bool, optional
1280
        Removes all but the heaviest molecule in the asymmeric unit,
1281
        intended for removing solvents.
1282
    molecular_centres : bool, default False
1283
        Use molecular centres of mass as the motif instead of centres of
1284
        atoms.
1285
1286
    Returns
1287
    -------
1288
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1289
        Represents the crystal as a periodic set, consisting of a finite
1290
        set of points (motif) and lattice (unit cell). Contains other
1291
        useful data, e.g. the crystal's name and information about the
1292
        asymmetric unit for calculation.
1293
1294
    Raises
1295
    ------
1296
    ParseError
1297
        Raised if the structure fails parsing for any of the following:
1298
        1. :code:``disorder == 'skip'`` and disorder is found on any
1299
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
1300
        3. a.fractional_coordinates is None for any a in
1301
        crystal.disordered_molecule, 4. The motif is empty after
1302
        removing H, disordered sites or solvents.
1303
    """
1304
1305
    molecule = crystal.disordered_molecule
1306
1307
    # Disorder
1308
    if disorder == 'skip':
1309
        if crystal.has_disorder or \
1310
         any(_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
1311
            raise ParseError(
1312
                f"{crystal.identifier} has disorder, pass "
1313
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1314
                "disorder"
1315
            )
1316
    elif disorder == 'ordered_sites':
1317
        molecule.remove_atoms(
1318
            a for a in molecule.atoms if _has_disorder(a.label, a.occupancy)
1319
        )
1320
1321
    if remove_hydrogens:
1322
        molecule.remove_atoms(
1323
            a for a in molecule.atoms if a.atomic_symbol in 'HD'
1324
        )
1325
1326
    # Remove atoms with missing coordinates and warn
1327
    if any(a.fractional_coordinates is None for a in molecule.atoms):
1328
        warnings.warn('atoms without sites or missing data will be removed')
1329
        molecule.remove_atoms(
1330
            a for a in molecule.atoms if a.fractional_coordinates is None
1331
        )
1332
1333
    if heaviest_component and len(molecule.components) > 1:
1334
        molecule = _heaviest_component_ccdc(molecule)
1335
1336
    crystal.molecule = molecule
1337
    cellpar = crystal.cell_lengths + crystal.cell_angles
1338
    if None in cellpar:
1339
        raise ParseError(f'{crystal.identifier} has missing cell data')
1340
    cell = cellpar_to_cell(np.array(cellpar))
1341
1342
    if molecular_centres:
1343
        frac_centres = _frac_molecular_centres_ccdc(crystal, _EQ_SITE_TOL)
1344
        mol_centres = np.matmul(frac_centres, cell)
1345
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
1346
1347
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
1348
    # check for None?
1349
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
1350
1351
    if asym_unit.shape[0] == 0:
1352
        raise ParseError(f'{crystal.identifier} has no valid sites')
1353
1354
    asym_unit = np.mod(asym_unit, 1)
1355
1356
    # recommended by pymatgen
1357
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1358
1359
    asym_types = [a.atomic_number for a in asym_atoms]
1360
1361
    # Remove overlapping sites unless disorder == 'all_sites'
1362
    if disorder != 'all_sites':
1363
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1364
        if not np.all(keep_sites):
1365
            warnings.warn(
1366
                'may have overlapping sites; duplicates will be removed'
1367
            )
1368
        asym_unit = asym_unit[keep_sites]
1369
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1370
1371
    # Symmetry operations
1372
    sitesym = crystal.symmetry_operators
1373
    # try spacegroup numbers?
1374
    if not sitesym:
1375
        warnings.warn('no symmetry data found, defaulting to P1')
1376
        sitesym = ['x,y,z']
1377
1378
    # Apply symmetries to asymmetric unit
1379
    rot, trans = _parse_sitesyms(sitesym)
1380
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1381
    _, wyc_muls = np.unique(invs, return_counts=True)
1382
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1383
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1384
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1385
    motif = np.matmul(frac_motif, cell)
1386
1387
    return PeriodicSet(
1388
        motif=motif,
1389
        cell=cell,
1390
        name=crystal.identifier,
1391
        asym_unit=asym_inds,
1392
        multiplicities=wyc_muls,
1393
        types=types
1394
    )
1395
1396
1397
def memoize(f):
1398
    """Cache for _parse_sitesym()."""
1399
    cache = {}
1400
    def wrapper(arg):
1401
        if arg not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
1402
            cache[arg] = f(arg)
1403
        return cache[arg]
1404
    return wrapper
1405
1406
1407
@memoize
1408
def _parse_sitesym(sym: str) -> Tuple[np.ndarray, np.ndarray]:
1409
    """Parse a single symmetry as an xyz string and return a 3x3
1410
    rotation matrix and a 3x1 translation vector.
1411
    """
1412
1413
    rot = np.zeros((3, 3), dtype=np.float64)
1414
    trans = np.zeros((3, ), dtype=np.float64)
1415
1416
    for ind, element in enumerate(sym.split(',')):
1417
1418
        is_positive = True
1419
        is_fraction = False
1420
        sng_trans = None
1421
        fst_trans = []
1422
        snd_trans = []
1423
1424
        for char in element.lower():
1425
            if char == '+':
1426
                is_positive = True
1427
            elif char == '-':
1428
                is_positive = False
1429
            elif char == '/':
1430
                is_fraction = True
1431
            elif char in 'xyz':
1432
                rot_sgn = 1.0 if is_positive else -1.0
1433
                rot[ind][ord(char) - ord('x')] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1434
            elif char.isdigit() or char == '.':
1435
                if sng_trans is None:
1436
                    sng_trans = 1.0 if is_positive else -1.0
1437
                if is_fraction:
1438
                    snd_trans.append(char)
1439
                else:
1440
                    fst_trans.append(char)
1441
1442
        if not fst_trans:
1443
            e_trans = 0.0
1444
        else:
1445
            e_trans = sng_trans * float(''.join(fst_trans))
1446
1447
        if is_fraction:
1448
            e_trans /= float(''.join(snd_trans))
1449
1450
        trans[ind] = e_trans
1451
1452
    return rot, trans
1453
1454
1455
def _parse_sitesyms(symmetries: List[str]) -> Tuple[np.ndarray, np.ndarray]:
1456
    """Parse a sequence of symmetries in xyz form and return rotation
1457
    and translation arrays.
1458
    """
1459
    rotations = []
1460
    translations = []
1461
    for sym in symmetries:
1462
        rot, trans = _parse_sitesym(sym)
1463
        rotations.append(rot)
1464
        translations.append(trans)
1465
    return np.array(rotations), np.array(translations)
1466
1467
1468
def _expand_asym_unit(
1469
        asym_unit: np.ndarray,
1470
        rotations: np.ndarray,
1471
        translations: np.ndarray,
1472
        tol: float
1473
) -> Tuple[np.ndarray, np.ndarray]:
1474
    """Expand the asymmetric unit by applying symmetries given by
1475
    ``rotations`` and ``translations``.
1476
    """
1477
1478
    asym_unit = asym_unit.astype(np.float64, copy=False)
1479
    rotations = rotations.astype(np.float64, copy=False)
1480
    translations = translations.astype(np.float64, copy=False)
1481
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1482
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1483
1484
    if not all(_unique_sites(frac_motif, tol)):
1485
        frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1486
1487
    return frac_motif, invs
1488
1489
1490
@numba.njit(cache=True)
1491
def _expand_sites(
1492
        asym_unit: np.ndarray, rotations: np.ndarray, translations: np.ndarray
1493
) -> np.ndarray:
1494
    """Expand the asymmetric unit by applying ``rotations`` and
1495
    ``translations``, without yet removing points duplicated because
1496
    they are invariant under a symmetry. Returns a 3D array shape
1497
    (#points, #syms, dims).
1498
    """
1499
1500
    m, dims = asym_unit.shape
1501
    n_syms = len(rotations)
1502
    expanded_sites = np.empty((m, n_syms, dims), dtype=np.float64)
1503
    for i in range(m):
1504
        p = asym_unit[i]
1505
        for j in range(n_syms):
1506
            expanded_sites[i, j] = np.dot(rotations[j], p) + translations[j]
1507
    expanded_sites = np.mod(expanded_sites, 1)
1508
    return expanded_sites
1509
1510
1511
@numba.njit(cache=True)
1512
def _reduce_expanded_sites(
1513
        expanded_sites: np.ndarray, tol: float
1514
) -> Tuple[np.ndarray, np.ndarray]:
1515
    """Reduce the asymmetric unit after being expended by symmetries by
1516
    removing invariant points. This is the fast version which works in
1517
    the case that no two sites in the asymmetric unit are equivalent.
1518
    If they are, the reduction is re-ran with
1519
    _reduce_expanded_equiv_sites() to account for it.
1520
    """
1521
1522
    all_unqiue_inds = []
1523
    n_sites, _, dims = expanded_sites.shape
1524
    multiplicities = np.zeros(shape=(n_sites, ))
1525
1526
    for i, sites in enumerate(expanded_sites):
1527
        unique_inds = _unique_sites(sites, tol)
1528
        all_unqiue_inds.append(unique_inds)
1529
        multiplicities[i] = np.sum(unique_inds)
1530
1531
    m = int(np.sum(multiplicities))
1532
    frac_motif = np.zeros(shape=(m, dims))
1533
    inverses = np.zeros(shape=(m, ), dtype=np.int32)
1534
1535
    s = 0
1536
    for i in range(n_sites):
1537
        t = s + multiplicities[i]
1538
        frac_motif[s:t, :] = expanded_sites[i][all_unqiue_inds[i]]
1539
        inverses[s:t] = i
1540
        s = t
1541
1542
    return frac_motif, inverses
1543
1544
1545
def _reduce_expanded_equiv_sites(
1546
        expanded_sites: np.ndarray, tol: float
1547
) -> Tuple[np.ndarray, np.ndarray]:
1548
    """Reduce the asymmetric unit after being expended by symmetries by
1549
    removing invariant points. This is the slower version, called after
1550
    the fast version if we find equivalent motif points which need to be
1551
    removed.
1552
    """
1553
1554
    sites = expanded_sites[0]
1555
    unique_inds = _unique_sites(sites, tol)
1556
    frac_motif = sites[unique_inds]
1557
    inverses = [0] * len(frac_motif)
1558
1559
    for i in range(1, len(expanded_sites)):
1560
        sites = expanded_sites[i]
1561
        unique_inds = _unique_sites(sites, tol)
1562
1563
        points = []
1564
        for site in sites[unique_inds]:
1565
            diffs1 = np.abs(site - frac_motif)
1566
            diffs2 = np.abs(diffs1 - 1)
1567
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1568
1569
            if not np.any(mask):
1570
                points.append(site)
1571
            else:
1572
                warnings.warn(
1573
                    'has equivalent sites at positions '
1574
                    f'{inverses[np.argmax(mask)]}, {i}'
1575
                )
1576
1577
        if points:
1578
            inverses.extend(i for _ in range(len(points)))
1579
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1580
1581
    return frac_motif, np.array(inverses, dtype=np.int32)
1582
1583
1584
@numba.njit(cache=True)
1585
def _unique_sites(asym_unit: np.ndarray, tol: float) -> np.ndarray:
1586
    """Uniquify (within tol) a list of fractional coordinates,
1587
    considering all points modulo 1. Return an array of bools such that
1588
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1589
    """
1590
1591
    m, _ = asym_unit.shape
1592
    where_unique = np.full(shape=(m, ), fill_value=True)
1593
1594
    for i in range(1, m):
1595
        site_diffs1 = np.abs(asym_unit[:i, :] - asym_unit[i])
1596
        site_diffs2 = np.abs(site_diffs1 - 1)
1597
        sites_neq_mask = (site_diffs1 > tol) & (site_diffs2 > tol)
1598
        if not np.all(np.sum(sites_neq_mask, axis=-1)):
1599
            where_unique[i] = False
1600
1601
    return where_unique
1602
1603
1604
def _has_disorder(label: str, occupancy) -> bool:
1605
    """Return True if label ends with ? or occupancy is a number < 1."""
1606
    try:
1607
        occupancy = float(occupancy)
1608
    except Exception:
1609
        occupancy = 1
1610
    return (occupancy < 1) or label.endswith('?')
1611
1612
1613
def _get_syms_pymatgen(data: dict) -> Tuple[np.ndarray, np.ndarray]:
1614
    """Parse symmetry operations given by data = block.data where block
1615
    is a pymatgen CifBlock object. If the symops are not present the
1616
    space group symbol/international number is parsed and symops are
1617
    generated.
1618
    """
1619
1620
    from pymatgen.symmetry.groups import SpaceGroup
1621
    import pymatgen.io.cif
1622
1623
    # Try xyz symmetry operations
1624
    for symmetry_label in _CIF_TAGS['symop']:
1625
        xyz = data.get(symmetry_label)
1626
        if not xyz:
1627
            continue
1628
        if isinstance(xyz, str):
1629
            xyz = [xyz]
1630
        return _parse_sitesyms(xyz)
1631
1632
    symops = []
1633
    # Try spacegroup symbol
1634
    for symmetry_label in _CIF_TAGS['spacegroup_name']:
1635
        sg = data.get(symmetry_label)
1636
        if not sg:
1637
            continue
1638
        sg = re.sub(r'[\s_]', '', sg)
1639
        try:
1640
            spg = pymatgen.io.cif.space_groups.get(sg)
1641
            if not spg:
1642
                continue
1643
            symops = SpaceGroup(spg).symmetry_ops
1644
            break
1645
        except ValueError:
1646
            pass
1647
        try:
1648
            for d in pymatgen.io.cif._get_cod_data():
1649
                if sg == re.sub(r'\s+', '', d['hermann_mauguin']):
1650
                    return _parse_sitesyms(d['symops'])
1651
        except Exception:
1652
            continue
1653
        if symops:
1654
            break
1655
1656
    # Try international number
1657
    if not symops:
1658
        for symmetry_label in _CIF_TAGS['spacegroup_number']:
1659
            num = data.get(symmetry_label)
1660
            if not num:
1661
                continue
1662
            try:
1663
                i = int(pymatgen.io.cif.str2float(num))
1664
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1665
                break
1666
            except ValueError:
1667
                continue
1668
1669
    if not symops:
1670
        warnings.warn('no symmetry data found, defaulting to P1')
1671
        return _parse_sitesyms(['x,y,z'])
1672
1673
    rotations = [op.rotation_matrix for op in symops]
1674
    translations = [op.translation_vector for op in symops]
1675
    rotations = np.array(rotations, dtype=np.float64)
1676
    translations = np.array(translations, dtype=np.float64)
1677
    return rotations, translations
1678
1679
1680
def _frac_molecular_centres_ccdc(crystal, tol: float) -> np.ndarray:
1681
    """Return the geometric centres of molecules in the unit cell.
1682
    Expects a ccdc Crystal object and returns fractional coordiantes.
1683
    """
1684
1685
    frac_centres = []
1686
    for comp in crystal.packing(inclusion='CentroidIncluded').components:
1687
        coords = [a.fractional_coordinates for a in comp.atoms]
1688
        frac_centres.append([sum(ax) / len(coords) for ax in zip(*coords)])
1689
    frac_centres = np.mod(np.array(frac_centres, dtype=np.float64), 1)
1690
    return frac_centres[_unique_sites(frac_centres, tol)]
1691
1692
1693
def _heaviest_component_ccdc(molecule):
1694
    """Remove all but the heaviest component of the asymmetric unit.
1695
    Intended for removing solvents. Expects and returns a ccdc Molecule
1696
    object.
1697
    """
1698
1699
    component_weights = []
1700
    for component in molecule.components:
1701
        weight = 0
1702
        for a in component.atoms:
1703
            try:
1704
                occ = float(a.occupancy)
1705
            except ValueError:
1706
                occ = 1
1707
            try:
1708
                weight += float(a.atomic_weight) * occ
1709
            except ValueError:
1710
                pass
1711
        component_weights.append(weight)
1712
    largest_component_ind = np.argmax(np.array(component_weights))
1713
    molecule = molecule.components[largest_component_ind]
1714
    return molecule
1715
1716
1717
def _snap_small_prec_coords(frac_coords: np.ndarray, tol: float) -> np.ndarray:
1718
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1719
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1720
    """
1721
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.
1722
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.
1723
    return frac_coords
1724