Test Failed
Push — master ( baece5...0f0bce )
by Daniel
07:51
created

amd.io.periodicset_from_pymatgen_cifblock()   F

Complexity

Conditions 28

Size

Total Lines 173
Code Lines 96

Duplication

Lines 24
Ratio 13.87 %

Importance

Changes 0
Metric Value
eloc 96
dl 24
loc 173
rs 0
c 0
b 0
f 0
cc 28
nop 3

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd.io.periodicset_from_pymatgen_cifblock() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import warnings
9
import os
10
import re
11
import functools
12
import errno
13
import math
14
import json
15
from pathlib import Path
16
from typing import Iterable, Iterator, Optional, Union, Callable, Tuple, List
17
18
import numpy as np
19
import numba
20
import tqdm
21
22
from .utils import cellpar_to_cell
23
from .periodicset import PeriodicSet
24
25
__all__ = [
26
    'CifReader', 'CSDReader', 'ParseError', 'periodicset_from_gemmi_block',
27
    'periodicset_from_ase_cifblock', 'periodicset_from_pymatgen_cifblock',
28
    'periodicset_from_ase_atoms', 'periodicset_from_pymatgen_structure',
29
    'periodicset_from_ccdc_entry', 'periodicset_from_ccdc_crystal'
30
]
31
32
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
33
    return f'{category.__name__}: {message}\n'
34
35
warnings.formatwarning = _custom_warning
36
37
with open(str(Path(__file__).absolute().parent / 'atomic_numbers.json')) as f:
38
    _ATOMIC_NUMBERS = json.load(f)
39
40
_EQ_SITE_TOL: float = 1e-3
41
_CIF_TAGS: dict = {
42
    'cellpar': [
43
        '_cell_length_a', '_cell_length_b', '_cell_length_c',
44
        '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma'],
45
    'atom_site_fract': [
46
        '_atom_site_fract_x', '_atom_site_fract_y', '_atom_site_fract_z'],
47
    'atom_site_cartn': [
48
        '_atom_site_Cartn_x', '_atom_site_Cartn_y', '_atom_site_Cartn_z'],
49
    'symop': [
50
        '_space_group_symop_operation_xyz', '_space_group_symop.operation_xyz',
51
        '_symmetry_equiv_pos_as_xyz'],
52
    'spacegroup_name': [
53
        '_space_group_name_H-M_alt', '_symmetry_space_group_name_H-M'],
54
    'spacegroup_number': [
55
        '_space_group_IT_number', '_symmetry_Int_Tables_number']
56
}
57
58
59
class _Reader:
60
    """Base reader class."""
61
62
    def __init__(
63
            self,
64
            iterable: Iterable,
65
            converter: Callable[..., PeriodicSet],
66
            show_warnings: bool,
67
            verbose: bool
68
    ):
69
70
        self._iterator = iter(iterable)
71
        self._converter = converter
72
        self.show_warnings = show_warnings
73
        if verbose:
74
            self._progress_bar = tqdm.tqdm(desc='Reading', delay=1)
75
        else:
76
            self._progress_bar = None
77
78
    def __iter__(self):
79
        return self
80
81
    def __next__(self):
82
        """Iterate over self._iterator, passing items through
83
        self._converter and yielding. If
84
        :class:`ParseError <.io.ParseError>` is raised in a call to
85
        self._converter, the item is skipped. Warnings raised in
86
        self._converter are printed if self.show_warnings is True.
87
        """
88
89
        if not self.show_warnings:
90
            warnings.simplefilter('ignore')
91
92
        while True:
93
94
            try:
95
                item = next(self._iterator)
96
            except StopIteration:
97
                if self._progress_bar is not None:
98
                    self._progress_bar.close()
99
                raise StopIteration
100
101
            parse_err_msg = None
102
            with warnings.catch_warnings(record=True) as warning_msgs:
103
                try:
104
                    periodic_set = self._converter(item)
105
                except ParseError as err:
106
                    parse_err_msg = str(err)
107
108
            if self._progress_bar is not None:
109
                self._progress_bar.update(1)
110
111
            if parse_err_msg is not None:
112
                warnings.warn(parse_err_msg)
113
                continue
114
115
            for warning in warning_msgs:
116
                msg = f'(name={periodic_set.name}) {warning.message}'
117
                warnings.warn(msg, category=warning.category)
118
119
            return periodic_set
120
121
    def read(self) -> Union[PeriodicSet, List[PeriodicSet]]:
122
        """Read the crystal(s), return one
123
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
124
        only one, otherwise return a list. If there the structure cannot
125
        be parsed, return None.
126
        """
127
        items = list(self)
128
        if len(items) == 1:
129
            return items[0]
130
        return items
131
132
133
class CifReader(_Reader):
134
    """Read all structures in a .cif file or all files in a folder
135
    with ase or csd-python-api (if installed), yielding
136
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
137
138
    Parameters
139
    ----------
140
    path : str
141
        Path to a .CIF file or directory. (Other files are accepted when
142
        using ``reader='ccdc'``, if csd-python-api is installed.)
143
    reader : str, optional
144
        The backend package used to parse the CIF. The default is
145
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
146
        accepted, as well as :code:`ccdc` if csd-python-api is
147
        installed. The ccdc reader should be able to read any format
148
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
149
        been tested.
150
    remove_hydrogens : bool, optional
151
        Remove Hydrogens from the crystals.
152
    disorder : str, optional
153
        Controls how disordered structures are handled. Default is
154
        ``skip`` which skips any crystal with disorder, since disorder
155
        conflicts with the periodic set model. To read disordered
156
        structures anyway, choose either :code:`ordered_sites` to remove
157
        atoms with disorder or :code:`all_sites` include all atoms
158
        regardless of disorder.
159
    heaviest_component : bool, optional
160
        csd-python-api only. Removes all but the heaviest molecule in
161
        the asymmeric unit, intended for removing solvents.
162
    molecular_centres : bool, default False
163
        csd-python-api only. Extract the centres of molecules in the
164
        unit cell and store in the attribute molecular_centres.
165
    show_warnings : bool, optional
166
        Controls whether warnings that arise during reading are printed.
167
    verbose : bool, default False
168
        If True, prints a progress bar showing the number of items
169
        processed.
170
171
    Yields
172
    ------
173
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
174
        Represents the crystal as a periodic set, consisting of a finite
175
        set of points (motif) and lattice (unit cell). Contains other
176
        data, e.g. the crystal's name and information about the
177
        asymmetric unit.
178
179
    Examples
180
    --------
181
182
        ::
183
184
            # Put all crystals in a .CIF in a list
185
            structures = list(amd.CifReader('mycif.cif'))
186
187
            # Can also accept path to a directory, reading all files inside
188
            structures = list(amd.CifReader('path/to/folder'))
189
190
            # Reads just one if the .CIF has just one crystal
191
            periodic_set = amd.CifReader('mycif.cif').read()
192
193
            # List of AMDs (k=100) of crystals in a .CIF
194
            amds = [amd.AMD(item, 100) for item in amd.CifReader('mycif.cif')]
195
    """
196
197
    def __init__(
198
            self,
199
            path: Union[str, os.PathLike],
200
            reader: str = 'gemmi',
201
            remove_hydrogens: bool = False,
202
            disorder: str = 'skip',
203
            heaviest_component: bool = False,
204
            molecular_centres: bool = False,
205
            show_warnings: bool = True,
206
            verbose: bool = False
207
    ):
208
209
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
210
            raise ValueError(
211
                f"'disorder'' parameter of {self.__class__.__name__} must be "
212
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
213
                f"'{disorder}')"
214
            )
215
216
        if reader != 'ccdc':
217
            if heaviest_component:
218
                raise NotImplementedError(
219
                    "'heaviest_component' parameter of "
220
                    f"{self.__class__.__name__} only implemented with "
221
                    "csd-python-api, if installed pass reader='ccdc'"
222
                )
223
            if molecular_centres:
224
                raise NotImplementedError(
225
                    "'molecular_centres' parameter of "
226
                    f"{self.__class__.__name__} only implemented with "
227
                    "csd-python-api, if installed pass reader='ccdc'"
228
                )
229
230
        # cannot handle some characters (�) in cifs
231
        if reader == 'gemmi':
232
            import gemmi
233
            extensions = {'cif'}
234
            file_parser = gemmi.cif.read_file
235
            converter = functools.partial(
236
                periodicset_from_gemmi_block,
237
                remove_hydrogens=remove_hydrogens,
238
                disorder=disorder
239
            )
240
241
        elif reader in ('ase', 'pycodcif'):
242
            from ase.io.cif import parse_cif
243
            extensions = {'cif'}
244
            file_parser = functools.partial(parse_cif, reader=reader)
245
            converter = functools.partial(
246
                periodicset_from_ase_cifblock,
247
                remove_hydrogens=remove_hydrogens,
248
                disorder=disorder
249
            )
250
251
        elif reader == 'pymatgen':
252
253
            def _pymatgen_cif_parser(path):
254
                from pymatgen.io.cif import CifFile
255
                return CifFile.from_file(path).data.values()
256
257
            extensions = {'cif'}
258
            file_parser = _pymatgen_cif_parser
259
            converter = functools.partial(
260
                periodicset_from_pymatgen_cifblock,
261
                remove_hydrogens=remove_hydrogens,
262
                disorder=disorder
263
            )
264
265
        elif reader == 'ccdc':
266
            try:
267
                import ccdc.io
268
            except (ImportError, RuntimeError) as e:
269
                raise ImportError('Failed to import csd-python-api') from e
270
271
            extensions = set(ccdc.io.EntryReader.known_suffixes.keys())
272
            file_parser = ccdc.io.EntryReader
273
            converter = functools.partial(
274
                periodicset_from_ccdc_entry,
275
                remove_hydrogens=remove_hydrogens,
276
                disorder=disorder,
277
                molecular_centres=molecular_centres,
278
                heaviest_component=heaviest_component
279
            )
280
281
        else:
282
            raise ValueError(
283
                f"'reader' parameter of {self.__class__.__name__} must be one "
284
                f"of 'gemmi', 'pymatgen', 'ccdc', 'ase', or 'pycodcif' "
285
                f"(passed '{reader}')"
286
            )
287
288
        path = Path(path)
289
        if path.is_file():
290
            iterable = file_parser(str(path))
291
        elif path.is_dir():
292
            iterable = CifReader._dir_generator(path, file_parser, extensions)
293
        else:
294
            raise FileNotFoundError(
295
                errno.ENOENT, os.strerror(errno.ENOENT), path
296
            )
297
298
        super().__init__(iterable, converter, show_warnings, verbose)
299
300
    @staticmethod
301
    def _dir_generator(
302
            path: Path,
303
            file_parser: Callable,
304
            extensions: Iterable
305
    ) -> Iterator:
306
        for file_path in path.iterdir():
307
            if not file_path.is_file():
308
                continue
309
            if file_path.suffix[1:].lower() not in extensions:
310
                continue
311
            try:
312
                yield from file_parser(str(file_path))
313
            except Exception as e:
314
                warnings.warn(
315
                    f'Error parsing "{str(file_path)}", skipping: {str(e)}'
316
                )
317
318
319
class CSDReader(_Reader):
320
    """Read structures from the CSD with csd-python-api, yielding
321
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
322
323
    Parameters
324
    ----------
325
    refcodes : str or List[str], optional
326
        Single or list of CSD refcodes to read. If None or 'CSD',
327
        iterates over the whole CSD.
328
    families : bool, optional
329
        Read all entries whose refcode starts with the given strings, or
330
        'families' (e.g. giving 'DEBXIT' reads all entries starting with
331
        DEBXIT).
332
    remove_hydrogens : bool, optional
333
        Remove hydrogens from the crystals.
334
    disorder : str, optional
335
        Controls how disordered structures are handled. Default is
336
        ``skip`` which skips any crystal with disorder, since disorder
337
        conflicts with the periodic set model. To read disordered
338
        structures anyway, choose either :code:`ordered_sites` to remove
339
        atoms with disorder or :code:`all_sites` include all atoms
340
        regardless of disorder.
341
    heaviest_component : bool, optional
342
        Removes all but the heaviest molecule in the asymmeric unit,
343
        intended for removing solvents.
344
    molecular_centres : bool, default False
345
        Extract the centres of molecules in the unit cell and store in
346
        attribute molecular_centres.
347
    show_warnings : bool, optional
348
        Controls whether warnings that arise during reading are printed.
349
    verbose : bool, default False
350
        If True, prints a progress bar showing the number of items
351
        processed.
352
353
    Yields
354
    ------
355
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
356
        Represents the crystal as a periodic set, consisting of a finite
357
        set of points (motif) and lattice (unit cell). Contains other
358
        useful data, e.g. the crystal's name and information about the
359
        asymmetric unit for calculation.
360
361
    Examples
362
    --------
363
364
        ::
365
366
            # Put these entries in a list
367
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
368
            structures = list(amd.CSDReader(refcodes))
369
370
            # Read refcode families (any whose refcode starts with strings in the list)
371
            refcode_families = ['ACSALA', 'HXACAN']
372
            structures = list(amd.CSDReader(refcode_families, families=True))
373
374
            # Get AMDs (k=100) for crystals in these families
375
            refcodes = ['ACSALA', 'HXACAN']
376
            amds = []
377
            for periodic_set in amd.CSDReader(refcodes, families=True):
378
                amds.append(amd.AMD(periodic_set, 100))
379
380
            # Giving the reader nothing reads from the whole CSD.
381
            for periodic_set in amd.CSDReader():
382
                ...
383
    """
384
385
    def __init__(
386
            self,
387
            refcodes: Optional[Union[str, List[str]]] = None,
388
            families: bool = False,
389
            remove_hydrogens: bool = False,
390
            disorder: str = 'skip',
391
            heaviest_component: bool = False,
392
            molecular_centres: bool = False,
393
            show_warnings: bool = True,
394
            verbose: bool = False
395
    ):
396
397
        if disorder not in ('skip', 'ordered_sites', 'all_sites'):
398
            raise ValueError(
399
                f"'disorder'' parameter of {self.__class__.__name__} must be "
400
                f"one of 'skip', 'ordered_sites' or 'all_sites' (passed "
401
                f"'{disorder}')"
402
            )
403
404
        try:
405
            import ccdc.search
406
            import ccdc.io
407
        except (ImportError, RuntimeError) as e:
408
            raise ImportError('Failed to import csd-python-api') from e
409
410
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
411
            refcodes = None
412
        if refcodes is None:
413
            families = False
414
        elif isinstance(refcodes, str):
415
            refcodes = [refcodes]
416
        elif isinstance(refcodes, list):
417
            if not all(isinstance(refcode, str) for refcode in refcodes):
418
                raise ValueError(
419
                    f'{self.__class__.__name__} expects None, a string or '
420
                    'list of strings.'
421
                )
422
        else:
423
            raise ValueError(
424
                f'{self.__class__.__name__} expects None, a string or list of '
425
                f'strings, got {refcodes.__class__.__name__}'
426
            )
427
428
        if families:
429
            all_refcodes = []
430
            for refcode in refcodes:
431
                query = ccdc.search.TextNumericSearch()
432
                query.add_identifier(refcode)
433
                hits = [hit.identifier for hit in query.search()]
434
                all_refcodes.extend(hits)
435
            # filter to unique refcodes while keeping order
436
            refcodes = []
437
            seen = set()
438
            for refcode in all_refcodes:
439
                if refcode not in seen:
440
                    refcodes.append(refcode)
441
                    seen.add(refcode)
442
443
        converter = functools.partial(
444
            periodicset_from_ccdc_entry,
445
            remove_hydrogens=remove_hydrogens,
446
            disorder=disorder,
447
            molecular_centres=molecular_centres,
448
            heaviest_component=heaviest_component
449
        )
450
451
        entry_reader = ccdc.io.EntryReader('CSD')
452
        if refcodes is None:
453
            iterable = entry_reader
454
        else:
455
            iterable = map(entry_reader.entry, refcodes)
456
457
        super().__init__(iterable, converter, show_warnings, verbose)
458
459
460
class ParseError(ValueError):
461
    """Raised when an item cannot be parsed into a periodic set."""
462
    pass
463
464
465
def periodicset_from_gemmi_block(
466
        block,
467
        remove_hydrogens: bool = False,
468
        disorder: str = 'skip'
469
) -> PeriodicSet:
470
    """Convert a :class:`gemmi.cif.Block` object to a
471
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
472
    :class:`gemmi.cif.Block` is the type returned by
473
    :func:`gemmi.cif.read_file`.
474
475
    Parameters
476
    ----------
477
    block : :class:`gemmi.cif.Block`
478
        An ase CIFBlock object representing a crystal.
479
    remove_hydrogens : bool, optional
480
        Remove Hydrogens from the crystal.
481
    disorder : str, optional
482
        Controls how disordered structures are handled. Default is
483
        ``skip`` which skips any crystal with disorder, since disorder
484
        conflicts with the periodic set model. To read disordered
485
        structures anyway, choose either :code:`ordered_sites` to remove
486
        atoms with disorder or :code:`all_sites` include all atoms
487
        regardless of disorder.
488
489
    Returns
490
    -------
491
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
492
        Represents the crystal as a periodic set, consisting of a finite
493
        set of points (motif) and lattice (unit cell). Contains other
494
        useful data, e.g. the crystal's name and information about the
495
        asymmetric unit for calculation.
496
497
    Raises
498
    ------
499
    ParseError
500
        Raised if the structure fails to be parsed for any of the
501
        following: 1. Required data is missing (e.g. cell parameters),
502
        2. :code:``disorder == 'skip'`` and disorder is found on any
503
        atom, 3. The motif is empty after removing H or disordered
504
        sites.
505
    """
506
507
    import gemmi
508
    from gemmi.cif import as_number, as_string, as_int
509
510
    # Unit cell
511
    cellpar = [block.find_value(t) for t in _CIF_TAGS['cellpar']]
512
    if not all(isinstance(par, str) for par in cellpar):
513
        raise ParseError(f'{block.name} has missing cell data')
514
    cellpar = np.array([as_number(par) for par in cellpar])
515
    if np.isnan(np.sum(cellpar)):
516
        raise ParseError(f'{block.name} has missing cell data')
517
    cell = cellpar_to_cell(cellpar)
518
519
    # Asymmetric unit coordinates
520
    xyz_loop = block.find(_CIF_TAGS['atom_site_fract']).loop
521
    if xyz_loop is None:
522
        xyz_loop = block.find(_CIF_TAGS['atom_site_cartn']).loop
523
        if xyz_loop is None:
524
            raise ParseError(f'{block.name} has missing coordinate data')
525
        else:
526
            raise ParseError(
527
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
528
                'only _atom_site_fract_ is supported'
529
            )
530
531
    tablified_loop = [[] for _ in range(len(xyz_loop.tags))]
532
    n_cols = xyz_loop.width()
533
    for i, item in enumerate(xyz_loop.values):
534
        tablified_loop[i % n_cols].append(item)
535
    loop_dict = {tag: l for tag, l in zip(xyz_loop.tags, tablified_loop)}
536
    xyz_str = [loop_dict[t] for t in _CIF_TAGS['atom_site_fract']]
537
    asym_unit = np.transpose(np.array(
538
        [[as_number(c) for c in coords] for coords in xyz_str]
539
    ))
540
    asym_unit = np.mod(asym_unit, 1)
541
    # recommended by pymatgen
542
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4) 
543
544
    # Labels
545
    if '_atom_site_label' in loop_dict:
546
        labels = [as_string(label) for label in loop_dict['_atom_site_label']]
547
    else:
548
        labels = [''] * xyz_loop.length()
549
550
    # Atomic types
551
    if '_atom_site_type_symbol' in loop_dict:
552
        symbols = [as_string(s) for s in loop_dict['_atom_site_type_symbol']]
553
    else:
554
        symbols = []
555
        for label in labels:
556
            sym = ''
557
            if label:
558
                match = re.search(r'([A-Z][a-z]?)', label)
559
                if match is not None:
560
                    sym = match.group() 
561
            symbols.append(sym)
562
    asym_types = [_ATOMIC_NUMBERS[s] for s in symbols]
563
564
    # Occupancies
565
    if '_atom_site_occupancy' in loop_dict:
566
        occs = [as_number(occ) for occ in loop_dict['_atom_site_occupancy']]
567
        occupancies = [occ if not math.isnan(occ) else 1 for occ in occs]
568
    else:
569
        occupancies = [1] * xyz_loop.length()
570
571
    # Remove sites with missing coordinates, disorder and Hydrogens if needed
572
    remove_sites = []
573
    remove_sites.extend(np.nonzero(np.isnan(asym_unit.min(axis=-1)))[0])
574
575
    if disorder == 'skip':
576
        if any(
577
            _has_disorder(lab, occ) for lab, occ in zip(labels, occupancies)
578
        ):
579
            raise ParseError(
580
                f"{block.name} has disorder, pass disorder='ordered_sites' or "
581
                "'all_sites' to remove/ignore disorder"
582
            )
583
    elif disorder == 'ordered_sites':
584
        for i, (label, occ) in enumerate(zip(labels, occupancies)):
585
            if _has_disorder(label, occ):
586
                remove_sites.append(i)
587
588
    if remove_hydrogens:
589
        remove_sites.extend(
590
            i for i, num in enumerate(asym_types) if num == 1
591
        )
592
593
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
594
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
595
    if asym_unit.shape[0] == 0:
596
        raise ParseError(f'{block.name} has no valid sites')
597
598
    if disorder != 'all_sites':
599
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
600
        if not np.all(keep_sites):
601
            warnings.warn(
602
                'may have overlapping sites; duplicates will be removed'
603
            )
604
        asym_unit = asym_unit[keep_sites]
605
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
606
607
    # Symmetry operations, try xyz strings first
608
    for tag in _CIF_TAGS['symop']:
609
        sitesym = [v.str(0) for v in block.find([tag])]
610
        if sitesym:
611
            rot, trans = _parse_sitesyms(sitesym)
612
            break
613
    else:
614
        # Try spacegroup name; can be a pair or in a loop
615
        spg = None
616
        for tag in _CIF_TAGS['spacegroup_name']:
617
            for value in block.find([tag]):
618
                try:
619
                    # Some names cannot be parsed by gemmi.SpaceGroup
620
                    spg = gemmi.SpaceGroup(value.str(0))
621
                    break
622
                except ValueError:
623
                    continue
624
            if spg is not None:
625
                break
626
627
        if spg is None:
628
            # Try international number
629
            for tag in _CIF_TAGS['spacegroup_number']:
630
                spg_num = block.find_value(tag)
631
                if spg_num is not None:
632
                    spg_num = as_int(spg_num)
633
                    break
634
            else:
635
                warnings.warn('no symmetry data found, defaulting to P1')
636
                spg_num = 1
637
            spg = gemmi.SpaceGroup(spg_num)
638
639
        rot = np.array([np.array(o.rot) / o.DEN for o in spg.operations()])
640
        trans = np.array([np.array(o.tran) / o.DEN for o in spg.operations()])
641
642
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
643
    _, wyc_muls = np.unique(invs, return_counts=True)
644
    wyc_muls = wyc_muls.astype(np.int32)
645
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
646
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
647
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
648
    motif = frac_motif @ cell
649
650
    return PeriodicSet(
651
        motif=motif,
652
        cell=cell,
653
        name=block.name,
654
        asymmetric_unit=asym_inds,
655
        wyckoff_multiplicities=wyc_muls,
656
        types=types
657
    )
658
659
660
def periodicset_from_ase_cifblock(
661
        block,
662
        remove_hydrogens: bool = False,
663
        disorder: str = 'skip'
664
) -> PeriodicSet:
665
    """Convert a :class:`ase.io.cif.CIFBlock` object to a 
666
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
667
    :class:`ase.io.cif.CIFBlock` is the type returned by
668
    :func:`ase.io.cif.parse_cif`.
669
670
    Parameters
671
    ----------
672
    block : :class:`ase.io.cif.CIFBlock`
673
        An ase :class:`ase.io.cif.CIFBlock` object representing a
674
        crystal.
675
    remove_hydrogens : bool, optional
676
        Remove Hydrogens from the crystal.
677
    disorder : str, optional
678
        Controls how disordered structures are handled. Default is
679
        ``skip`` which skips any crystal with disorder, since disorder
680
        conflicts with the periodic set model. To read disordered
681
        structures anyway, choose either :code:`ordered_sites` to remove
682
        atoms with disorder or :code:`all_sites` include all atoms
683
        regardless of disorder.
684
685
    Returns
686
    -------
687
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
688
        Represents the crystal as a periodic set, consisting of a finite
689
        set of points (motif) and lattice (unit cell). Contains other
690
        useful data, e.g. the crystal's name and information about the
691
        asymmetric unit for calculation.
692
693
    Raises
694
    ------
695
    ParseError
696
        Raised if the structure fails to be parsed for any of the
697
        following: 1. Required data is missing (e.g. cell parameters),
698
        2. The motif is empty after removing H or disordered sites,
699
        3. :code:``disorder == 'skip'`` and disorder is found on any
700
        atom.
701
    """
702
703
    import ase
704
    import ase.spacegroup
705
706
    # Unit cell
707
    cellpar = [block.get(tag) for tag in _CIF_TAGS['cellpar']]
708
    if None in cellpar:
709
        raise ParseError(f'{block.name} has missing cell data')
710
    cell = cellpar_to_cell(np.array(cellpar))
711
712
    # Asymmetric unit coordinates. ase removes uncertainty brackets
713
    asym_unit = [block.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
714
    if None in asym_unit:
715
        asym_unit = [
716
            block.get(tag.lower()) for tag in _CIF_TAGS['atom_site_cartn']
717
        ]
718
        if None in asym_unit:
719
            raise ParseError(f'{block.name} has missing coordinates')
720
        else:
721
            raise ParseError(
722
                f'{block.name} uses _atom_site_Cartn_ tags for coordinates, '
723
                'only _atom_site_fract_ is supported'
724
            )
725
    asym_unit = list(zip(*asym_unit))
726
727
    # Labels
728
    asym_labels = block.get('_atom_site_label')
729
    if asym_labels is None:
730
        asym_labels = [''] * len(asym_unit)
731
732
    # Atomic types
733
    asym_symbols = block.get('_atom_site_type_symbol')
734 View Code Duplication
    if asym_symbols is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
735
        asym_symbols_ = []
736
        for label in asym_symbols:
737
            sym = ''
738
            if label and label not in ('.', '?'):
739
                match = re.search(r'([A-Z][a-z]?)', label)
740
                if match is not None:
741
                    sym = match.group() 
742
            asym_symbols_.append(sym)
743
    else:
744
        asym_symbols_ = [''] * len(asym_unit)
745
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
746
747
    # Find where sites have disorder if necassary
748
    has_disorder = []
749
    if disorder != 'all_sites':
750
        occupancies = block.get('_atom_site_occupancy')
751
        if occupancies is None:
752
            occupancies = [1] * len(asym_unit)
753
        for lab, occ in zip(asym_labels, occupancies):
754
            has_disorder.append(_has_disorder(lab, occ))
755
756
    # Remove sites with ?, . or other invalid string for coordinates
757
    invalid = []
758
    for i, xyz in enumerate(asym_unit):
759
        if not all(isinstance(coord, (int, float)) for coord in xyz):
760
            invalid.append(i)
761
    if invalid:
762
        warnings.warn('atoms without sites or missing data will be removed')
763
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
764
        asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
765
        if disorder != 'all_sites':
766
            has_disorder = [d for i, d in enumerate(has_disorder)
767
                            if i not in invalid]
768
769
    remove_sites = []
770
771
    if remove_hydrogens:
772
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
773
774
    # Remove atoms with fractional occupancy or raise ParseError
775 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
776
        for i, dis in enumerate(has_disorder):
777
            if i in remove_sites:
778
                continue
779
            if dis:
780
                if disorder == 'skip':
781
                    raise ParseError(
782
                        f'{block.name} has disorder, pass '
783
                        "disorder='ordered_sites' or 'all_sites' to "
784
                        'remove/ignore disorder'
785
                    )
786
                elif disorder == 'ordered_sites':
787
                    remove_sites.append(i)
788
789
    # Asymmetric unit
790
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
791
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
792
    if len(asym_unit) == 0:
793
        raise ParseError(f'{block.name} has no valid sites')
794
    asym_unit = np.mod(np.array(asym_unit), 1)
795
796
    # recommended by pymatgen
797
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
798
799
    # Remove overlapping sites unless disorder == 'all_sites'
800
    if disorder != 'all_sites':
801
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
802
        if not np.all(keep_sites):
803
            warnings.warn(
804
                'may have overlapping sites, duplicates will be removed'
805
            )
806
            asym_unit = asym_unit[keep_sites]
807
            asym_types = [t for t, keep in zip(asym_types, keep_sites) if keep]
808
809
    # Get symmetry operations
810
    sitesym = block._get_any(_CIF_TAGS['symop'])
811
    if sitesym is None:
812
        label_or_num = block._get_any(
813
            [s.lower() for s in _CIF_TAGS['spacegroup_name']]
814
        )
815
        if label_or_num is None:
816
            label_or_num = block._get_any(
817
                [s.lower() for s in _CIF_TAGS['spacegroup_number']]
818
            )
819
        if label_or_num is None:
820
            warnings.warn('no symmetry data found, defaulting to P1')
821
            label_or_num = 1
822
        spg = ase.spacegroup.Spacegroup(label_or_num)
823
        rot, trans = spg.get_op()
824
    else:
825
        if isinstance(sitesym, str):
826
            sitesym = [sitesym]
827
        rot, trans = _parse_sitesyms(sitesym)
828
829
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
830
    _, wyc_muls = np.unique(invs, return_counts=True)
831
    wyc_muls = wyc_muls.astype(np.int32)
832
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
833
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
834
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
835
    motif = frac_motif @ cell
836
837
    return PeriodicSet(
838
        motif=motif,
839
        cell=cell,
840
        name=block.name,
841
        asymmetric_unit=asym_inds,
842
        wyckoff_multiplicities=wyc_muls,
843
        types=types
844
    )
845
846
847
def periodicset_from_pymatgen_cifblock(
848
        block,
849
        remove_hydrogens: bool = False,
850
        disorder: str = 'skip'
851
) -> PeriodicSet:
852
    """Convert a :class:`pymatgen.io.cif.CifBlock` object to a
853
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
854
    :class:`pymatgen.io.cif.CifBlock` is the type returned by
855
    :class:`pymatgen.io.cif.CifFile`.
856
857
    Parameters
858
    ----------
859
    block : :class:`pymatgen.io.cif.CifBlock`
860
        A pymatgen CifBlock object representing a crystal.
861
    remove_hydrogens : bool, optional
862
        Remove Hydrogens from the crystal.
863
    disorder : str, optional
864
        Controls how disordered structures are handled. Default is
865
        ``skip`` which skips any crystal with disorder, since disorder
866
        conflicts with the periodic set model. To read disordered
867
        structures anyway, choose either :code:`ordered_sites` to remove
868
        atoms with disorder or :code:`all_sites` include all atoms
869
        regardless of disorder.
870
871
    Returns
872
    -------
873
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
874
        Represents the crystal as a periodic set, consisting of a finite
875
        set of points (motif) and lattice (unit cell). Contains other
876
        useful data, e.g. the crystal's name and information about the
877
        asymmetric unit for calculation.
878
879
    Raises
880
    ------
881
    ParseError
882
        Raised if the structure can/should not be parsed for the
883
        following reasons: 1. No sites found or motif is empty after
884
        removing Hydrogens & disorder, 2. A site has missing
885
        coordinates, 3. :code:``disorder == 'skip'`` and disorder is
886
        found on any atom.
887
    """
888
889
    from pymatgen.io.cif import str2float
890
891
    odict = block.data
892
893
    # Unit cell
894
    cellpar = [odict.get(tag) for tag in _CIF_TAGS['cellpar']]
895
    if any(par in (None, '?', '.') for par in cellpar):
896
        raise ParseError(f'{block.header} has missing cell data')
897
    cell = cellpar_to_cell(np.array([str2float(v) for v in cellpar]))
898
899
    # Asymmetric unit coordinates
900
    asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
901
    # check for . and ?
902
    if None in asym_unit:
903
        asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_cartn']]
904
        if None in asym_unit:
905
            raise ParseError(f'{block.header} has missing coordinates')
906
        else:
907
            raise ParseError(
908
                f'{block.header} uses _atom_site_Cartn_ tags for coordinates, '
909
                'only _atom_site_fract_ is supported'
910
            )
911
    asym_unit = list(zip(*asym_unit))
912
    asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
913
914
    # Labels
915
    asym_labels = odict.get('_atom_site_label')
916
    if asym_labels is None:
917
        asym_labels = [''] * len(asym_unit)
918
919
    # Atomic types
920
    asym_symbols = odict.get('_atom_site_type_symbol')
921 View Code Duplication
    if asym_symbols is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
922
        asym_symbols_ = []
923
        for label in asym_symbols:
924
            sym = ''
925
            if label and label not in ('.', '?'):
926
                match = re.search(r'([A-Z][a-z]?)', label)
927
                if match is not None:
928
                    sym = match.group() 
929
            asym_symbols_.append(sym)
930
    else:
931
        asym_symbols_ = [''] * len(asym_unit)
932
    asym_types = [_ATOMIC_NUMBERS[s] for s in asym_symbols_]
933
934
    # Find where sites have disorder if necassary
935
    has_disorder = []
936
    if disorder != 'all_sites':
937
        occupancies = odict.get('_atom_site_occupancy')
938
        if occupancies is None:
939
            occupancies = np.ones((len(asym_unit), ))
940
        else:
941
            occupancies = np.array([str2float(occ) for occ in occupancies])
942
        labels = odict.get('_atom_site_label')
943
        if labels is None:
944
            labels = [''] * len(asym_unit)
945
        for lab, occ in zip(labels, occupancies):
946
            has_disorder.append(_has_disorder(lab, occ))
947
948
    # Remove sites with ?, . or other invalid string for coordinates
949
    invalid = []
950
    for i, xyz in enumerate(asym_unit):
951
        if not all(isinstance(coord, (int, float)) for coord in xyz):
952
            invalid.append(i)
953
954
    if invalid:
955
        warnings.warn('atoms without sites or missing data will be removed')
956
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
957
        asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
958
        if disorder != 'all_sites':
959
            has_disorder = [
960
                d for i, d in enumerate(has_disorder) if i not in invalid
961
            ]
962
963
    remove_sites = []
964
965
    if remove_hydrogens:
966
        remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
0 ignored issues
show
introduced by
The variable i does not seem to be defined in case the for loop on line 950 is not entered. Are you sure this can never be the case?
Loading history...
967
968
    # Remove atoms with fractional occupancy or raise ParseError
969 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
970
        for i, dis in enumerate(has_disorder):
971
            if i in remove_sites:
972
                continue
973
            if dis:
974
                if disorder == 'skip':
975
                    raise ParseError(
976
                        f'{block.header} has disorder, pass '
977
                        "disorder='ordered_sites' or 'all_sites' to "
978
                        'remove/ignore disorder'
979
                    )
980
                elif disorder == 'ordered_sites':
981
                    remove_sites.append(i)
982
983
    # Asymmetric unit
984
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
985
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
986
    if len(asym_unit) == 0:
987
        raise ParseError(f'{block.header} has no valid sites')
988
    asym_unit = np.mod(np.array(asym_unit), 1)
989
990
    # recommended by pymatgen
991
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
992
993
    # Remove overlapping sites unless disorder == 'all_sites'
994
    if disorder != 'all_sites':
995
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
996
        if not np.all(keep_sites):
997
            warnings.warn(
998
                'may have overlapping sites; duplicates will be removed'
999
            )
1000
        asym_unit = asym_unit[keep_sites]
1001
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1002
1003
    # Apply symmetries to asymmetric unit
1004
    rot, trans = _get_syms_pymatgen(odict)
1005
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1006
    _, wyc_muls = np.unique(invs, return_counts=True)
1007
    wyc_muls = wyc_muls.astype(np.int32)
1008
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1009
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1010
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1011
    motif = frac_motif @ cell
1012
1013
    return PeriodicSet(
1014
        motif=motif,
1015
        cell=cell,
1016
        name=block.header,
1017
        asymmetric_unit=asym_inds,
1018
        wyckoff_multiplicities=wyc_muls,
1019
        types=types
1020
    )
1021
1022
1023
def periodicset_from_ase_atoms(
1024
        atoms,
1025
        remove_hydrogens: bool = False
1026
) -> PeriodicSet:
1027
    """Convert an :class:`ase.atoms.Atoms` object to a
1028
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
1029
    the option to remove disorder.
1030
1031
    Parameters
1032
    ----------
1033
    atoms : :class:`ase.atoms.Atoms`
1034
        An ase :class:`ase.atoms.Atoms` object representing a crystal.
1035
    remove_hydrogens : bool, optional
1036
        Remove Hydrogens from the crystal.
1037
1038
    Returns
1039
    -------
1040
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1041
        Represents the crystal as a periodic set, consisting of a finite
1042
        set of points (motif) and lattice (unit cell). Contains other
1043
        useful data, e.g. the crystal's name and information about the
1044
        asymmetric unit for calculation.
1045
1046
    Raises
1047
    ------
1048
    ParseError
1049
        Raised if there are no valid sites in atoms.
1050
    """
1051
1052
    from ase.spacegroup import get_basis
1053
1054
    cell = atoms.get_cell().array
1055
1056
    remove_inds = []
1057
    if remove_hydrogens:
1058
        for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
1059
            remove_inds.append(i)
1060
    for i in sorted(remove_inds, reverse=True):
1061
        atoms.pop(i)
1062
1063
    if len(atoms) == 0:
1064
        raise ParseError('ase Atoms object has no valid sites')
1065
1066
    # Symmetry operations from spacegroup
1067
    spg = None
1068
    if 'spacegroup' in atoms.info:
1069
        spg = atoms.info['spacegroup']
1070
        rot, trans = spg.rotations, spg.translations
1071
    else:
1072
        warnings.warn('no symmetry data found, defaulting to P1')
1073
        rot = np.identity(3)[None, :]
1074
        trans = np.zeros((1, 3))
1075
1076
    # Asymmetric unit. ase default tol is 1e-5
1077
    # do differently! get_basis determines a reduced asym unit from the atoms;
1078
    # surely this is not needed!
1079
    asym_unit = get_basis(atoms, spacegroup=spg, tol=_EQ_SITE_TOL)
1080
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1081
    _, wyc_muls = np.unique(invs, return_counts=True)
1082
    wyc_muls = wyc_muls.astype(np.int32)
1083
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1084
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1085
    motif = frac_motif @ cell
1086
1087
    return PeriodicSet(
1088
        motif=motif,
1089
        cell=cell,
1090
        asymmetric_unit=asym_inds,
1091
        wyckoff_multiplicities=wyc_muls,
1092
        types=atoms.get_atomic_numbers().astype(np.uint8)
1093
    )
1094
1095
1096
def periodicset_from_pymatgen_structure(
1097
        structure,
1098
        remove_hydrogens: bool = False,
1099
        disorder: str = 'skip'
1100
) -> PeriodicSet:
1101
    """Convert a :class:`pymatgen.core.structure.Structure` object to a
1102
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not set
1103
    the name of the periodic set, as pymatgen Structure objects seem to
1104
    have no name attribute.
1105
1106
    Parameters
1107
    ----------
1108
    structure : :class:`pymatgen.core.structure.Structure`
1109
        A pymatgen Structure object representing a crystal.
1110
    remove_hydrogens : bool, optional
1111
        Remove Hydrogens from the crystal.
1112
    disorder : str, optional
1113
        Controls how disordered structures are handled. Default is
1114
        ``skip`` which skips any crystal with disorder, since disorder
1115
        conflicts with the periodic set model. To read disordered
1116
        structures anyway, choose either :code:`ordered_sites` to remove
1117
        atoms with disorder or :code:`all_sites` include all atoms
1118
        regardless of disorder.
1119
1120
    Returns
1121
    -------
1122
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1123
        Represents the crystal as a periodic set, consisting of a finite
1124
        set of points (motif) and lattice (unit cell). Contains other
1125
        useful data, e.g. the crystal's name and information about the
1126
        asymmetric unit for calculation.
1127
1128
    Raises
1129
    ------
1130
    ParseError
1131
        Raised if the :code:`disorder == 'skip'` and
1132
        :code:`not structure.is_ordered`
1133
    """
1134
1135
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1136
1137
    if remove_hydrogens:
1138
        structure.remove_species(['H', 'D'])
1139
1140
    # Disorder
1141
    if disorder == 'skip':
1142
        if not structure.is_ordered:
1143
            raise ParseError(
1144
                'pymatgen Structure has disorder, pass '
1145
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1146
                'disorder'
1147
            )
1148
    elif disorder == 'ordered_sites':
1149
        remove_inds = []
1150
        for i, comp in enumerate(structure.species_and_occu):
1151
            if comp.num_atoms < 1:
1152
                remove_inds.append(i)
1153
        structure.remove_sites(remove_inds)
1154
1155
    motif = structure.cart_coords
1156
    cell = structure.lattice.matrix
1157
    sym_structure = SpacegroupAnalyzer(structure).get_symmetrized_structure()
1158
    eq_inds = sym_structure.equivalent_indices
1159
    asym_unit = np.array([ix_list[0] for ix_list in eq_inds], dtype=np.int32)
1160
    wyc_muls = np.array([len(ix_list) for ix_list in eq_inds], dtype=np.int32)
1161
    types = np.array(sym_structure.atomic_numbers, dtype=np.uint8)
1162
1163
    return PeriodicSet(
1164
        motif=motif,
1165
        cell=cell,
1166
        asymmetric_unit=asym_unit,
1167
        wyckoff_multiplicities=wyc_muls,
1168
        types=types
1169
    )
1170
1171
1172
def periodicset_from_ccdc_entry(
1173
        entry,
1174
        remove_hydrogens: bool = False,
1175
        disorder: str = 'skip',
1176
        heaviest_component: bool = False,
1177
        molecular_centres: bool = False
1178
) -> PeriodicSet:
1179
    """Convert a :class:`ccdc.entry.Entry` object to a
1180
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1181
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
1182
1183
    Parameters
1184
    ----------
1185
    entry : :class:`ccdc.entry.Entry`
1186
        A ccdc Entry object representing a database entry.
1187
    remove_hydrogens : bool, optional
1188
        Remove Hydrogens from the crystal.
1189
    disorder : str, optional
1190
        Controls how disordered structures are handled. Default is
1191
        ``skip`` which skips any crystal with disorder, since disorder
1192
        conflicts with the periodic set model. To read disordered
1193
        structures anyway, choose either :code:`ordered_sites` to remove
1194
        atoms with disorder or :code:`all_sites` include all atoms
1195
        regardless of disorder.
1196
    heaviest_component : bool, optional
1197
        Removes all but the heaviest molecule in the asymmeric unit,
1198
        intended for removing solvents.
1199
    molecular_centres : bool, default False
1200
        Extract the centres of molecules in the unit cell and store in
1201
        the attribute molecular_centres of the returned PeriodicSet.
1202
1203
    Returns
1204
    -------
1205
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1206
        Represents the crystal as a periodic set, consisting of a finite
1207
        set of points (motif) and lattice (unit cell). Contains other
1208
        useful data, e.g. the crystal's name and information about the
1209
        asymmetric unit for calculation.
1210
1211
    Raises
1212
    ------
1213
    ParseError
1214
        Raised if the structure fails parsing for any of the following:
1215
        1. entry.has_3d_structure is False, 2.
1216
        :code:``disorder == 'skip'`` and disorder is found on any atom,
1217
        3. entry.crystal.molecule.all_atoms_have_sites is False,
1218
        4. a.fractional_coordinates is None for any a in
1219
        entry.crystal.disordered_molecule, 5. The motif is empty after
1220
        removing Hydrogens and disordered sites.
1221
    """
1222
1223
    # Entry specific flag
1224
    if not entry.has_3d_structure:
1225
        raise ParseError(f'{entry.identifier} has no 3D structure')
1226
1227
    # Disorder
1228
    if disorder == 'skip' and entry.has_disorder:
1229
        raise ParseError(
1230
            f"{entry.identifier} has disorder, pass disorder='ordered_sites' "
1231
            "or 'all_sites' to remove/ignore disorder"
1232
        )
1233
1234
    return periodicset_from_ccdc_crystal(
1235
        entry.crystal,
1236
        remove_hydrogens=remove_hydrogens,
1237
        disorder=disorder,
1238
        heaviest_component=heaviest_component,
1239
        molecular_centres=molecular_centres
1240
    )
1241
1242
1243
def periodicset_from_ccdc_crystal(
1244
        crystal,
1245
        remove_hydrogens: bool = False,
1246
        disorder: str = 'skip',
1247
        heaviest_component: bool = False,
1248
        molecular_centres: bool = False
1249
) -> PeriodicSet:
1250
    """Convert a :class:`ccdc.crystal.Crystal` object to a
1251
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1252
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
1253
1254
    Parameters
1255
    ----------
1256
    crystal : :class:`ccdc.crystal.Crystal`
1257
        A ccdc Crystal object representing a crystal structure.
1258
    remove_hydrogens : bool, optional
1259
        Remove Hydrogens from the crystal.
1260
    disorder : str, optional
1261
        Controls how disordered structures are handled. Default is
1262
        ``skip`` which skips any crystal with disorder, since disorder
1263
        conflicts with the periodic set model. To read disordered
1264
        structures anyway, choose either :code:`ordered_sites` to remove
1265
        atoms with disorder or :code:`all_sites` include all atoms
1266
        regardless of disorder.
1267
    heaviest_component : bool, optional
1268
        Removes all but the heaviest molecule in the asymmeric unit,
1269
        intended for removing solvents.
1270
    molecular_centres : bool, default False
1271
        Extract the centres of molecules in the unit cell and store in
1272
        the attribute molecular_centres of the returned PeriodicSet.
1273
1274
    Returns
1275
    -------
1276
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1277
        Represents the crystal as a periodic set, consisting of a finite
1278
        set of points (motif) and lattice (unit cell). Contains other
1279
        useful data, e.g. the crystal's name and information about the
1280
        asymmetric unit for calculation.
1281
1282
    Raises
1283
    ------
1284
    ParseError
1285
        Raised if the structure fails parsing for any of the following:
1286
        1. :code:``disorder == 'skip'`` and disorder is found on any
1287
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
1288
        3. a.fractional_coordinates is None for any a in
1289
        crystal.disordered_molecule, 4. The motif is empty after
1290
        removing H, disordered sites or solvents.
1291
    """
1292
1293
    molecule = crystal.disordered_molecule
1294
1295
    # Disorder
1296
    if disorder == 'skip':
1297
        if crystal.has_disorder or \
1298
         any(_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
1299
            raise ParseError(
1300
                f"{crystal.identifier} has disorder, pass "
1301
                "disorder='ordered_sites' or 'all_sites' to remove/ignore "
1302
                "disorder"
1303
            )
1304
    elif disorder == 'ordered_sites':
1305
        molecule.remove_atoms(
1306
            a for a in molecule.atoms if _has_disorder(a.label, a.occupancy)
1307
        )
1308
1309
    if remove_hydrogens:
1310
        molecule.remove_atoms(
1311
            a for a in molecule.atoms if a.atomic_symbol in 'HD'
1312
        )
1313
1314
    if heaviest_component and len(molecule.components) > 1:
1315
        molecule = _heaviest_component_ccdc(molecule)
1316
1317
    # Remove atoms with missing coordinates and warn
1318
    if any(a.fractional_coordinates is None for a in molecule.atoms):
1319
        warnings.warn('atoms without sites or missing data will be removed')
1320
        molecule.remove_atoms(
1321
            a for a in molecule.atoms if a.fractional_coordinates is None
1322
        )
1323
1324
    crystal.molecule = molecule
1325
    cellpar = crystal.cell_lengths + crystal.cell_angles
1326
    if None in cellpar:
1327
        raise ParseError(f'{crystal.identifier} has missing cell data')
1328
    cell = cellpar_to_cell(np.array(cellpar))
1329
1330
    if molecular_centres:
1331
        frac_centres = _frac_molecular_centres_ccdc(crystal, _EQ_SITE_TOL)
1332
        mol_centres = frac_centres @ cell
1333
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
1334
1335
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
1336
    # check for None?
1337
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
1338
1339
    if asym_unit.shape[0] == 0:
1340
        raise ParseError(f'{crystal.identifier} has no valid sites')
1341
1342
    asym_unit = np.mod(asym_unit, 1)
1343
1344
    # recommended by pymatgen
1345
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1346
1347
    asym_types = [a.atomic_number for a in asym_atoms]
1348
1349
    # Remove overlapping sites unless disorder == 'all_sites'
1350
    if disorder != 'all_sites':
1351
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1352
        if not np.all(keep_sites):
1353
            warnings.warn(
1354
                'may have overlapping sites; duplicates will be removed'
1355
            )
1356
        asym_unit = asym_unit[keep_sites]
1357
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1358
1359
    # Symmetry operations
1360
    sitesym = crystal.symmetry_operators
1361
    # try spacegroup numbers?
1362
    if not sitesym:
1363
        warnings.warn('no symmetry data found, defaulting to P1')
1364
        sitesym = ['x,y,z']
1365
1366
    # Apply symmetries to asymmetric unit
1367
    rot, trans = _parse_sitesyms(sitesym)
1368
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1369
    _, wyc_muls = np.unique(invs, return_counts=True)
1370
    wyc_muls = wyc_muls.astype(np.int32)
1371
    asym_inds = np.zeros_like(wyc_muls, dtype=np.int32)
1372
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1373
    motif = frac_motif @ cell
1374
    types = np.array([asym_types[i] for i in invs], dtype=np.uint8)
1375
1376
    return PeriodicSet(
1377
        motif=motif,
1378
        cell=cell,
1379
        name=crystal.identifier,
1380
        asymmetric_unit=asym_inds,
1381
        wyckoff_multiplicities=wyc_muls,
1382
        types=types
1383
    )
1384
1385
1386
def _parse_sitesyms(
1387
        symmetries: List[str]
1388
) -> Tuple[np.ndarray[np.float64], np.ndarray[np.float64]]:
1389
    """Parse a sequence of symmetries in xyz form and return rotation
1390
    and translation arrays.
1391
    """
1392
1393
    n_syms = len(symmetries)
1394
    rotations = np.zeros((n_syms, 3, 3), dtype=np.float64)
1395
    translations = np.zeros((n_syms, 3), dtype=np.float64)
1396
1397
    for i, sym in enumerate(symmetries):
1398
        for ind, element in enumerate(sym.split(',')):
1399
1400
            is_positive = True
1401
            is_fraction = False
1402
            sng_trans = None
1403
            fst_trans = []
1404
            snd_trans = []
1405
1406
            for char in element.lower():
1407
                if char == '+':
1408
                    is_positive = True
1409
                elif char == '-':
1410
                    is_positive = False
1411
                elif char == '/':
1412
                    is_fraction = True
1413
                elif char in 'xyz':
1414
                    rot_sgn = 1.0 if is_positive else -1.0
1415
                    rotations[i][ind][ord(char) - ord('x')] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1416
                elif char.isdigit() or char == '.':
1417
                    if sng_trans is None:
1418
                        sng_trans = 1.0 if is_positive else -1.0
1419
                    if is_fraction:
1420
                        snd_trans.append(char)
1421
                    else:
1422
                        fst_trans.append(char)
1423
1424
            if not fst_trans:
1425
                e_trans = 0.0
1426
            else:
1427
                e_trans = sng_trans * float(''.join(fst_trans))
1428
1429
            if is_fraction:
1430
                e_trans /= float(''.join(snd_trans))
1431
1432
            translations[i][ind] = e_trans
1433
1434
    return rotations, translations
1435
1436
1437
def _expand_asym_unit(
1438
        asym_unit: np.ndarray,
1439
        rotations: np.ndarray,
1440
        translations: np.ndarray,
1441
        tol: float
1442
) -> Tuple[np.ndarray[np.float64], np.ndarray[np.int32]]:
1443
    """Expand the asymmetric unit by applying symmetries given by
1444
    ``rotations`` and ``translations``.
1445
    """
1446
1447
    asym_unit = asym_unit.astype(np.float64, copy=False)
1448
    rotations = rotations.astype(np.float64, copy=False)
1449
    translations = translations.astype(np.float64, copy=False)
1450
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1451
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1452
1453
    if not all(_unique_sites(frac_motif, tol)):
1454
        frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1455
1456
    return frac_motif, invs
1457
1458
1459
@numba.njit(cache=True)
1460
def _expand_sites(
1461
        asym_unit: np.ndarray[np.float64],
1462
        rotations: np.ndarray[np.float64],
1463
        translations: np.ndarray[np.float64]
1464
) -> np.ndarray[np.float64]:
1465
    """Expand the asymmetric unit by applying ``rotations`` and
1466
    ``translations``, without yet removing points duplicated because
1467
    they are invariant under a symmetry. Returns a 3D array shape
1468
    (#points, #syms, dims).
1469
    """
1470
1471
    m, dims = asym_unit.shape
1472
    n_syms = len(rotations)
1473
    expanded_sites = np.empty((m, n_syms, dims), dtype=np.float64)
1474
    for i in range(m):
1475
        p = asym_unit[i]
1476
        for j in range(n_syms):
1477
            expanded_sites[i, j] = np.dot(rotations[j], p) + translations[j]
1478
    expanded_sites = np.mod(expanded_sites, 1)
1479
    return expanded_sites
1480
1481
1482
@numba.njit(cache=True)
1483
def _reduce_expanded_sites(
1484
        expanded_sites: np.ndarray[np.float64],
1485
        tol: float
1486
) -> Tuple[np.ndarray[np.float64], np.ndarray[np.int32]]:
1487
    """Reduce the asymmetric unit after being expended by symmetries by
1488
    removing invariant points. This is the fast version which works in
1489
    the case that no two sites in the asymmetric unit are equivalent.
1490
    If they are, the reduction is re-ran with
1491
    _reduce_expanded_equiv_sites() to account for it.
1492
    """
1493
1494
    all_unqiue_inds = []
1495
    n_sites, _, dims = expanded_sites.shape
1496
    multiplicities = np.zeros(shape=(n_sites, ))
1497
1498
    for i, sites in enumerate(expanded_sites):
1499
        unique_inds = _unique_sites(sites, tol)
1500
        all_unqiue_inds.append(unique_inds)
1501
        multiplicities[i] = np.sum(unique_inds)
1502
1503
    m = int(np.sum(multiplicities))
1504
    frac_motif = np.zeros(shape=(m, dims))
1505
    inverses = np.zeros(shape=(m, ), dtype=np.int32)
1506
1507
    s = 0
1508
    for i in range(n_sites):
1509
        t = s + multiplicities[i]
1510
        frac_motif[s:t, :] = expanded_sites[i][all_unqiue_inds[i]]
1511
        inverses[s:t] = i
1512
        s = t
1513
1514
    return frac_motif, inverses
1515
1516
1517
def _reduce_expanded_equiv_sites(
1518
        expanded_sites: np.ndarray[np.float64],
1519
        tol: float
1520
) -> Tuple[np.ndarray[np.float64], np.ndarray[np.int32]]:
1521
    """Reduce the asymmetric unit after being expended by symmetries by
1522
    removing invariant points. This is the slower version, called after
1523
    the fast version if we find equivalent motif points which need to be
1524
    removed.
1525
    """
1526
1527
    sites = expanded_sites[0]
1528
    unique_inds = _unique_sites(sites, tol)
1529
    frac_motif = sites[unique_inds]
1530
    inverses = [0] * len(frac_motif)
1531
1532
    for i in range(1, len(expanded_sites)):
1533
        sites = expanded_sites[i]
1534
        unique_inds = _unique_sites(sites, tol)
1535
1536
        points = []
1537
        for site in sites[unique_inds]:
1538
            diffs1 = np.abs(site - frac_motif)
1539
            diffs2 = np.abs(diffs1 - 1)
1540
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1541
1542
            if not np.any(mask):
1543
                points.append(site)
1544
            else:
1545
                warnings.warn(
1546
                    'has equivalent sites at positions '
1547
                    f'{inverses[np.argmax(mask)]}, {i}'
1548
                )
1549
1550
        if points:
1551
            inverses.extend(i for _ in range(len(points)))
1552
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1553
1554
    return frac_motif, np.array(inverses, dtype=np.int32)
1555
1556
1557
@numba.njit(cache=True)
1558
def _unique_sites(
1559
        asym_unit: np.ndarray[np.float64], tol: float
1560
) -> np.ndarray[np.bool_]:
1561
    """Uniquify (within tol) a list of fractional coordinates,
1562
    considering all points modulo 1. Return an array of bools such that
1563
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1564
    """
1565
1566
    m, _ = asym_unit.shape
1567
    where_unique = np.full(shape=(m, ), fill_value=True)
1568
1569
    for i in range(1, m):
1570
        asym_unit[i]
1571
        site_diffs1 = np.abs(asym_unit[:i, :] - asym_unit[i])
1572
        site_diffs2 = np.abs(site_diffs1 - 1)
1573
        sites_neq_mask = (site_diffs1 > tol) & (site_diffs2 > tol)
1574
        if not np.all(np.sum(sites_neq_mask, axis=-1)):
1575
            where_unique[i] = False
1576
1577
    return where_unique
1578
1579
1580
def _has_disorder(label: str, occupancy) -> bool:
1581
    """Return True if label ends with ? or occupancy is a number < 1."""
1582
    try:
1583
        occupancy = float(occupancy)
1584
    except Exception:
1585
        occupancy = 1
1586
    return (occupancy < 1) or label.endswith('?')
1587
1588
1589
def _get_syms_pymatgen(
1590
        data: dict
1591
) -> Tuple[np.ndarray[np.float64], np.ndarray[np.float64]]:
1592
    """Parse symmetry operations given by data = block.data where block
1593
    is a pymatgen CifBlock object. If the symops are not present the
1594
    space group symbol/international number is parsed and symops are
1595
    generated.
1596
    """
1597
1598
    from pymatgen.symmetry.groups import SpaceGroup
1599
    import pymatgen.io.cif
1600
1601
    # Try xyz symmetry operations
1602
    for symmetry_label in _CIF_TAGS['symop']:
1603
        xyz = data.get(symmetry_label)
1604
        if not xyz:
1605
            continue
1606
        if isinstance(xyz, str):
1607
            xyz = [xyz]
1608
        return _parse_sitesyms(xyz)
1609
1610
    symops = []
1611
    # Try spacegroup symbol
1612
    for symmetry_label in _CIF_TAGS['spacegroup_name']:
1613
        sg = data.get(symmetry_label)
1614
        if not sg:
1615
            continue
1616
        sg = re.sub(r'[\s_]', '', sg)
1617
        try:
1618
            spg = pymatgen.io.cif.space_groups.get(sg)
1619
            if not spg:
1620
                continue
1621
            symops = SpaceGroup(spg).symmetry_ops
1622
            break
1623
        except ValueError:
1624
            pass
1625
        try:
1626
            for d in pymatgen.io.cif._get_cod_data():
1627
                if sg == re.sub(r'\s+', '', d['hermann_mauguin']):
1628
                    return _parse_sitesyms(d['symops'])
1629
        except Exception:
1630
            continue
1631
        if symops:
1632
            break
1633
1634
    # Try international number
1635
    if not symops:
1636
        for symmetry_label in _CIF_TAGS['spacegroup_number']:
1637
            num = data.get(symmetry_label)
1638
            if not num:
1639
                continue
1640
            try:
1641
                i = int(pymatgen.io.cif.str2float(num))
1642
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1643
                break
1644
            except ValueError:
1645
                continue
1646
1647
    if not symops:
1648
        warnings.warn('no symmetry data found, defaulting to P1')
1649
        return _parse_sitesyms(['x,y,z'])
1650
1651
    rotations = [op.rotation_matrix for op in symops]
1652
    translations = [op.translation_vector for op in symops]
1653
    rotations = np.array(rotations, dtype=np.float64)
1654
    translations = np.array(translations, dtype=np.float64)
1655
1656
    return rotations, translations
1657
1658
1659
def _frac_molecular_centres_ccdc(
1660
        crystal, tol: float
1661
) -> np.ndarray[np.float64]:
1662
    """Return the geometric centres of molecules in the unit cell.
1663
    Expects a ccdc Crystal object and returns fractional coordiantes.
1664
    """
1665
1666
    frac_centres = []
1667
    for comp in crystal.packing(inclusion='CentroidIncluded').components:
1668
        coords = [a.fractional_coordinates for a in comp.atoms]
1669
        frac_centres.append([sum(ax) / len(coords) for ax in zip(*coords)])
1670
    frac_centres = np.mod(np.array(frac_centres, dtype=np.float64), 1)
1671
    return frac_centres[_unique_sites(frac_centres, tol)]
1672
1673
1674
def _heaviest_component_ccdc(molecule):
1675
    """Remove all but the heaviest component of the asymmetric unit.
1676
    Intended for removing solvents. Expects and returns a ccdc Molecule
1677
    object.
1678
    """
1679
1680
    component_weights = []
1681
    for component in molecule.components:
1682
        weight = 0
1683
        for a in component.atoms:
1684
            try:
1685
                occ = float(a.occupancy)
1686
            except ValueError:
1687
                occ = 1
1688
            try:
1689
                weight += float(a.atomic_weight) * occ
1690
            except ValueError:
1691
                pass
1692
        component_weights.append(weight)
1693
    largest_component_ind = np.argmax(np.array(component_weights))
1694
    molecule = molecule.components[largest_component_ind]
1695
    return molecule
1696
1697
1698
def _snap_small_prec_coords(
1699
        frac_coords: np.ndarray[np.float64], tol: float
1700
) -> np.ndarray[np.float64]:
1701
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1702
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1703
    """
1704
1705
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.
1706
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.
1707
    return frac_coords
1708