Passed
Push — master ( d05325...baece5 )
by Daniel
03:54
created

amd.io._parse_sitesyms()   F

Complexity

Conditions 16

Size

Total Lines 49
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 35
dl 0
loc 49
rs 2.4
c 0
b 0
f 0
cc 16
nop 1

How to fix   Complexity   

Complexity

Complex classes like amd.io._parse_sitesyms() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import warnings
9
import os
10
import re
11
import functools
12
import errno
13
import math
14
import json
15
from pathlib import Path
16
from typing import Iterable, Iterator, Optional, Union, Callable, Tuple, List
17
18
import numpy as np
19
import numpy.typing as npt
20
import numba
21
import tqdm
22
23
from .utils import cellpar_to_cell
24
from .periodicset import PeriodicSet
25
26
__all__ = [
27
    'CifReader', 'CSDReader', 'ParseError', 'periodicset_from_gemmi_block',
28
    'periodicset_from_ase_cifblock', 'periodicset_from_pymatgen_cifblock',
29
    'periodicset_from_ase_atoms', 'periodicset_from_pymatgen_structure',
30
    'periodicset_from_ccdc_entry', 'periodicset_from_ccdc_crystal'
31
]
32
33
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
34
    return f'{category.__name__}: {message}\n'
35
36
warnings.formatwarning = _custom_warning
37
38
with open(str(Path(__file__).absolute().parent / 'atomic_numbers.json')) as f:
39
    _ATOMIC_NUMBERS = json.load(f)
40
41
_EQ_SITE_TOL: float = 1e-3
42
_CIF_TAGS: dict = {
43
    'cellpar': [
44
        '_cell_length_a', '_cell_length_b', '_cell_length_c',
45
        '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma'],
46
    'atom_site_fract': [
47
        '_atom_site_fract_x', '_atom_site_fract_y', '_atom_site_fract_z'],
48
    'atom_site_cartn': [
49
        '_atom_site_Cartn_x', '_atom_site_Cartn_y', '_atom_site_Cartn_z'],
50
    'symop': [
51
        '_space_group_symop_operation_xyz', '_space_group_symop.operation_xyz',
52
        '_symmetry_equiv_pos_as_xyz'],
53
    'spacegroup_name': [
54
        '_space_group_name_H-M_alt', '_symmetry_space_group_name_H-M'],
55
    'spacegroup_number': [
56
        '_space_group_IT_number', '_symmetry_Int_Tables_number']
57
}
58
59
60
class _Reader:
61
    """Base reader class."""
62
63
    def __init__(
64
            self,
65
            iterable: Iterable,
66
            converter: Callable[..., PeriodicSet],
67
            show_warnings: bool,
68
            verbose: bool
69
    ):
70
71
        self._iterator = iter(iterable)
72
        self._converter = converter
73
        self.show_warnings = show_warnings
74
        if verbose:
75
            self._progress_bar = tqdm.tqdm(desc='Reading', delay=1)
76
        else:
77
            self._progress_bar = None
78
79
    def __iter__(self):
80
        return self
81
82
    def __next__(self):
83
        """Iterate over self._iterator, passing items through
84
        self._converter and yielding. If
85
        :class:`ParseError <.io.ParseError>` is raised in a call to
86
        self._converter, the item is skipped. Warnings raised in
87
        self._converter are printed if self.show_warnings is True.
88
        """
89
90
        if not self.show_warnings:
91
            warnings.simplefilter('ignore')
92
93
        while True:
94
95
            try:
96
                item = next(self._iterator)
97
            except StopIteration:
98
                if self._progress_bar is not None:
99
                    self._progress_bar.close()
100
                raise StopIteration
101
102
            parse_err_msg = None
103
            with warnings.catch_warnings(record=True) as warning_msgs:
104
                try:
105
                    periodic_set = self._converter(item)
106
                except ParseError as err:
107
                    parse_err_msg = str(err)
108
109
            if self._progress_bar is not None:
110
                self._progress_bar.update(1)
111
112
            if parse_err_msg is not None:
113
                warnings.warn(parse_err_msg)
114
                continue
115
116
            for warning in warning_msgs:
117
                msg = f'(name={periodic_set.name}) {warning.message}'
118
                warnings.warn(msg, category=warning.category)
119
120
            return periodic_set
121
122
    def read(self) -> Union[PeriodicSet, List[PeriodicSet]]:
123
        """Read the crystal(s), return one
124
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
125
        only one, otherwise return a list. If there the structure cannot
126
        be parsed, return None.
127
        """
128
        items = list(self)
129
        if len(items) == 1:
130
            return items[0]
131
        return items
132
133
134
class CifReader(_Reader):
135
    """Read all structures in a .cif file or all files in a folder
136
    with ase or csd-python-api (if installed), yielding
137
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
138
139
    Parameters
140
    ----------
141
    path : str
142
        Path to a .CIF file or directory. (Other files are accepted when
143
        using ``reader='ccdc'``, if csd-python-api is installed.)
144
    reader : str, optional
145
        The backend package used to parse the CIF. The default is
146
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
147
        accepted, as well as :code:`ccdc` if csd-python-api is
148
        installed. The ccdc reader should be able to read any format
149
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
150
        been tested.
151
    remove_hydrogens : bool, optional
152
        Remove Hydrogens from the crystals.
153
    disorder : str, optional
154
        Controls how disordered structures are handled. Default is
155
        ``skip`` which skips any crystal with disorder, since disorder
156
        conflicts with the periodic set model. To read disordered
157
        structures anyway, choose either :code:`ordered_sites` to remove
158
        atoms with disorder or :code:`all_sites` include all atoms
159
        regardless of disorder.
160
    heaviest_component : bool, optional
161
        csd-python-api only. Removes all but the heaviest molecule in
162
        the asymmeric unit, intended for removing solvents.
163
    molecular_centres : bool, default False
164
        csd-python-api only. Extract the centres of molecules in the
165
        unit cell and store in the attribute molecular_centres.
166
    show_warnings : bool, optional
167
        Controls whether warnings that arise during reading are printed.
168
    verbose : bool, default False
169
        If True, prints a progress bar showing the number of items
170
        processed.
171
172
    Yields
173
    ------
174
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
175
        Represents the crystal as a periodic set, consisting of a finite
176
        set of points (motif) and lattice (unit cell). Contains other
177
        data, e.g. the crystal's name and information about the
178
        asymmetric unit.
179
180
    Examples
181
    --------
182
183
        ::
184
185
            # Put all crystals in a .CIF in a list
186
            structures = list(amd.CifReader('mycif.cif'))
187
188
            # Can also accept path to a directory, reading all files inside
189
            structures = list(amd.CifReader('path/to/folder'))
190
191
            # Reads just one if the .CIF has just one crystal
192
            periodic_set = amd.CifReader('mycif.cif').read()
193
194
            # List of AMDs (k=100) of crystals in a .CIF
195
            amds = [amd.AMD(item, 100) for item in amd.CifReader('mycif.cif')]
196
    """
197
198
    def __init__(
199
            self,
200
            path: Union[str, os.PathLike],
201
            reader: str = 'gemmi',
202
            remove_hydrogens: bool = False,
203
            disorder: str = 'skip',
204
            heaviest_component: bool = False,
205
            molecular_centres: bool = False,
206
            show_warnings: bool = True,
207
            verbose: bool = False
208
    ):
209
210
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
211
            raise ValueError(
212
                f"'disorder'' parameter of {self.__class__.__name__} must be "
213
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
214
                f"'{disorder}')"
215
            )
216
217
        if reader != 'ccdc':
218
            if heaviest_component:
219
                raise NotImplementedError(
220
                    "'heaviest_component' parameter of "
221
                    f"{self.__class__.__name__} only implemented with "
222
                    "csd-python-api, if installed pass reader='ccdc'"
223
                )
224
            if molecular_centres:
225
                raise NotImplementedError(
226
                    "'molecular_centres' parameter of "
227
                    f"{self.__class__.__name__} only implemented with "
228
                    "csd-python-api, if installed pass reader='ccdc'"
229
                )
230
231
        # cannot handle some characters (�) in cifs
232
        if reader == 'gemmi':
233
            import gemmi
234
            extensions = {'cif'}
235
            file_parser = gemmi.cif.read_file
236
            converter = functools.partial(
237
                periodicset_from_gemmi_block,
238
                remove_hydrogens=remove_hydrogens,
239
                disorder=disorder
240
            )
241
242
        elif reader in ('ase', 'pycodcif'):
243
            from ase.io.cif import parse_cif
244
            extensions = {'cif'}
245
            file_parser = functools.partial(parse_cif, reader=reader)
246
            converter = functools.partial(
247
                periodicset_from_ase_cifblock,
248
                remove_hydrogens=remove_hydrogens,
249
                disorder=disorder
250
            )
251
252
        elif reader == 'pymatgen':
253
254
            def _pymatgen_cif_parser(path):
255
                from pymatgen.io.cif import CifFile
256
                return CifFile.from_file(path).data.values()
257
258
            extensions = {'cif'}
259
            file_parser = _pymatgen_cif_parser
260
            converter = functools.partial(
261
                periodicset_from_pymatgen_cifblock,
262
                remove_hydrogens=remove_hydrogens,
263
                disorder=disorder
264
            )
265
266
        elif reader == 'ccdc':
267
            try:
268
                import ccdc.io
269
            except (ImportError, RuntimeError) as e:
270
                raise ImportError('Failed to import csd-python-api') from e
271
272
            extensions = set(ccdc.io.EntryReader.known_suffixes.keys())
273
            file_parser = ccdc.io.EntryReader
274
            converter = functools.partial(
275
                periodicset_from_ccdc_entry,
276
                remove_hydrogens=remove_hydrogens,
277
                disorder=disorder,
278
                molecular_centres=molecular_centres,
279
                heaviest_component=heaviest_component
280
            )
281
282
        else:
283
            raise ValueError(
284
                f"'reader' parameter of {self.__class__.__name__} must be one "
285
                f"of 'gemmi', 'pymatgen', 'ccdc', 'ase', or 'pycodcif' "
286
                f"(passed '{reader}')"
287
            )
288
289
        path = Path(path)
290
        if path.is_file():
291
            iterable = file_parser(str(path))
292
        elif path.is_dir():
293
            iterable = CifReader._dir_generator(path, file_parser, extensions)
294
        else:
295
            raise FileNotFoundError(
296
                errno.ENOENT, os.strerror(errno.ENOENT), path
297
            )
298
299
        super().__init__(iterable, converter, show_warnings, verbose)
300
301
    @staticmethod
302
    def _dir_generator(
303
            path: Path,
304
            file_parser: Callable,
305
            extensions: Iterable
306
    ) -> Iterator:
307
        for file_path in path.iterdir():
308
            if not file_path.is_file():
309
                continue
310
            if file_path.suffix[1:].lower() not in extensions:
311
                continue
312
            try:
313
                yield from file_parser(str(file_path))
314
            except Exception as e:
315
                warnings.warn(
316
                    f'Error parsing "{str(file_path)}", skipping: {str(e)}'
317
                )
318
319
320
class CSDReader(_Reader):
321
    """Read structures from the CSD with csd-python-api, yielding
322
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
323
324
    Parameters
325
    ----------
326
    refcodes : str or List[str], optional
327
        Single or list of CSD refcodes to read. If None or 'CSD',
328
        iterates over the whole CSD.
329
    families : bool, optional
330
        Read all entries whose refcode starts with the given strings, or
331
        'families' (e.g. giving 'DEBXIT' reads all entries starting with
332
        DEBXIT).
333
    remove_hydrogens : bool, optional
334
        Remove hydrogens from the crystals.
335
    disorder : str, optional
336
        Controls how disordered structures are handled. Default is
337
        ``skip`` which skips any crystal with disorder, since disorder
338
        conflicts with the periodic set model. To read disordered
339
        structures anyway, choose either :code:`ordered_sites` to remove
340
        atoms with disorder or :code:`all_sites` include all atoms
341
        regardless of disorder.
342
    heaviest_component : bool, optional
343
        Removes all but the heaviest molecule in the asymmeric unit,
344
        intended for removing solvents.
345
    molecular_centres : bool, default False
346
        Extract the centres of molecules in the unit cell and store in
347
        attribute molecular_centres.
348
    show_warnings : bool, optional
349
        Controls whether warnings that arise during reading are printed.
350
    verbose : bool, default False
351
        If True, prints a progress bar showing the number of items
352
        processed.
353
354
    Yields
355
    ------
356
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
357
        Represents the crystal as a periodic set, consisting of a finite
358
        set of points (motif) and lattice (unit cell). Contains other
359
        useful data, e.g. the crystal's name and information about the
360
        asymmetric unit for calculation.
361
362
    Examples
363
    --------
364
365
        ::
366
367
            # Put these entries in a list
368
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
369
            structures = list(amd.CSDReader(refcodes))
370
371
            # Read refcode families (any whose refcode starts with strings in the list)
372
            refcode_families = ['ACSALA', 'HXACAN']
373
            structures = list(amd.CSDReader(refcode_families, families=True))
374
375
            # Get AMDs (k=100) for crystals in these families
376
            refcodes = ['ACSALA', 'HXACAN']
377
            amds = []
378
            for periodic_set in amd.CSDReader(refcodes, families=True):
379
                amds.append(amd.AMD(periodic_set, 100))
380
381
            # Giving the reader nothing reads from the whole CSD.
382
            for periodic_set in amd.CSDReader():
383
                ...
384
    """
385
386
    def __init__(
387
            self,
388
            refcodes: Optional[Union[str, Iterable[str]]] = None,
389
            families: bool = False,
390
            remove_hydrogens: bool = False,
391
            disorder: str = 'skip',
392
            heaviest_component: bool = False,
393
            molecular_centres: bool = False,
394
            show_warnings: bool = True,
395
            verbose: bool = False
396
    ):
397
398
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
399
            raise ValueError(
400
                f"'disorder'' parameter of {self.__class__.__name__} must be "
401
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
402
                f"'{disorder}')"
403
            )
404
405
        try:
406
            import ccdc.search
407
            import ccdc.io
408
        except (ImportError, RuntimeError) as e:
409
            raise ImportError('Failed to import csd-python-api') from e
410
411
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
412
            refcodes = None
413
414
        if refcodes is None:
415
            families = False
416
        elif isinstance(refcodes, str):
417
            refcodes = [refcodes]
418
        elif not all(isinstance(refcode, str) for refcode in refcodes):
419
            raise ValueError(
420
                f'First argument of {self.__class__.__name__} expects None, a '
421
                'string or list of strings.'
422
            )
423
424
        if families:
425
            all_refcodes = []
426
            for refcode in refcodes:
427
                query = ccdc.search.TextNumericSearch()
428
                query.add_identifier(refcode)
429
                hits = [hit.identifier for hit in query.search()]
430
                all_refcodes.extend(hits)
431
432
            # filter to unique refcodes while keeping order
433
            refcodes = []
434
            seen = set()
435
            for refcode in all_refcodes:
436
                if refcode not in seen:
437
                    refcodes.append(refcode)
438
                    seen.add(refcode)
439
440
        converter = functools.partial(
441
            periodicset_from_ccdc_entry,
442
            remove_hydrogens=remove_hydrogens,
443
            disorder=disorder,
444
            molecular_centres=molecular_centres,
445
            heaviest_component=heaviest_component
446
        )
447
448
        entry_reader = ccdc.io.EntryReader('CSD')
449
        if refcodes is None:
450
            iterable = entry_reader
451
        else:
452
            iterable = map(entry_reader.entry, refcodes)
453
454
        super().__init__(iterable, converter, show_warnings, verbose)
455
456
457
class ParseError(ValueError):
458
    """Raised when an item cannot be parsed into a periodic set."""
459
    pass
460
461
462
def periodicset_from_gemmi_block(
463
        block,
464
        remove_hydrogens: bool = False,
465
        disorder: str = 'skip'
466
) -> PeriodicSet:
467
    """Convert a :class:`gemmi.cif.Block` object to a
468
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
469
    :class:`gemmi.cif.Block` is the type returned by
470
    :func:`gemmi.cif.read_file`.
471
472
    Parameters
473
    ----------
474
    block : :class:`gemmi.cif.Block`
475
        An ase CIFBlock object representing a crystal.
476
    remove_hydrogens : bool, optional
477
        Remove Hydrogens from the crystal.
478
    disorder : str, optional
479
        Controls how disordered structures are handled. Default is
480
        ``skip`` which skips any crystal with disorder, since disorder
481
        conflicts with the periodic set model. To read disordered
482
        structures anyway, choose either :code:`ordered_sites` to remove
483
        atoms with disorder or :code:`all_sites` include all atoms
484
        regardless of disorder.
485
486
    Returns
487
    -------
488
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
489
        Represents the crystal as a periodic set, consisting of a finite
490
        set of points (motif) and lattice (unit cell). Contains other
491
        useful data, e.g. the crystal's name and information about the
492
        asymmetric unit for calculation.
493
494
    Raises
495
    ------
496
    ParseError
497
        Raised if the structure fails to be parsed for any of the
498
        following: 1. Required data is missing (e.g. cell parameters),
499
        2. :code:``disorder == 'skip'`` and disorder is found on any
500
        atom, 3. The motif is empty after removing H or disordered
501
        sites.
502
    """
503
504
    import gemmi
505
    from gemmi.cif import as_number, as_string, as_int
506
507
    # Unit cell
508
    cellpar = [block.find_value(t) for t in _CIF_TAGS['cellpar']]
509
    if not all(isinstance(par, str) for par in cellpar):
510
        raise ParseError(f'{block.name} has missing cell data')
511
    cellpar = np.array([as_number(par) for par in cellpar])
512
    if np.isnan(np.sum(cellpar)):
513
        raise ParseError(f'{block.name} has missing cell data')
514
    cell = cellpar_to_cell(cellpar)
515
516
    # Asymmetric unit coordinates
517
    xyz_loop = block.find(_CIF_TAGS['atom_site_fract']).loop
518
    if xyz_loop is None:
519
        xyz_loop = block.find(_CIF_TAGS['atom_site_cartn']).loop
520
        if xyz_loop is None:
521
            raise ParseError(f'{block.name} has missing coordinate data')
522
        else:
523
            raise ParseError(
524
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
525
                'only _atom_site_fract_ is supported'
526
            )
527
528
    tablified_loop = [[] for _ in range(len(xyz_loop.tags))]
529
    n_cols = xyz_loop.width()
530
    for i, item in enumerate(xyz_loop.values):
531
        tablified_loop[i % n_cols].append(item)
532
    loop_dict = {tag: l for tag, l in zip(xyz_loop.tags, tablified_loop)}
533
    xyz_str = [loop_dict[t] for t in _CIF_TAGS['atom_site_fract']]
534
    asym_unit = np.transpose(np.array(
535
        [[as_number(c) for c in coords] for coords in xyz_str]
536
    ))
537
    asym_unit = np.mod(asym_unit, 1)
538
    # recommended by pymatgen
539
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4) 
540
541
    # Labels
542
    if '_atom_site_label' in loop_dict:
543
        labels = [as_string(label) for label in loop_dict['_atom_site_label']]
544
    else:
545
        labels = [''] * xyz_loop.length()
546
547
    # Atomic types
548
    if '_atom_site_type_symbol' in loop_dict:
549
        symbols = [as_string(s) for s in loop_dict['_atom_site_type_symbol']]
550
    else:
551
        symbols = []
552
        for label in labels:
553
            sym = ''
554
            if label:
555
                match = re.search(r'([A-Z][a-z]?)', label)
556
                if match is not None:
557
                    sym = match.group() 
558
            symbols.append(sym)
559
    asym_types = [_ATOMIC_NUMBERS[s] for s in symbols]
560
561
    # Occupancies
562
    if '_atom_site_occupancy' in loop_dict:
563
        occs = [as_number(occ) for occ in loop_dict['_atom_site_occupancy']]
564
        occupancies = [occ if not math.isnan(occ) else 1 for occ in occs]
565
    else:
566
        occupancies = [1] * xyz_loop.length()
567
568
    # Remove sites with missing coordinates, disorder and Hydrogens if needed
569
    remove_sites = []
570
    remove_sites.extend(np.nonzero(np.isnan(asym_unit.min(axis=-1)))[0])
571
572
    if disorder == 'skip':
573
        if any(
574
            _has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)
575
        ):
576
            raise ParseError(
577
                f"{block.name} has disorder, pass disorder='ordered_sites' or "
578
                "'all_sites' to remove/ignore disorder"
579
            )
580
    elif disorder == 'ordered_sites':
581
        for i, (label, occ) in enumerate(zip(labels, occupancies)):
582
            if _has_disorder(label, occ):
583
                remove_sites.append(i)
584
585
    if remove_hydrogens:
586
        remove_sites.extend(
587
            i for i, num in enumerate(asym_types) if num == 1
588
        )
589
590
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
591
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
592
    if asym_unit.shape[0] == 0:
593
        raise ParseError(f'{block.name} has no valid sites')
594
595
    if disorder != 'all_sites':
596
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
597
        if not np.all(keep_sites):
598
            warnings.warn(
599
                'may have overlapping sites; duplicates will be removed'
600
            )
601
        asym_unit = asym_unit[keep_sites]
602
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
603
604
    # Symmetry operations, try xyz strings first
605
    for tag in _CIF_TAGS['symop']:
606
        sitesym = [v.str(0) for v in block.find([tag])]
607
        if sitesym:
608
            rot, trans = _parse_sitesyms(sitesym)
609
            break
610
    else:
611
        # Try spacegroup name; can be a pair or in a loop
612
        spg = None
613
        for tag in _CIF_TAGS['spacegroup_name']:
614
            for value in block.find([tag]):
615
                try:
616
                    # Some names cannot be parsed by gemmi.SpaceGroup
617
                    spg = gemmi.SpaceGroup(value.str(0))
618
                    break
619
                except ValueError:
620
                    continue
621
            if spg is not None:
622
                break
623
624
        if spg is None:
625
            # Try international number
626
            for tag in _CIF_TAGS['spacegroup_number']:
627
                spg_num = block.find_value(tag)
628
                if spg_num is not None:
629
                    spg_num = as_int(spg_num)
630
                    break
631
            else:
632
                warnings.warn('no symmetry data found, defaulting to P1')
633
                spg_num = 1
634
            spg = gemmi.SpaceGroup(spg_num)
635
636
        rot = np.array([np.array(o.rot) / o.DEN for o in spg.operations()])
637
        trans = np.array([np.array(o.tran) / o.DEN for o in spg.operations()])
638
639
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
640
    _, wyc_muls = np.unique(invs, return_counts=True)
641
    wyc_muls = wyc_muls.astype(np.int32)
642
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
643
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
644
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
645
    motif = frac_motif @ cell
646
647
    return PeriodicSet(
648
        motif=motif,
649
        cell=cell,
650
        name=block.name,
651
        asymmetric_unit=asym_inds,
652
        wyckoff_multiplicities=wyc_muls,
653
        types=types
654
    )
655
656
657
def periodicset_from_ase_cifblock(
658
        block,
659
        remove_hydrogens: bool = False,
660
        disorder: str = 'skip'
661
) -> PeriodicSet:
662
    """Convert a :class:`ase.io.cif.CIFBlock` object to a 
663
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
664
    :class:`ase.io.cif.CIFBlock` is the type returned by
665
    :func:`ase.io.cif.parse_cif`.
666
667
    Parameters
668
    ----------
669
    block : :class:`ase.io.cif.CIFBlock`
670
        An ase :class:`ase.io.cif.CIFBlock` object representing a
671
        crystal.
672
    remove_hydrogens : bool, optional
673
        Remove Hydrogens from the crystal.
674
    disorder : str, optional
675
        Controls how disordered structures are handled. Default is
676
        ``skip`` which skips any crystal with disorder, since disorder
677
        conflicts with the periodic set model. To read disordered
678
        structures anyway, choose either :code:`ordered_sites` to remove
679
        atoms with disorder or :code:`all_sites` include all atoms
680
        regardless of disorder.
681
682
    Returns
683
    -------
684
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
685
        Represents the crystal as a periodic set, consisting of a finite
686
        set of points (motif) and lattice (unit cell). Contains other
687
        useful data, e.g. the crystal's name and information about the
688
        asymmetric unit for calculation.
689
690
    Raises
691
    ------
692
    ParseError
693
        Raised if the structure fails to be parsed for any of the
694
        following: 1. Required data is missing (e.g. cell parameters),
695
        2. The motif is empty after removing H or disordered sites,
696
        3. :code:``disorder == 'skip'`` and disorder is found on any
697
        atom.
698
    """
699
700
    import ase
701
    import ase.spacegroup
702
703
    # Unit cell
704
    cellpar = [block.get(tag) for tag in _CIF_TAGS['cellpar']]
705
    if None in cellpar:
706
        raise ParseError(f'{block.name} has missing cell data')
707
    cell = cellpar_to_cell(np.array(cellpar))
708
709
    # Asymmetric unit coordinates. ase removes uncertainty brackets
710
    asym_unit = [block.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
711
    if None in asym_unit:
712
        asym_unit = [
713
            block.get(tag.lower()) for tag in _CIF_TAGS['atom_site_cartn']
714
        ]
715
        if None in asym_unit:
716
            raise ParseError(f'{block.name} has missing coordinates')
717
        else:
718
            raise ParseError(
719
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
720
                'only _atom_site_fract_ is supported'
721
            )
722
    asym_unit = list(zip(*asym_unit))
723
724
    # Labels
725
    asym_labels = block.get('_atom_site_label')
726
    if asym_labels is None:
727
        asym_labels = [''] * len(asym_unit)
728
729
    # Atomic types
730
    asym_symbols = block.get('_atom_site_type_symbol')
731 View Code Duplication
    if asym_symbols is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
732
        asym_symbols_ = []
733
        for label in asym_symbols:
734
            sym = ''
735
            if label and label not in ('.', '?'):
736
                match = re.search(r'([A-Z][a-z]?)', label)
737
                if match is not None:
738
                    sym = match.group() 
739
            asym_symbols_.append(sym)
740
    else:
741
        asym_symbols_ = [''] * len(asym_unit)
742
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
743
744
    # Find where sites have disorder if necassary
745
    has_disorder = []
746
    if disorder != 'all_sites':
747
        occupancies = block.get('_atom_site_occupancy')
748
        if occupancies is None:
749
            occupancies = [1] * len(asym_unit)
750
        for lab, occ in zip(asym_labels, occupancies):
751
            has_disorder.append(_has_disorder(lab, occ))
752
753
    # Remove sites with ?, . or other invalid string for coordinates
754
    invalid = []
755
    for i, xyz in enumerate(asym_unit):
756
        if not all(isinstance(coord, (int, float)) for coord in xyz):
757
            invalid.append(i)
758
    if invalid:
759
        warnings.warn('atoms without sites or missing data will be removed')
760
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
761
        asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
762
        if disorder != 'all_sites':
763
            has_disorder = [d for i, d in enumerate(has_disorder)
764
                            if i not in invalid]
765
766
    remove_sites = []
767
768
    if remove_hydrogens:
769
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
770
771
    # Remove atoms with fractional occupancy or raise ParseError
772 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
773
        for i, dis in enumerate(has_disorder):
774
            if i in remove_sites:
775
                continue
776
            if dis:
777
                if disorder == 'skip':
778
                    raise ParseError(
779
                        f'{block.name} has disorder, pass '
780
                        "disorder='ordered_sites' or 'all_sites' to "
781
                        'remove/ignore disorder'
782
                    )
783
                elif disorder == 'ordered_sites':
784
                    remove_sites.append(i)
785
786
    # Asymmetric unit
787
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
788
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
789
    if len(asym_unit) == 0:
790
        raise ParseError(f'{block.name} has no valid sites')
791
    asym_unit = np.mod(np.array(asym_unit), 1)
792
793
    # recommended by pymatgen
794
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
795
796
    # Remove overlapping sites unless disorder == 'all_sites'
797
    if disorder != 'all_sites':
798
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
799
        if not np.all(keep_sites):
800
            warnings.warn(
801
                'may have overlapping sites, duplicates will be removed'
802
            )
803
            asym_unit = asym_unit[keep_sites]
804
            asym_types = [t for t, keep in zip(asym_types, keep_sites) if keep]
805
806
    # Get symmetry operations
807
    sitesym = block._get_any(_CIF_TAGS['symop'])
808
    if sitesym is None:
809
        label_or_num = block._get_any(
810
            [s.lower() for s in _CIF_TAGS['spacegroup_name']]
811
        )
812
        if label_or_num is None:
813
            label_or_num = block._get_any(
814
                [s.lower() for s in _CIF_TAGS['spacegroup_number']]
815
            )
816
        if label_or_num is None:
817
            warnings.warn('no symmetry data found, defaulting to P1')
818
            label_or_num = 1
819
        spg = ase.spacegroup.Spacegroup(label_or_num)
820
        rot, trans = spg.get_op()
821
    else:
822
        if isinstance(sitesym, str):
823
            sitesym = [sitesym]
824
        rot, trans = _parse_sitesyms(sitesym)
825
826
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
827
    _, wyc_muls = np.unique(invs, return_counts=True)
828
    wyc_muls = wyc_muls.astype(np.int32)
829
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
830
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
831
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
832
    motif = frac_motif @ cell
833
834
    return PeriodicSet(
835
        motif=motif,
836
        cell=cell,
837
        name=block.name,
838
        asymmetric_unit=asym_inds,
839
        wyckoff_multiplicities=wyc_muls,
840
        types=types
841
    )
842
843
844
def periodicset_from_pymatgen_cifblock(
845
        block,
846
        remove_hydrogens: bool = False,
847
        disorder: str = 'skip'
848
) -> PeriodicSet:
849
    """Convert a :class:`pymatgen.io.cif.CifBlock` object to a
850
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
851
    :class:`pymatgen.io.cif.CifBlock` is the type returned by
852
    :class:`pymatgen.io.cif.CifFile`.
853
854
    Parameters
855
    ----------
856
    block : :class:`pymatgen.io.cif.CifBlock`
857
        A pymatgen CifBlock object representing a crystal.
858
    remove_hydrogens : bool, optional
859
        Remove Hydrogens from the crystal.
860
    disorder : str, optional
861
        Controls how disordered structures are handled. Default is
862
        ``skip`` which skips any crystal with disorder, since disorder
863
        conflicts with the periodic set model. To read disordered
864
        structures anyway, choose either :code:`ordered_sites` to remove
865
        atoms with disorder or :code:`all_sites` include all atoms
866
        regardless of disorder.
867
868
    Returns
869
    -------
870
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
871
        Represents the crystal as a periodic set, consisting of a finite
872
        set of points (motif) and lattice (unit cell). Contains other
873
        useful data, e.g. the crystal's name and information about the
874
        asymmetric unit for calculation.
875
876
    Raises
877
    ------
878
    ParseError
879
        Raised if the structure can/should not be parsed for the
880
        following reasons: 1. No sites found or motif is empty after
881
        removing Hydrogens & disorder, 2. A site has missing
882
        coordinates, 3. :code:``disorder == 'skip'`` and disorder is
883
        found on any atom.
884
    """
885
886
    from pymatgen.io.cif import str2float
887
888
    odict = block.data
889
890
    # Unit cell
891
    cellpar = [odict.get(tag) for tag in _CIF_TAGS['cellpar']]
892
    if any(par in (None, '?', '.') for par in cellpar):
893
        raise ParseError(f'{block.header} has missing cell data')
894
    cell = cellpar_to_cell(np.array([str2float(v) for v in cellpar]))
895
896
    # Asymmetric unit coordinates
897
    asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
898
    # check for . and ?
899
    if None in asym_unit:
900
        asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_cartn']]
901
        if None in asym_unit:
902
            raise ParseError(f'{block.header} has missing coordinates')
903
        else:
904
            raise ParseError(
905
                f'{block.header} uses _atom_site_Cartn_ tags for coordinates, '
906
                'only _atom_site_fract_ is supported'
907
            )
908
    asym_unit = list(zip(*asym_unit))
909
    asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
910
911
    # Labels
912
    asym_labels = odict.get('_atom_site_label')
913
    if asym_labels is None:
914
        asym_labels = [''] * len(asym_unit)
915
916
    # Atomic types
917
    asym_symbols = odict.get('_atom_site_type_symbol')
918 View Code Duplication
    if asym_symbols is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
919
        asym_symbols_ = []
920
        for label in asym_symbols:
921
            sym = ''
922
            if label and label not in ('.', '?'):
923
                match = re.search(r'([A-Z][a-z]?)', label)
924
                if match is not None:
925
                    sym = match.group() 
926
            asym_symbols_.append(sym)
927
    else:
928
        asym_symbols_ = [''] * len(asym_unit)
929
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
930
931
    # Find where sites have disorder if necassary
932
    has_disorder = []
933
    if disorder != 'all_sites':
934
        occupancies = odict.get('_atom_site_occupancy')
935
        if occupancies is None:
936
            occupancies = np.ones((len(asym_unit), ))
937
        else:
938
            occupancies = np.array([str2float(occ) for occ in occupancies])
939
        labels = odict.get('_atom_site_label')
940
        if labels is None:
941
            labels = [''] * len(asym_unit)
942
        for lab, occ in zip(labels, occupancies):
943
            has_disorder.append(_has_disorder(lab, occ))
944
945
    # Remove sites with ?, . or other invalid string for coordinates
946
    invalid = []
947
    for i, xyz in enumerate(asym_unit):
948
        if not all(isinstance(coord, (int, float)) for coord in xyz):
949
            invalid.append(i)
950
951
    if invalid:
952
        warnings.warn('atoms without sites or missing data will be removed')
953
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
954
        asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
955
        if disorder != 'all_sites':
956
            has_disorder = [
957
                d for i, d in enumerate(has_disorder) if i not in invalid
958
            ]
959
960
    remove_sites = []
961
962
    if remove_hydrogens:
963
        remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
0 ignored issues
show
introduced by
The variable i does not seem to be defined in case the for loop on line 947 is not entered. Are you sure this can never be the case?
Loading history...
964
965
    # Remove atoms with fractional occupancy or raise ParseError
966 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
967
        for i, dis in enumerate(has_disorder):
968
            if i in remove_sites:
969
                continue
970
            if dis:
971
                if disorder == 'skip':
972
                    raise ParseError(
973
                        f'{block.header} has disorder, pass '
974
                        "disorder='ordered_sites' or 'all_sites' to "
975
                        'remove/ignore disorder'
976
                    )
977
                elif disorder == 'ordered_sites':
978
                    remove_sites.append(i)
979
980
    # Asymmetric unit
981
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
982
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
983
    if len(asym_unit) == 0:
984
        raise ParseError(f'{block.header} has no valid sites')
985
    asym_unit = np.mod(np.array(asym_unit), 1)
986
987
    # recommended by pymatgen
988
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
989
990
    # Remove overlapping sites unless disorder == 'all_sites'
991
    if disorder != 'all_sites':
992
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
993
        if not np.all(keep_sites):
994
            warnings.warn(
995
                'may have overlapping sites; duplicates will be removed'
996
            )
997
        asym_unit = asym_unit[keep_sites]
998
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
999
1000
    # Apply symmetries to asymmetric unit
1001
    rot, trans = _get_syms_pymatgen(odict)
1002
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1003
    _, wyc_muls = np.unique(invs, return_counts=True)
1004
    wyc_muls = wyc_muls.astype(np.int32)
1005
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1006
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1007
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1008
    motif = frac_motif @ cell
1009
1010
    return PeriodicSet(
1011
        motif=motif,
1012
        cell=cell,
1013
        name=block.header,
1014
        asymmetric_unit=asym_inds,
1015
        wyckoff_multiplicities=wyc_muls,
1016
        types=types
1017
    )
1018
1019
1020
def periodicset_from_ase_atoms(
1021
        atoms,
1022
        remove_hydrogens: bool = False
1023
) -> PeriodicSet:
1024
    """Convert an :class:`ase.atoms.Atoms` object to a
1025
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
1026
    the option to remove disorder.
1027
1028
    Parameters
1029
    ----------
1030
    atoms : :class:`ase.atoms.Atoms`
1031
        An ase :class:`ase.atoms.Atoms` object representing a crystal.
1032
    remove_hydrogens : bool, optional
1033
        Remove Hydrogens from the crystal.
1034
1035
    Returns
1036
    -------
1037
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1038
        Represents the crystal as a periodic set, consisting of a finite
1039
        set of points (motif) and lattice (unit cell). Contains other
1040
        useful data, e.g. the crystal's name and information about the
1041
        asymmetric unit for calculation.
1042
1043
    Raises
1044
    ------
1045
    ParseError
1046
        Raised if there are no valid sites in atoms.
1047
    """
1048
1049
    from ase.spacegroup import get_basis
1050
1051
    cell = atoms.get_cell().array
1052
1053
    remove_inds = []
1054
    if remove_hydrogens:
1055
        for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
1056
            remove_inds.append(i)
1057
    for i in sorted(remove_inds, reverse=True):
1058
        atoms.pop(i)
1059
1060
    if len(atoms) == 0:
1061
        raise ParseError('ase Atoms object has no valid sites')
1062
1063
    # Symmetry operations from spacegroup
1064
    spg = None
1065
    if 'spacegroup' in atoms.info:
1066
        spg = atoms.info['spacegroup']
1067
        rot, trans = spg.rotations, spg.translations
1068
    else:
1069
        warnings.warn('no symmetry data found, defaulting to P1')
1070
        rot = np.identity(3)[None, :]
1071
        trans = np.zeros((1, 3))
1072
1073
    # Asymmetric unit. ase default tol is 1e-5
1074
    # do differently! get_basis determines a reduced asym unit from the atoms;
1075
    # surely this is not needed!
1076
    asym_unit = get_basis(atoms, spacegroup=spg, tol=_EQ_SITE_TOL)
1077
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1078
    _, wyc_muls = np.unique(invs, return_counts=True)
1079
    wyc_muls = wyc_muls.astype(np.int32)
1080
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1081
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1082
    motif = frac_motif @ cell
1083
1084
    return PeriodicSet(
1085
        motif=motif,
1086
        cell=cell,
1087
        asymmetric_unit=asym_inds,
1088
        wyckoff_multiplicities=wyc_muls,
1089
        types=atoms.get_atomic_numbers().astype(np.uint8)
1090
    )
1091
1092
1093
def periodicset_from_pymatgen_structure(
1094
        structure,
1095
        remove_hydrogens: bool = False,
1096
        disorder: str = 'skip'
1097
) -> PeriodicSet:
1098
    """Convert a :class:`pymatgen.core.structure.Structure` object to a
1099
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not set
1100
    the name of the periodic set, as pymatgen Structure objects seem to
1101
    have no name attribute.
1102
1103
    Parameters
1104
    ----------
1105
    structure : :class:`pymatgen.core.structure.Structure`
1106
        A pymatgen Structure object representing a crystal.
1107
    remove_hydrogens : bool, optional
1108
        Remove Hydrogens from the crystal.
1109
    disorder : str, optional
1110
        Controls how disordered structures are handled. Default is
1111
        ``skip`` which skips any crystal with disorder, since disorder
1112
        conflicts with the periodic set model. To read disordered
1113
        structures anyway, choose either :code:`ordered_sites` to remove
1114
        atoms with disorder or :code:`all_sites` include all atoms
1115
        regardless of disorder.
1116
1117
    Returns
1118
    -------
1119
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1120
        Represents the crystal as a periodic set, consisting of a finite
1121
        set of points (motif) and lattice (unit cell). Contains other
1122
        useful data, e.g. the crystal's name and information about the
1123
        asymmetric unit for calculation.
1124
1125
    Raises
1126
    ------
1127
    ParseError
1128
        Raised if the :code:`disorder == 'skip'` and
1129
        :code:`not structure.is_ordered`
1130
    """
1131
1132
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1133
1134
    if remove_hydrogens:
1135
        structure.remove_species(['H', 'D'])
1136
1137
    # Disorder
1138
    if disorder == 'skip':
1139
        if not structure.is_ordered:
1140
            raise ParseError(
1141
                'pymatgen Structure has disorder, pass '
1142
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1143
                'disorder'
1144
            )
1145
    elif disorder == 'ordered_sites':
1146
        remove_inds = []
1147
        for i, comp in enumerate(structure.species_and_occu):
1148
            if comp.num_atoms < 1:
1149
                remove_inds.append(i)
1150
        structure.remove_sites(remove_inds)
1151
1152
    motif = structure.cart_coords
1153
    cell = structure.lattice.matrix
1154
    sym_structure = SpacegroupAnalyzer(structure).get_symmetrized_structure()
1155
    eq_inds = sym_structure.equivalent_indices
1156
    asym_unit = np.array([ix_list[0] for ix_list in eq_inds], dtype=np.int32)
1157
    wyc_muls = np.array([len(ix_list) for ix_list in eq_inds], dtype=np.int32)
1158
    types = np.array(sym_structure.atomic_numbers, dtype=np.uint8)
1159
1160
    return PeriodicSet(
1161
        motif=motif,
1162
        cell=cell,
1163
        asymmetric_unit=asym_unit,
1164
        wyckoff_multiplicities=wyc_muls,
1165
        types=types
1166
    )
1167
1168
1169
def periodicset_from_ccdc_entry(
1170
        entry,
1171
        remove_hydrogens: bool = False,
1172
        disorder: str = 'skip',
1173
        heaviest_component: bool = False,
1174
        molecular_centres: bool = False
1175
) -> PeriodicSet:
1176
    """Convert a :class:`ccdc.entry.Entry` object to a
1177
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1178
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
1179
1180
    Parameters
1181
    ----------
1182
    entry : :class:`ccdc.entry.Entry`
1183
        A ccdc Entry object representing a database entry.
1184
    remove_hydrogens : bool, optional
1185
        Remove Hydrogens from the crystal.
1186
    disorder : str, optional
1187
        Controls how disordered structures are handled. Default is
1188
        ``skip`` which skips any crystal with disorder, since disorder
1189
        conflicts with the periodic set model. To read disordered
1190
        structures anyway, choose either :code:`ordered_sites` to remove
1191
        atoms with disorder or :code:`all_sites` include all atoms
1192
        regardless of disorder.
1193
    heaviest_component : bool, optional
1194
        Removes all but the heaviest molecule in the asymmeric unit,
1195
        intended for removing solvents.
1196
    molecular_centres : bool, default False
1197
        Extract the centres of molecules in the unit cell and store in
1198
        the attribute molecular_centres of the returned PeriodicSet.
1199
1200
    Returns
1201
    -------
1202
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1203
        Represents the crystal as a periodic set, consisting of a finite
1204
        set of points (motif) and lattice (unit cell). Contains other
1205
        useful data, e.g. the crystal's name and information about the
1206
        asymmetric unit for calculation.
1207
1208
    Raises
1209
    ------
1210
    ParseError
1211
        Raised if the structure fails parsing for any of the following:
1212
        1. entry.has_3d_structure is False, 2.
1213
        :code:``disorder == 'skip'`` and disorder is found on any atom,
1214
        3. entry.crystal.molecule.all_atoms_have_sites is False,
1215
        4. a.fractional_coordinates is None for any a in
1216
        entry.crystal.disordered_molecule, 5. The motif is empty after
1217
        removing Hydrogens and disordered sites.
1218
    """
1219
1220
    # Entry specific flag
1221
    if not entry.has_3d_structure:
1222
        raise ParseError(f'{entry.identifier} has no 3D structure')
1223
1224
    # Disorder
1225
    if disorder == 'skip' and entry.has_disorder:
1226
        raise ParseError(
1227
            f"{entry.identifier} has disorder, pass disorder='ordered_sites' "
1228
            "or 'all_sites' to remove/ignore disorder"
1229
        )
1230
1231
    return periodicset_from_ccdc_crystal(
1232
        entry.crystal,
1233
        remove_hydrogens=remove_hydrogens,
1234
        disorder=disorder,
1235
        heaviest_component=heaviest_component,
1236
        molecular_centres=molecular_centres
1237
    )
1238
1239
1240
def periodicset_from_ccdc_crystal(
1241
        crystal,
1242
        remove_hydrogens: bool = False,
1243
        disorder: str = 'skip',
1244
        heaviest_component: bool = False,
1245
        molecular_centres: bool = False
1246
) -> PeriodicSet:
1247
    """Convert a :class:`ccdc.crystal.Crystal` object to a
1248
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1249
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
1250
1251
    Parameters
1252
    ----------
1253
    crystal : :class:`ccdc.crystal.Crystal`
1254
        A ccdc Crystal object representing a crystal structure.
1255
    remove_hydrogens : bool, optional
1256
        Remove Hydrogens from the crystal.
1257
    disorder : str, optional
1258
        Controls how disordered structures are handled. Default is
1259
        ``skip`` which skips any crystal with disorder, since disorder
1260
        conflicts with the periodic set model. To read disordered
1261
        structures anyway, choose either :code:`ordered_sites` to remove
1262
        atoms with disorder or :code:`all_sites` include all atoms
1263
        regardless of disorder.
1264
    heaviest_component : bool, optional
1265
        Removes all but the heaviest molecule in the asymmeric unit,
1266
        intended for removing solvents.
1267
    molecular_centres : bool, default False
1268
        Extract the centres of molecules in the unit cell and store in
1269
        the attribute molecular_centres of the returned PeriodicSet.
1270
1271
    Returns
1272
    -------
1273
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1274
        Represents the crystal as a periodic set, consisting of a finite
1275
        set of points (motif) and lattice (unit cell). Contains other
1276
        useful data, e.g. the crystal's name and information about the
1277
        asymmetric unit for calculation.
1278
1279
    Raises
1280
    ------
1281
    ParseError
1282
        Raised if the structure fails parsing for any of the following:
1283
        1. :code:``disorder == 'skip'`` and disorder is found on any
1284
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
1285
        3. a.fractional_coordinates is None for any a in
1286
        crystal.disordered_molecule, 4. The motif is empty after
1287
        removing H, disordered sites or solvents.
1288
    """
1289
1290
    molecule = crystal.disordered_molecule
1291
1292
    # Disorder
1293
    if disorder == 'skip':
1294
        if crystal.has_disorder or \
1295
         any(_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
1296
            raise ParseError(
1297
                f"{crystal.identifier} has disorder, pass "
1298
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1299
                "disorder"
1300
            )
1301
    elif disorder == 'ordered_sites':
1302
        molecule.remove_atoms(
1303
            a for a in molecule.atoms if _has_disorder(a.label, a.occupancy)
1304
        )
1305
1306
    if remove_hydrogens:
1307
        molecule.remove_atoms(
1308
            a for a in molecule.atoms if a.atomic_symbol in 'HD'
1309
        )
1310
1311
    if heaviest_component and len(molecule.components) > 1:
1312
        molecule = _heaviest_component_ccdc(molecule)
1313
1314
    # Remove atoms with missing coordinates and warn
1315
    if any(a.fractional_coordinates is None for a in molecule.atoms):
1316
        warnings.warn('atoms without sites or missing data will be removed')
1317
        molecule.remove_atoms(
1318
            a for a in molecule.atoms if a.fractional_coordinates is None
1319
        )
1320
1321
    crystal.molecule = molecule
1322
    cellpar = crystal.cell_lengths + crystal.cell_angles
1323
    if None in cellpar:
1324
        raise ParseError(f'{crystal.identifier} has missing cell data')
1325
    cell = cellpar_to_cell(np.array(cellpar))
1326
1327
    if molecular_centres:
1328
        frac_centres = _frac_molecular_centres_ccdc(crystal, _EQ_SITE_TOL)
1329
        mol_centres = frac_centres @ cell
1330
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
1331
1332
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
1333
    # check for None?
1334
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
1335
1336
    if asym_unit.shape[0] == 0:
1337
        raise ParseError(f'{crystal.identifier} has no valid sites')
1338
1339
    asym_unit = np.mod(asym_unit, 1)
1340
1341
    # recommended by pymatgen
1342
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1343
1344
    asym_types = [a.atomic_number for a in asym_atoms]
1345
1346
    # Remove overlapping sites unless disorder == 'all_sites'
1347
    if disorder != 'all_sites':
1348
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1349
        if not np.all(keep_sites):
1350
            warnings.warn(
1351
                'may have overlapping sites; duplicates will be removed'
1352
            )
1353
        asym_unit = asym_unit[keep_sites]
1354
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1355
1356
    # Symmetry operations
1357
    sitesym = crystal.symmetry_operators
1358
    # try spacegroup numbers?
1359
    if not sitesym:
1360
        warnings.warn('no symmetry data found, defaulting to P1')
1361
        sitesym = ['x,y,z']
1362
1363
    # Apply symmetries to asymmetric unit
1364
    rot, trans = _parse_sitesyms(sitesym)
1365
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1366
    _, wyc_muls = np.unique(invs, return_counts=True)
1367
    wyc_muls = wyc_muls.astype(np.int32)
1368
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1369
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1370
    motif = frac_motif @ cell
1371
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1372
1373
    return PeriodicSet(
1374
        motif=motif,
1375
        cell=cell,
1376
        name=crystal.identifier,
1377
        asymmetric_unit=asym_inds,
1378
        wyckoff_multiplicities=wyc_muls,
1379
        types=types
1380
    )
1381
1382
1383
def _parse_sitesyms(
1384
        symmetries: List[str]
1385
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
1386
    """Parse a sequence of symmetries in xyz form and return rotation
1387
    and translation arrays.
1388
    """
1389
1390
    n_syms = len(symmetries)
1391
    rotations = np.zeros((n_syms, 3, 3), dtype=np.float64)
1392
    translations = np.zeros((n_syms, 3), dtype=np.float64)
1393
1394
    for i, sym in enumerate(symmetries):
1395
        for ind, element in enumerate(sym.split(',')):
1396
1397
            is_positive = True
1398
            is_fraction = False
1399
            sng_trans = None
1400
            fst_trans = []
1401
            snd_trans = []
1402
1403
            for char in element.lower():
1404
                if char == '+':
1405
                    is_positive = True
1406
                elif char == '-':
1407
                    is_positive = False
1408
                elif char == '/':
1409
                    is_fraction = True
1410
                elif char in 'xyz':
1411
                    rot_sgn = 1.0 if is_positive else -1.0
1412
                    rotations[i][ind][ord(char) - ord('x')] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1413
                elif char.isdigit() or char == '.':
1414
                    if sng_trans is None:
1415
                        sng_trans = 1.0 if is_positive else -1.0
1416
                    if is_fraction:
1417
                        snd_trans.append(char)
1418
                    else:
1419
                        fst_trans.append(char)
1420
1421
            if not fst_trans:
1422
                e_trans = 0.0
1423
            else:
1424
                e_trans = sng_trans * float(''.join(fst_trans))
1425
1426
            if is_fraction:
1427
                e_trans /= float(''.join(snd_trans))
1428
1429
            translations[i][ind] = e_trans
1430
1431
    return rotations, translations
1432
1433
1434
def _expand_asym_unit(
1435
        asym_unit: npt.NDArray,
1436
        rotations: npt.NDArray,
1437
        translations: npt.NDArray,
1438
        tol: float
1439
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int32]]:
1440
    """Expand the asymmetric unit by applying symmetries given by
1441
    ``rotations`` and ``translations``.
1442
    """
1443
1444
    asym_unit = asym_unit.astype(np.float64, copy=False)
1445
    rotations = rotations.astype(np.float64, copy=False)
1446
    translations = translations.astype(np.float64, copy=False)
1447
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1448
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1449
1450
    if not all(_unique_sites(frac_motif, tol)):
1451
        frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1452
1453
    return frac_motif, invs
1454
1455
1456
@numba.njit(cache=True)
1457
def _expand_sites(
1458
        asym_unit: npt.NDArray[np.float64],
1459
        rotations: npt.NDArray[np.float64],
1460
        translations: npt.NDArray[np.float64]
1461
) -> npt.NDArray[np.float64]:
1462
    """Expand the asymmetric unit by applying ``rotations`` and
1463
    ``translations``, without yet removing points duplicated because
1464
    they are invariant under a symmetry. Returns a 3D array shape
1465
    (#points, #syms, dims).
1466
    """
1467
1468
    m, dims = asym_unit.shape
1469
    n_syms = len(rotations)
1470
    expanded_sites = np.empty((m, n_syms, dims), dtype=np.float64)
1471
    for i in range(m):
1472
        p = asym_unit[i]
1473
        for j in range(n_syms):
1474
            expanded_sites[i, j] = np.dot(rotations[j], p) + translations[j]
1475
    expanded_sites = np.mod(expanded_sites, 1)
1476
    return expanded_sites
1477
1478
1479
@numba.njit(cache=True)
1480
def _reduce_expanded_sites(
1481
        expanded_sites: npt.NDArray[np.float64],
1482
        tol: float
1483
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int32]]:
1484
    """Reduce the asymmetric unit after being expended by symmetries by
1485
    removing invariant points. This is the fast version which works in
1486
    the case that no two sites in the asymmetric unit are equivalent.
1487
    If they are, the reduction is re-ran with
1488
    _reduce_expanded_equiv_sites() to account for it.
1489
    """
1490
1491
    all_unqiue_inds = []
1492
    n_sites, _, dims = expanded_sites.shape
1493
    multiplicities = np.zeros(shape=(n_sites, ))
1494
1495
    for i, sites in enumerate(expanded_sites):
1496
        unique_inds = _unique_sites(sites, tol)
1497
        all_unqiue_inds.append(unique_inds)
1498
        multiplicities[i] = np.sum(unique_inds)
1499
1500
    m = int(np.sum(multiplicities))
1501
    frac_motif = np.zeros(shape=(m, dims))
1502
    inverses = np.zeros(shape=(m, ), dtype=np.int32)
1503
1504
    s = 0
1505
    for i in range(n_sites):
1506
        t = s + multiplicities[i]
1507
        frac_motif[s:t, :] = expanded_sites[i][all_unqiue_inds[i]]
1508
        inverses[s:t] = i
1509
        s = t
1510
1511
    return frac_motif, inverses
1512
1513
1514
def _reduce_expanded_equiv_sites(
1515
        expanded_sites: npt.NDArray[np.float64],
1516
        tol: float
1517
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int32]]:
1518
    """Reduce the asymmetric unit after being expended by symmetries by
1519
    removing invariant points. This is the slower version, called after
1520
    the fast version if we find equivalent motif points which need to be
1521
    removed.
1522
    """
1523
1524
    sites = expanded_sites[0]
1525
    unique_inds = _unique_sites(sites, tol)
1526
    frac_motif = sites[unique_inds]
1527
    inverses = [0] * len(frac_motif)
1528
1529
    for i in range(1, len(expanded_sites)):
1530
        sites = expanded_sites[i]
1531
        unique_inds = _unique_sites(sites, tol)
1532
1533
        points = []
1534
        for site in sites[unique_inds]:
1535
            diffs1 = np.abs(site - frac_motif)
1536
            diffs2 = np.abs(diffs1 - 1)
1537
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1538
1539
            if not np.any(mask):
1540
                points.append(site)
1541
            else:
1542
                warnings.warn(
1543
                    'has equivalent sites at positions '
1544
                    f'{inverses[np.argmax(mask)]}, {i}'
1545
                )
1546
1547
        if points:
1548
            inverses.extend(i for _ in range(len(points)))
1549
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1550
1551
    return frac_motif, np.array(inverses, dtype=np.int32)
1552
1553
1554
@numba.njit(cache=True)
1555
def _unique_sites(
1556
        asym_unit: npt.NDArray[np.float64], tol: float
1557
) -> npt.NDArray[np.bool_]:
1558
    """Uniquify (within tol) a list of fractional coordinates,
1559
    considering all points modulo 1. Return an array of bools such that
1560
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1561
    """
1562
1563
    m, _ = asym_unit.shape
1564
    where_unique = np.full(shape=(m, ), fill_value=True)
1565
1566
    for i in range(1, m):
1567
        asym_unit[i]
1568
        site_diffs1 = np.abs(asym_unit[:i, :] - asym_unit[i])
1569
        site_diffs2 = np.abs(site_diffs1 - 1)
1570
        sites_neq_mask = (site_diffs1 > tol) & (site_diffs2 > tol)
1571
        if not np.all(np.sum(sites_neq_mask, axis=-1)):
1572
            where_unique[i] = False
1573
1574
    return where_unique
1575
1576
1577
def _has_disorder(label: str, occupancy) -> bool:
1578
    """Return True if label ends with ? or occupancy is a number < 1."""
1579
    try:
1580
        occupancy = float(occupancy)
1581
    except Exception:
1582
        occupancy = 1
1583
    return (occupancy < 1) or label.endswith('?')
1584
1585
1586
def _get_syms_pymatgen(
1587
        data: dict
1588
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
1589
    """Parse symmetry operations given by data = block.data where block
1590
    is a pymatgen CifBlock object. If the symops are not present the
1591
    space group symbol/international number is parsed and symops are
1592
    generated.
1593
    """
1594
1595
    from pymatgen.symmetry.groups import SpaceGroup
1596
    import pymatgen.io.cif
1597
1598
    # Try xyz symmetry operations
1599
    for symmetry_label in _CIF_TAGS['symop']:
1600
        xyz = data.get(symmetry_label)
1601
        if not xyz:
1602
            continue
1603
        if isinstance(xyz, str):
1604
            xyz = [xyz]
1605
        return _parse_sitesyms(xyz)
1606
1607
    symops = []
1608
    # Try spacegroup symbol
1609
    for symmetry_label in _CIF_TAGS['spacegroup_name']:
1610
        sg = data.get(symmetry_label)
1611
        if not sg:
1612
            continue
1613
        sg = re.sub(r'[\s_]', '', sg)
1614
        try:
1615
            spg = pymatgen.io.cif.space_groups.get(sg)
1616
            if not spg:
1617
                continue
1618
            symops = SpaceGroup(spg).symmetry_ops
1619
            break
1620
        except ValueError:
1621
            pass
1622
        try:
1623
            for d in pymatgen.io.cif._get_cod_data():
1624
                if sg == re.sub(r'\s+', '', d['hermann_mauguin']):
1625
                    return _parse_sitesyms(d['symops'])
1626
        except Exception:
1627
            continue
1628
        if symops:
1629
            break
1630
1631
    # Try international number
1632
    if not symops:
1633
        for symmetry_label in _CIF_TAGS['spacegroup_number']:
1634
            num = data.get(symmetry_label)
1635
            if not num:
1636
                continue
1637
            try:
1638
                i = int(pymatgen.io.cif.str2float(num))
1639
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1640
                break
1641
            except ValueError:
1642
                continue
1643
1644
    if not symops:
1645
        warnings.warn('no symmetry data found, defaulting to P1')
1646
        return _parse_sitesyms(['x,y,z'])
1647
1648
    rotations = [op.rotation_matrix for op in symops]
1649
    translations = [op.translation_vector for op in symops]
1650
    rotations = np.array(rotations, dtype=np.float64)
1651
    translations = np.array(translations, dtype=np.float64)
1652
1653
    return rotations, translations
1654
1655
1656
def _frac_molecular_centres_ccdc(
1657
        crystal, tol: float
1658
) -> npt.NDArray[np.float64]:
1659
    """Return the geometric centres of molecules in the unit cell.
1660
    Expects a ccdc Crystal object and returns fractional coordiantes.
1661
    """
1662
1663
    frac_centres = []
1664
    for comp in crystal.packing(inclusion='CentroidIncluded').components:
1665
        coords = [a.fractional_coordinates for a in comp.atoms]
1666
        frac_centres.append([sum(ax) / len(coords) for ax in zip(*coords)])
1667
    frac_centres = np.mod(np.array(frac_centres, dtype=np.float64), 1)
1668
    return frac_centres[_unique_sites(frac_centres, tol)]
1669
1670
1671
def _heaviest_component_ccdc(molecule):
1672
    """Remove all but the heaviest component of the asymmetric unit.
1673
    Intended for removing solvents. Expects and returns a ccdc Molecule
1674
    object.
1675
    """
1676
1677
    component_weights = []
1678
    for component in molecule.components:
1679
        weight = 0
1680
        for a in component.atoms:
1681
            try:
1682
                occ = float(a.occupancy)
1683
            except ValueError:
1684
                occ = 1
1685
            try:
1686
                weight += float(a.atomic_weight) * occ
1687
            except ValueError:
1688
                pass
1689
        component_weights.append(weight)
1690
    largest_component_ind = np.argmax(np.array(component_weights))
1691
    molecule = molecule.components[largest_component_ind]
1692
    return molecule
1693
1694
1695
def _snap_small_prec_coords(
1696
        frac_coords: npt.NDArray[np.float64], tol: float
1697
) -> npt.NDArray[np.float64]:
1698
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1699
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1700
    """
1701
1702
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.
1703
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.
1704
    return frac_coords
1705