Passed
Push — master ( eb28f3...7c7660 )
by Daniel
04:02
created

amd.io._expand_sites()   A

Complexity

Conditions 3

Size

Total Lines 18
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 18
rs 9.75
c 0
b 0
f 0
cc 3
nop 3
1
"""Tools for reading crystals from files, or from the CSD with
2
``csd-python-api``. The readers return
3
:class:`amd.PeriodicSet <.periodicset.PeriodicSet>` objects representing
4
the crystal which can be passed to :func:`amd.AMD() <.calculate.AMD>`
5
and :func:`amd.PDD() <.calculate.PDD>`.
6
"""
7
8
import os
9
import re
10
import functools
11
import warnings
12
import errno
13
from typing import Tuple, List
14
import math
15
16
import numpy as np
17
import numba
18
19
from .utils import cellpar_to_cell
20
from .periodicset import PeriodicSet
21
22
23
def _custom_warning(message, category, filename, lineno, *args, **kwargs):
24
    return f'{category.__name__}: {message}\n'
25
warnings.formatwarning = _custom_warning
26
27
28
_EQ_SITE_TOL = 1e-3
29
_DISORDER_OPTIONS = {'skip', 'ordered_sites', 'all_sites'}
30
_CCDC_IMPORT_ERR_MSG = 'Failed to import csd-python-api.'
31
_CIF_TAGS = {
32
    'cellpar': [
33
        '_cell_length_a',
34
        '_cell_length_b',
35
        '_cell_length_c',
36
        '_cell_angle_alpha',
37
        '_cell_angle_beta',
38
        '_cell_angle_gamma',],
39
40
    'atom_site_fract': [
41
        '_atom_site_fract_x',
42
        '_atom_site_fract_y',
43
        '_atom_site_fract_z',],
44
45
    'atom_site_cartn': [
46
        '_atom_site_cartn_x',
47
        '_atom_site_cartn_y',
48
        '_atom_site_cartn_z',],
49
50
    'atom_symbol': [
51
        '_atom_site_type_symbol',
52
        '_atom_site_label',],
53
54
    'symop': [
55
        '_symmetry_equiv_pos_as_xyz',
56
        '_space_group_symop_operation_xyz',
57
        '_space_group_symop.operation_xyz',
58
        '_symmetry_equiv_pos_as_xyz_',
59
        '_space_group_symop_operation_xyz_',],
60
61
    'spacegroup_name': [
62
        '_space_group_name_Hall',
63
        '_symmetry_space_group_name_hall',
64
        '_space_group_name_H-M_alt',
65
        '_symmetry_space_group_name_H-M',
66
        '_symmetry_space_group_name_H_M',
67
        '_symmetry_space_group_name_h-m',],
68
69
    'spacegroup_number': [
70
        '_space_group_IT_number',
71
        '_symmetry_Int_Tables_number',
72
        '_space_group_IT_number_',
73
        '_symmetry_Int_Tables_number_',],
74
}
75
76
77
class _Reader:
78
    """Base reader class.
79
    """
80
81
    def __init__(self, iterable, converter, show_warnings=True):
82
        self._iterator = iter(iterable)
83
        self._converter = converter
84
        self.show_warnings = show_warnings
85
86
    def __iter__(self):
87
        if self._iterator is None or self._converter is None:
88
            raise RuntimeError(f'{self.__class__.__name__} not initialized.')
89
        return self
90
91
    def __next__(self):
92
        """Iterates over self._iterator, passing items through
93
        self._converter. Catches :class:`ParseError <.io.ParseError>`
94
        and warnings raised in self._converter, optionally printing
95
        them.
96
        """
97
98
        if not self.show_warnings:
99
            warnings.simplefilter('ignore')
100
101
        while True:
102
103
            item = next(self._iterator)
104
105
            with warnings.catch_warnings(record=True) as warning_msgs:
106
                msg = None
107
                try:
108
                    periodic_set = self._converter(item)
109
                except ParseError as err:
110
                    msg = str(err)
111
112
            if msg:
113
                warnings.warn(msg)
114
                continue
115
116
            for warning in warning_msgs:
117
                msg = f'(name={periodic_set.name}) {warning.message}'
118
                warnings.warn(msg, category=warning.category)
119
120
            return periodic_set
121
122
    def read(self):
123
        """Reads the crystal(s), returns one
124
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` if there is
125
        only one, otherwise returns a list.
126
        """
127
128
        l = list(self)
129
        if len(l) == 1:
130
            return l[0]
131
        return l
132
133
134
class CifReader(_Reader):
135
    """Read all structures in a .cif file or all files in a folder
136
    with ase or csd-python-api (if installed), yielding
137
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
138
139
    Parameters
140
    ----------
141
    path : str
142
        Path to a .CIF file or directory. (Other files are accepted when
143
        using ``reader='ccdc'``, if csd-python-api is installed.)
144
    reader : str, optional
145
        The backend package used to parse the CIF. The default is 
146
        :code:`ase`, :code:`pymatgen` and :code:`gemmi` are also
147
        accepted, as well as :code:`ccdc` if csd-python-api is 
148
        installed. The ccdc reader should be able to read any format
149
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
150
        been tested.
151
    remove_hydrogens : bool, optional
152
        Remove Hydrogens from the crystals.
153
    disorder : str, optional
154
        Controls how disordered structures are handled. Default is
155
        ``skip`` which skips any crystal with disorder, since disorder
156
        conflicts with the periodic set model. To read disordered
157
        structures anyway, choose either :code:`ordered_sites` to remove
158
        atoms with disorder or :code:`all_sites` include all atoms
159
        regardless of disorder.
160
    heaviest_component : bool, optional
161
        csd-python-api only. Removes all but the heaviest molecule in
162
        the asymmeric unit, intended for removing solvents.
163
    molecular_centres : bool, default False
164
        csd-python-api only. Extract the centres of molecules in the
165
        unit cell and store in the attribute molecular_centres.
166
    show_warnings : bool, optional
167
        Controls whether warnings that arise during reading are printed.
168
169
    Yields
170
    ------
171
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
172
        Represents the crystal as a periodic set, consisting of a finite
173
        set of points (motif) and lattice (unit cell). Contains other
174
        data, e.g. the crystal's name and information about the
175
        asymmetric unit.
176
177
    Examples
178
    --------
179
180
        ::
181
182
            # Put all crystals in a .CIF in a list
183
            structures = list(amd.CifReader('mycif.cif'))
184
185
            # Can also accept path to a directory, reading all files inside
186
            structures = list(amd.CifReader('path/to/folder'))
187
188
            # Reads just one if the .CIF has just one crystal
189
            periodic_set = amd.CifReader('mycif.cif').read()
190
191
            # List of AMDs (k=100) of crystals in a .CIF
192
            amds = [amd.AMD(periodic_set, 100) for periodic_set in amd.CifReader('mycif.cif')]
193
    """
194
195
    def __init__(
196
            self,
197
            path,
198
            reader='ase',
199
            remove_hydrogens=False,
200
            disorder='skip',
201
            heaviest_component=False,
202
            molecular_centres=False,
203
            show_warnings=True,
204
    ):
205
206
        if disorder not in _DISORDER_OPTIONS:
207
            msg = 'disorder parameter must be one of ' \
208
                  f'{_DISORDER_OPTIONS} (passed {disorder})'
209
            raise ValueError(msg)
210
211
        if reader != 'ccdc':
212
            if heaviest_component:
213
                msg = 'Parameter heaviest_component only implemented for ' \
214
                      'reader="ccdc".'
215
                raise NotImplementedError(msg)
216
217
            if molecular_centres:
218
                msg = 'Parameter molecular_centres only implemented for ' \
219
                      'reader="ccdc".'
220
                raise NotImplementedError(msg)
221
222
        if reader in ('ase', 'pycodcif'):
223
            from ase.io.cif import parse_cif
224
            extensions = {'cif'}
225
            file_parser = functools.partial(parse_cif, reader=reader)
226
            converter = functools.partial(
227
                periodicset_from_ase_cifblock,
228
                remove_hydrogens=remove_hydrogens,
229
                disorder=disorder
230
            )
231
232
        elif reader == 'pymatgen':
233
            extensions = {'cif'}
234
            file_parser = CifReader._pymatgen_cifblock_generator
235
            converter = functools.partial(
236
                periodicset_from_pymatgen_cifblock,
237
                remove_hydrogens=remove_hydrogens,
238
                disorder=disorder
239
            )
240
241
        elif reader == 'gemmi':
242
            import gemmi
243
            extensions = {'cif'}
244
            file_parser = gemmi.cif.read_file
245
            converter = functools.partial(
246
                periodicset_from_gemmi_block,
247
                remove_hydrogens=remove_hydrogens,
248
                disorder=disorder
249
            )
250
251
        elif reader == 'ccdc':
252
            try:
253
                import ccdc.io
254
            except (ImportError, RuntimeError) as e:
255
                raise ImportError(_CCDC_IMPORT_ERR_MSG) from e
256
257
            extensions = ccdc.io.EntryReader.known_suffixes
258
            file_parser = ccdc.io.EntryReader
259
            converter = functools.partial(
260
                periodicset_from_ccdc_entry,
261
                remove_hydrogens=remove_hydrogens,
262
                disorder=disorder,
263
                molecular_centres=molecular_centres,
264
                heaviest_component=heaviest_component
265
            )
266
267
        else:
268
            raise ValueError(f'Unknown reader "{reader}".')
269
270
        if os.path.isfile(path):
271
            iterable = file_parser(path)
272
        elif os.path.isdir(path):
273
            iterable = CifReader._dir_generator(path, file_parser, extensions)
274
        else:
275
            raise FileNotFoundError(
276
                errno.ENOENT, os.strerror(errno.ENOENT), path)
277
278
        super().__init__(iterable, converter, show_warnings=show_warnings)
279
280
    @staticmethod
281
    def _dir_generator(path, callable, extensions):
282
        for file in os.listdir(path):
283
            suff = os.path.splitext(file)[1][1:]
284
            if suff.lower() in extensions:
285
                yield from callable(os.path.join(path, file))
286
287
    @staticmethod
288
    def _pymatgen_cifblock_generator(path):
289
        """Path to .cif --> generator of pymatgen CifBlock objects."""
290
        from pymatgen.io.cif import CifFile
291
        yield from CifFile.from_file(path).data.values()
292
293
294
class CSDReader(_Reader):
295
    """Read structures from the CSD with csd-python-api, yielding
296
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` s.
297
298
    Parameters
299
    ----------
300
    refcodes : str or List[str], optional
301
        Single or list of CSD refcodes to read. If None or 'CSD',
302
        iterates over the whole CSD.
303
    families : bool, optional
304
        Read all entries whose refcode starts with the given strings, or
305
        'families' (e.g. giving 'DEBXIT' reads all entries starting with
306
        DEBXIT).
307
    remove_hydrogens : bool, optional
308
        Remove hydrogens from the crystals.
309
    disorder : str, optional
310
        Controls how disordered structures are handled. Default is
311
        ``skip`` which skips any crystal with disorder, since disorder
312
        conflicts with the periodic set model. To read disordered
313
        structures anyway, choose either :code:`ordered_sites` to remove
314
        atoms with disorder or :code:`all_sites` include all atoms
315
        regardless of disorder.
316
    heaviest_component : bool, optional
317
        Removes all but the heaviest molecule in the asymmeric unit,
318
        intended for removing solvents.
319
    molecular_centres : bool, default False
320
        Extract the centres of molecules in the unit cell and store in
321
        attribute molecular_centres.
322
    show_warnings : bool, optional
323
        Controls whether warnings that arise during reading are printed.
324
325
    Yields
326
    ------
327
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
328
        Represents the crystal as a periodic set, consisting of a finite
329
        set of points (motif) and lattice (unit cell). Contains other
330
        useful data, e.g. the crystal's name and information about the
331
        asymmetric unit for calculation.
332
333
    Examples
334
    --------
335
336
        ::
337
338
            # Put these entries in a list
339
            refcodes = ['DEBXIT01', 'DEBXIT05', 'HXACAN01']
340
            structures = list(amd.CSDReader(refcodes))
341
342
            # Read refcode families (any whose refcode starts with strings in the list)
343
            refcode_families = ['ACSALA', 'HXACAN']
344
            structures = list(amd.CSDReader(refcode_families, families=True))
345
346
            # Get AMDs (k=100) for crystals in these families
347
            refcodes = ['ACSALA', 'HXACAN']
348
            amds = []
349
            for periodic_set in amd.CSDReader(refcodes, families=True):
350
                amds.append(amd.AMD(periodic_set, 100))
351
352
            # Giving the reader nothing reads from the whole CSD.
353
            for periodic_set in amd.CSDReader():
354
                ...
355
    """
356
357
    def __init__(
358
            self,
359
            refcodes=None,
360
            families=False,
361
            remove_hydrogens=False,
362
            disorder='skip',
363
            heaviest_component=False,
364
            molecular_centres=False,
365
            show_warnings=True,
366
    ):
367
368
        try:
369
            import ccdc.io
370
        except (ImportError, RuntimeError) as e:
371
            raise ImportError(_CCDC_IMPORT_ERR_MSG) from e
372
373
        if disorder not in _DISORDER_OPTIONS:
374
            msg = 'disorder parameter must be one of ' \
375
                  f'{_DISORDER_OPTIONS} (passed {disorder})'
376
            raise ValueError(msg)
377
378
        if isinstance(refcodes, str) and refcodes.lower() == 'csd':
379
            refcodes = None
380
381
        if refcodes is None:
382
            families = False
383
        else:
384
            if isinstance(refcodes, str):
385
                refcodes = [refcodes]
386
            else:
387
                refcodes = list(refcodes)
388
389
        if families:
390
            refcodes = self._refcodes_from_families_ccdc(refcodes)
391
392
        entry_reader = ccdc.io.EntryReader('CSD')
393
        converter = functools.partial(
394
            periodicset_from_ccdc_entry,
395
            remove_hydrogens=remove_hydrogens,
396
            disorder=disorder,
397
            molecular_centres=molecular_centres,
398
            heaviest_component=heaviest_component
399
        )
400
        iterable = self._ccdc_generator(refcodes, entry_reader)
401
        super().__init__(iterable, converter, show_warnings=show_warnings)
402
403
    @staticmethod
404
    def _ccdc_generator(refcodes, entry_reader):
405
        """Generates ccdc Entries from CSD refcodes.
406
        """
407
408
        if refcodes is None:
409
            for entry in entry_reader:
410
                yield entry
411
        else:
412
            for refcode in refcodes:
413
                try:
414
                    entry = entry_reader.entry(refcode)
415
                    yield entry
416
                except RuntimeError:
417
                    warnings.warn(f'{refcode} not found in database')
418
419
    @staticmethod
420
    def _refcodes_from_families_ccdc(refcode_families):
421
        """List of strings --> all CSD refcodes starting with any of the
422
        strings. Intended to be passed a list of families and return all
423
        refcodes in them.
424
        """
425
426
        try:
427
            import ccdc.search
428
        except (ImportError, RuntimeError) as e:
429
            raise ImportError(_CCDC_IMPORT_ERR_MSG) from e
430
431
        all_refcodes = []
432
        for refcode in refcode_families:
433
            query = ccdc.search.TextNumericSearch()
434
            query.add_identifier(refcode)
435
            hits = [hit.identifier for hit in query.search()]
436
            all_refcodes.extend(hits)
437
438
        # filter to unique refcodes
439
        seen = set()
440
        seen_add = seen.add
441
        refcodes = [
442
            refcode for refcode in all_refcodes
443
            if not (refcode in seen or seen_add(refcode))]
444
445
        return refcodes
446
447
448
class ParseError(ValueError):
449
    """Raised when an item cannot be parsed into a periodic set.
450
    """
451
    pass
452
453
454
def periodicset_from_ase_cifblock(
455
        block,
456
        remove_hydrogens=False,
457
        disorder='skip'
458
) -> PeriodicSet:
459
    """:class:`ase.io.cif.CIFBlock` --> 
460
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
461
    :class:`ase.io.cif.CIFBlock` is the type returned by
462
    :class:`ase.io.cif.parse_cif`.
463
464
    Parameters
465
    ----------
466
    block : :class:`ase.io.cif.CIFBlock`
467
        An ase :class:`ase.io.cif.CIFBlock` object representing a
468
        crystal.
469
    remove_hydrogens : bool, optional
470
        Remove Hydrogens from the crystal.
471
    disorder : str, optional
472
        Controls how disordered structures are handled. Default is
473
        ``skip`` which skips any crystal with disorder, since disorder
474
        conflicts with the periodic set model. To read disordered
475
        structures anyway, choose either :code:`ordered_sites` to remove
476
        atoms with disorder or :code:`all_sites` include all atoms
477
        regardless of disorder.
478
479
    Returns
480
    -------
481
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
482
        Represents the crystal as a periodic set, consisting of a finite
483
        set of points (motif) and lattice (unit cell). Contains other
484
        useful data, e.g. the crystal's name and information about the
485
        asymmetric unit for calculation.
486
487
    Raises
488
    ------
489
    ParseError
490
        Raised if the structure fails to be parsed for any of the
491
        following: 1. Required data is missing (e.g. cell parameters),
492
        2. The motif is empty after removing H or disordered sites,
493
        3. :code:``disorder == 'skip'`` and disorder is found on any
494
        atom.
495
    """
496
497
    import ase
498
    import ase.spacegroup
499
500
    # Unit cell
501
    cellpar = [block.get(tag) for tag in _CIF_TAGS['cellpar']]
502
    if None in cellpar:
503
        raise ParseError(f'{block.name} has missing cell data')
504
    cell = cellpar_to_cell(np.array(cellpar))
505
506
    # Asymmetric unit coordinates. ase removes uncertainty brackets
507
    cartesian = False # flag needed for later
508
    asym_unit = [block.get(name) for name in _CIF_TAGS['atom_site_fract']]
509
    if None in asym_unit: # missing scaled coords, try Cartesian
510
        asym_unit = [block.get(name) for name in _CIF_TAGS['atom_site_cartn']]
511
        if None in asym_unit:
512
            raise ParseError(f'{block.name} has missing coordinates')
513
        cartesian = True
514
    asym_unit = list(zip(*asym_unit)) # transpose [xs,ys,zs] -> [p1,p2,...]
515
516
    # Atomic types
517
    asym_symbols = block._get_any(_CIF_TAGS['atom_symbol'])
518
    if asym_symbols is None:
519
        warnings.warn('missing atomic types will be labelled 0')
520
        asym_types = [0] * len(asym_unit)
521
    else:
522
        asym_types = []
523
        for label in asym_symbols:
524
            if label in ('.', '?'):
525
                warnings.warn('missing atomic types will be labelled 0')
526
                num = 0
527
            else:
528
                sym = re.search(r'([A-Z][a-z]?)', label).group(0)
529
                if sym == 'D':
530
                    sym = 'H'
531
                num = ase.data.atomic_numbers[sym]
532
            asym_types.append(num)
533
534
    # Find where sites have disorder if necassary
535
    has_disorder = []
536
    if disorder != 'all_sites':
537
        occupancies = block.get('_atom_site_occupancy')
538
        if occupancies is None:
539
            occupancies = [1] * len(asym_unit)
540
        labels = block.get('_atom_site_label')
541
        if labels is None:
542
            labels = [''] * len(asym_unit)
543
        for lab, occ in zip(labels, occupancies):
544
            has_disorder.append(_has_disorder(lab, occ))
545
546
    # Remove sites with ?, . or other invalid string for coordinates
547
    invalid = []
548
    for i, xyz in enumerate(asym_unit):
549
        if not all(isinstance(coord, (int, float)) for coord in xyz):
550
            invalid.append(i)
551
    if invalid:
552
        warnings.warn('atoms without sites or missing data will be removed')
553
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
554
        asym_types = [t for i, t in enumerate(asym_types) if i not in invalid]
555
        if disorder != 'all_sites':
556
            has_disorder = [d for i, d in enumerate(has_disorder)
557
                            if i not in invalid]
558
559
    remove_sites = []
560
561
    if remove_hydrogens:
562
        remove_sites.extend(i for i, num in enumerate(asym_types) if num == 1)
563
564
    # Remove atoms with fractional occupancy or raise ParseError
565 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
566
        for i, dis in enumerate(has_disorder):
567
            if i in remove_sites:
568
                continue
569
            if dis:
570
                if disorder == 'skip':
571
                    msg = f"{block.name} has disorder, pass " \
572
                           "disorder='ordered_sites'or 'all_sites' to " \
573
                           "remove/ignore disorder"
574
                    raise ParseError(msg)
575
                elif disorder == 'ordered_sites':
576
                    remove_sites.append(i)
577
578
    # Asymmetric unit
579
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
580
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
581
    if len(asym_unit) == 0:
582
        raise ParseError(f'{block.name} has no valid sites')
583
    asym_unit = np.array(asym_unit)
584
585
    # If Cartesian coords were given, convert to scaled
586
    if cartesian:
587
        asym_unit = asym_unit @ np.linalg.inv(cell)
588
    asym_unit = np.mod(asym_unit, 1)
589
590
    # recommended by pymatgen, they use tol=1e-4
591
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
592
593
    # Remove overlapping sites unless disorder == 'all_sites'
594
    if disorder != 'all_sites':
595
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
596
        if not np.all(keep_sites):
597
            msg = 'may have overlapping sites, duplicates will be removed'
598
            warnings.warn(msg)
599
            asym_unit = asym_unit[keep_sites]
600
            asym_types = [t for t, keep in zip(asym_types, keep_sites) if keep]
601
602
    # Get symmetry operations
603
    sitesym = block._get_any(_CIF_TAGS['symop'])
604
    if sitesym is None:
605
        label_or_num = block._get_any(_CIF_TAGS['spacegroup_name'])
606
        if label_or_num is None:
607
            label_or_num = block._get_any(_CIF_TAGS['spacegroup_number'])
608
        if label_or_num is None:
609
            warnings.warn('no symmetry data found, defaulting to P1')
610
            label_or_num = 1
611
        spg = ase.spacegroup.Spacegroup(label_or_num)
612
        rot, trans = spg.get_op()
613
    else:
614
        if isinstance(sitesym, str):
615
            sitesym = [sitesym]
616
        rot, trans =  _parse_sitesyms(sitesym)
617
618
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
619
    _, wyc_muls = np.unique(invs, return_counts=True)
620
    asym_inds = np.zeros_like(wyc_muls)
621
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
622
    types = np.array([asym_types[i] for i in invs])
623
    motif = frac_motif @ cell
624
625
    return PeriodicSet(
626
        motif=motif,
627
        cell=cell,
628
        name=block.name,
629
        asymmetric_unit=asym_inds,
630
        wyckoff_multiplicities=wyc_muls,
631
        types=types
632
    )
633
634
635
def periodicset_from_ase_atoms(
636
        atoms,
637
        remove_hydrogens=False
638
) -> PeriodicSet:
639
    """:class:`ase.atoms.Atoms` -->
640
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`. Does not have
641
    the option to remove disorder.
642
643
    Parameters
644
    ----------
645
    atoms : :class:`ase.atoms.Atoms`
646
        An ase :class:`ase.atoms.Atoms` object representing a crystal.
647
    remove_hydrogens : bool, optional
648
        Remove Hydrogens from the crystal.
649
650
    Returns
651
    -------
652
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
653
        Represents the crystal as a periodic set, consisting of a finite
654
        set of points (motif) and lattice (unit cell). Contains other
655
        useful data, e.g. the crystal's name and information about the
656
        asymmetric unit for calculation.
657
658
    Raises
659
    ------
660
    ParseError
661
        Raised if there are no valid sites in atoms.
662
    """
663
664
    from ase.spacegroup import get_basis
665
666
    cell = atoms.get_cell().array
667
668
    remove_inds = []
669
    if remove_hydrogens:
670
        for i in np.where(atoms.get_atomic_numbers() == 1)[0]:
671
            remove_inds.append(i)
672
    for i in sorted(remove_inds, reverse=True):
673
        atoms.pop(i)
674
675
    if len(atoms) == 0:
676
        raise ParseError(f'ase Atoms object has no valid sites')
677
678
    # Symmetry operations from spacegroup 
679
    spg = None
680
    if 'spacegroup' in atoms.info:
681
        spg = atoms.info['spacegroup']
682
        rot, trans = spg.rotations, spg.translations
683
    else:
684
        warnings.warn('no symmetry data found, defaulting to P1')
685
        rot = np.identity(3)[None, :]
686
        trans = np.zeros((1, 3))
687
688
    # Asymmetric unit. ase default tol is 1e-5
689
    # do differently! get_basis determines a reduced asym unit from the atoms;
690
    # surely this is not needed!
691
    asym_unit = get_basis(atoms, spacegroup=spg, tol=_EQ_SITE_TOL)
692
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
693
    _, wyc_muls = np.unique(invs, return_counts=True)
694
    asym_inds = np.zeros_like(wyc_muls)
695
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
696
    motif = frac_motif @ cell
697
698
    return PeriodicSet(
699
        motif=motif,
700
        cell=cell,
701
        asymmetric_unit=asym_inds,
702
        wyckoff_multiplicities=wyc_muls,
703
        types=atoms.get_atomic_numbers()
704
    )
705
706
707
def periodicset_from_pymatgen_cifblock(
708
        block,
709
        remove_hydrogens=False,
710
        disorder='skip'
711
) -> PeriodicSet:
712
    """:class:`pymatgen.io.cif.CifBlock` -->
713
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
714
715
    Parameters
716
    ----------
717
    block : :class:`pymatgen.io.cif.CifBlock`
718
        A pymatgen CifBlock object representing a crystal.
719
    remove_hydrogens : bool, optional
720
        Remove Hydrogens from the crystal.
721
    disorder : str, optional
722
        Controls how disordered structures are handled. Default is
723
        ``skip`` which skips any crystal with disorder, since disorder
724
        conflicts with the periodic set model. To read disordered
725
        structures anyway, choose either :code:`ordered_sites` to remove
726
        atoms with disorder or :code:`all_sites` include all atoms
727
        regardless of disorder.
728
729
    Returns
730
    -------
731
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
732
        Represents the crystal as a periodic set, consisting of a finite
733
        set of points (motif) and lattice (unit cell). Contains other
734
        useful data, e.g. the crystal's name and information about the
735
        asymmetric unit for calculation.
736
737
    Raises
738
    ------
739
    ParseError
740
        Raised if the structure can/should not be parsed for the
741
        following reasons: 1. No sites found or motif is empty after
742
        removing Hydrogens & disorder, 2. A site has missing
743
        coordinates, 3. :code:``disorder == 'skip'`` and disorder is
744
        found on any atom.
745
    """
746
747
    from pymatgen.io.cif import str2float
748
    import pymatgen.core.periodic_table as periodic_table
749
750
    odict = block.data
751
752
    # Unit cell
753
    cellpar = [odict.get(tag) for tag in _CIF_TAGS['cellpar']]
754
    if None in cellpar:
755
        raise ParseError(f'{block.header} has missing cell data')
756
    cell = cellpar_to_cell(np.array([str2float(v) for v in cellpar]))
757
758
    # Asymmetric unit coordinates
759
    cartesian = False
760
    asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_fract']]
761
762
    if None in asym_unit: # missing scaled coordinates, try Cartesian
763
        asym_unit = [odict.get(tag) for tag in _CIF_TAGS['atom_site_cartn']]
764
        if None in asym_unit:
765
            raise ParseError(f'{block.header} has no coordinates')
766
        cartesian = True
767
768
    asym_unit = list(zip(*asym_unit)) # transpose [xs,ys,zs] -> [p1,p2,...]
769
    asym_unit = [[str2float(coord) for coord in xyz] for xyz in asym_unit]
770
771
    # Atomic types
772
    for tag in _CIF_TAGS['atom_symbol']:
773
        asym_symbols = odict.get(tag)
774
        if asym_symbols is not None:
775
            asym_types = []
776
            for label in asym_symbols:
777
                if label in ('.', '?'):
778
                    warnings.warn('missing atomic types will be labelled 0')
779
                    num = 0
780
                else:
781
                    sym = re.search(r'([A-Z][a-z]?)', label).group(0)
782
                    if sym == 'D':
783
                        sym = 'H'
784
                    num = periodic_table.Element[sym].number
785
                asym_types.append(num)
786
            break
787
    else:
788
        warnings.warn('missing atomic types will be labelled 0')
789
        asym_types = [0] * len(asym_unit)
790
791
    # Find where sites have disorder if necassary
792
    has_disorder = []
793
    if disorder != 'all_sites':
794
        occupancies = odict.get('_atom_site_occupancy')
795
        if occupancies is None:
796
            occupancies = np.ones((len(asym_unit), ))
797
        else:
798
            occupancies = np.array([str2float(occ) for occ in occupancies])
799
        labels = odict.get('_atom_site_label')
800
        if labels is None:
801
            labels = [''] * len(asym_unit)
802
        for lab, occ in zip(labels, occupancies):
803
            has_disorder.append(_has_disorder(lab, occ))
804
805
    # Remove sites with ?, . or other invalid string for coordinates
806
    invalid = []
807
    for i, xyz in enumerate(asym_unit):
808
        if not all(isinstance(coord, (int, float)) for coord in xyz):
809
            invalid.append(i)
810
811
    if invalid:
812
        warnings.warn('atoms without sites or missing data will be removed')
813
        asym_unit = [c for i, c in enumerate(asym_unit) if i not in invalid]
814
        asym_types = [c for i, c in enumerate(asym_types) if i not in invalid]
815
        if disorder != 'all_sites':
816
            has_disorder = [d for i, d in enumerate(has_disorder) 
817
                            if i not in invalid]
818
819
    remove_sites = []
820
821
    if remove_hydrogens:
822
        remove_sites.extend((i for i, n in enumerate(asym_types) if n == 1))
0 ignored issues
show
introduced by
The variable i does not seem to be defined in case the for loop on line 807 is not entered. Are you sure this can never be the case?
Loading history...
823
824
    # Remove atoms with fractional occupancy or raise ParseError
825 View Code Duplication
    if disorder != 'all_sites':
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
826
        for i, dis in enumerate(has_disorder):
827
            if i in remove_sites:
828
                continue
829
            if dis:
830
                if disorder == 'skip':
831
                    msg = f"{block.header} has disorder, pass " \
832
                            "disorder='ordered_sites' or 'all_sites' to " \
833
                            "remove/ignore disorder"
834
                    raise ParseError(msg)
835
                elif disorder == 'ordered_sites':
836
                    remove_sites.append(i)
837
838
    # Asymmetric unit
839
    asym_unit = [c for i, c in enumerate(asym_unit) if i not in remove_sites]
840
    asym_types = [t for i, t in enumerate(asym_types) if i not in remove_sites]
841
    if len(asym_unit) == 0:
842
        raise ParseError(f'{block.header} has no valid sites')
843
    asym_unit = np.array(asym_unit)
844
845
    # If Cartesian coords were given, convert to scaled
846
    if cartesian:
847
        asym_unit = asym_unit @ np.linalg.inv(cell)
848
    asym_unit = np.mod(asym_unit, 1)
849
850
    # recommended by pymatgen, they use tol=1e-4
851
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
852
853
    # Remove overlapping sites unless disorder == 'all_sites'
854
    if disorder != 'all_sites':      
855
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
856
        if not np.all(keep_sites):
857
            msg = 'may have overlapping sites; duplicates will be removed'
858
            warnings.warn(msg)
859
        asym_unit = asym_unit[keep_sites]
860
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
861
862
    # Apply symmetries to asymmetric unit
863
    rot, trans = _get_syms_pymatgen(odict)
864
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
865
    _, wyc_muls = np.unique(invs, return_counts=True)
866
    asym_inds = np.zeros_like(wyc_muls)
867
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
868
    types = np.array([asym_types[i] for i in invs])
869
    motif = frac_motif @ cell
870
871
    return PeriodicSet(
872
        motif=motif,
873
        cell=cell,
874
        name=block.header,
875
        asymmetric_unit=asym_inds,
876
        wyckoff_multiplicities=wyc_muls,
877
        types=types
878
    )
879
880
881
def periodicset_from_pymatgen_structure(
882
        structure,
883
        remove_hydrogens=False,
884
        disorder='skip'
885
) -> PeriodicSet:
886
    """:class:`pymatgen.core.structure.Structure` -->
887
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
888
    Does not set the name of the periodic set, as there seems to be no
889
    name attribute in the pymatgen Structure object.
890
891
    Parameters
892
    ----------
893
    structure : :class:`pymatgen.core.structure.Structure`
894
        A pymatgen Structure object representing a crystal.
895
    remove_hydrogens : bool, optional
896
        Remove Hydrogens from the crystal.
897
    disorder : str, optional
898
        Controls how disordered structures are handled. Default is
899
        ``skip`` which skips any crystal with disorder, since disorder
900
        conflicts with the periodic set model. To read disordered
901
        structures anyway, choose either :code:`ordered_sites` to remove
902
        atoms with disorder or :code:`all_sites` include all atoms
903
        regardless of disorder.
904
905
    Returns
906
    -------
907
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
908
        Represents the crystal as a periodic set, consisting of a finite
909
        set of points (motif) and lattice (unit cell). Contains other
910
        useful data, e.g. the crystal's name and information about the
911
        asymmetric unit for calculation.
912
    
913
    Raises
914
    ------
915
    ParseError
916
        Raised if the :code:`disorder == 'skip'` and 
917
        :code:`not structure.is_ordered`
918
    """
919
920
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
921
922
    if remove_hydrogens:
923
        structure.remove_species(['H', 'D'])
924
925
    # Disorder
926
    if disorder == 'skip':
927
        if not structure.is_ordered:
928
            msg = f"pymatgen Structure has disorder, pass " \
929
                   "disorder='ordered_sites' or 'all_sites' to " \
930
                   "remove/ignore disorder"
931
            raise ParseError(msg)
932
    elif disorder == 'ordered_sites':
933
        remove_inds = []
934
        for i, comp in enumerate(structure.species_and_occu):
935
            if comp.num_atoms < 1:
936
                remove_inds.append(i)
937
        structure.remove_sites(remove_inds)
938
939
    motif = structure.cart_coords
940
    cell = structure.lattice.matrix
941
    sym_structure = SpacegroupAnalyzer(structure).get_symmetrized_structure()
942
    asym_unit = np.array([l[0] for l in sym_structure.equivalent_indices])
943
    wyc_muls = np.array([len(l) for l in sym_structure.equivalent_indices])
944
    types = np.array(sym_structure.atomic_numbers)
945
946
    return PeriodicSet(
947
        motif=motif,
948
        cell=cell,
949
        asymmetric_unit=asym_unit,
950
        wyckoff_multiplicities=wyc_muls,
951
        types=types
952
    )
953
954
955
def periodicset_from_ccdc_entry(
956
        entry,
957
        remove_hydrogens=False,
958
        disorder='skip',
959
        heaviest_component=False,
960
        molecular_centres=False
961
) -> PeriodicSet:
962
    """:class:`ccdc.entry.Entry` -->
963
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
964
    Entry is the type returned by :class:`ccdc.io.EntryReader`.
965
966
    Parameters
967
    ----------
968
    entry : :class:`ccdc.entry.Entry`
969
        A ccdc Entry object representing a database entry.
970
    remove_hydrogens : bool, optional
971
        Remove Hydrogens from the crystal.
972
    disorder : str, optional
973
        Controls how disordered structures are handled. Default is
974
        ``skip`` which skips any crystal with disorder, since disorder
975
        conflicts with the periodic set model. To read disordered
976
        structures anyway, choose either :code:`ordered_sites` to remove
977
        atoms with disorder or :code:`all_sites` include all atoms
978
        regardless of disorder.
979
    heaviest_component : bool, optional
980
        Removes all but the heaviest molecule in the asymmeric unit,
981
        intended for removing solvents.
982
    molecular_centres : bool, default False
983
        Extract the centres of molecules in the unit cell and store in
984
        the attribute molecular_centres of the returned PeriodicSet.
985
986
    Returns
987
    -------
988
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
989
        Represents the crystal as a periodic set, consisting of a finite
990
        set of points (motif) and lattice (unit cell). Contains other
991
        useful data, e.g. the crystal's name and information about the
992
        asymmetric unit for calculation.
993
994
    Raises
995
    ------
996
    ParseError
997
        Raised if the structure fails parsing for any of the following:
998
        1. entry.has_3d_structure is False, 2.
999
        :code:``disorder == 'skip'`` and disorder is found on any atom,
1000
        3. entry.crystal.molecule.all_atoms_have_sites is False,
1001
        4. a.fractional_coordinates is None for any a in
1002
        entry.crystal.disordered_molecule, 5. The motif is empty after
1003
        removing Hydrogens and disordered sites.
1004
    """
1005
1006
    # Entry specific flag
1007
    if not entry.has_3d_structure:
1008
        raise ParseError(f'{entry.identifier} has no 3D structure')
1009
1010
    # Disorder
1011
    if disorder == 'skip' and entry.has_disorder:
1012
        msg = f"{entry.identifier} has disorder, pass " \
1013
                "disorder='ordered_sites' or 'all_sites' to remove/ignore " \
1014
                "disorder"
1015
        raise ParseError(msg)
1016
1017
    return periodicset_from_ccdc_crystal(
1018
        entry.crystal,
1019
        remove_hydrogens=remove_hydrogens,
1020
        disorder=disorder,
1021
        heaviest_component=heaviest_component,
1022
        molecular_centres=molecular_centres
1023
    )
1024
1025
1026
def periodicset_from_ccdc_crystal(
1027
        crystal,
1028
        remove_hydrogens=False,
1029
        disorder='skip',
1030
        heaviest_component=False,
1031
        molecular_centres=False
1032
) -> PeriodicSet:
1033
    """:class:`ccdc.crystal.Crystal` -->
1034
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1035
    Crystal is the type returned by :class:`ccdc.io.CrystalReader`.
1036
1037
    Parameters
1038
    ----------
1039
    crystal : :class:`ccdc.crystal.Crystal`
1040
        A ccdc Crystal object representing a crystal structure.
1041
    remove_hydrogens : bool, optional
1042
        Remove Hydrogens from the crystal.
1043
    disorder : str, optional
1044
        Controls how disordered structures are handled. Default is
1045
        ``skip`` which skips any crystal with disorder, since disorder
1046
        conflicts with the periodic set model. To read disordered
1047
        structures anyway, choose either :code:`ordered_sites` to remove
1048
        atoms with disorder or :code:`all_sites` include all atoms
1049
        regardless of disorder.
1050
    heaviest_component : bool, optional
1051
        Removes all but the heaviest molecule in the asymmeric unit,
1052
        intended for removing solvents.
1053
    molecular_centres : bool, default False
1054
        Extract the centres of molecules in the unit cell and store in
1055
        the attribute molecular_centres of the returned PeriodicSet.
1056
1057
    Returns
1058
    -------
1059
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1060
        Represents the crystal as a periodic set, consisting of a finite
1061
        set of points (motif) and lattice (unit cell). Contains other
1062
        useful data, e.g. the crystal's name and information about the
1063
        asymmetric unit for calculation.
1064
1065
    Raises
1066
    ------
1067
    ParseError
1068
        Raised if the structure fails parsing for any of the following:
1069
        1. :code:``disorder == 'skip'`` and disorder is found on any
1070
        atom, 2. crystal.molecule.all_atoms_have_sites is False,
1071
        3. a.fractional_coordinates is None for any a in
1072
        crystal.disordered_molecule, 4. The motif is empty after
1073
        removing H, disordered sites or solvents.
1074
    """
1075
1076
    molecule = crystal.disordered_molecule
1077
1078
    # Disorder
1079
    if disorder == 'skip':
1080
        if crystal.has_disorder or \
1081
         any(_has_disorder(a.label, a.occupancy) for a in molecule.atoms):
1082
            msg = f"{crystal.identifier} has disorder, pass " \
1083
                   "disorder='ordered_sites' or 'all_sites' to " \
1084
                   "remove/ignore disorder"
1085
            raise ParseError(msg)
1086
1087
    elif disorder == 'ordered_sites':
1088
        molecule.remove_atoms(
1089
            a for a in molecule.atoms if _has_disorder(a.label, a.occupancy)
1090
        )
1091
1092
    if remove_hydrogens:
1093
        molecule.remove_atoms(
1094
            a for a in molecule.atoms if a.atomic_symbol in 'HD'
1095
        )
1096
1097
    if heaviest_component and len(molecule.components) > 1:
1098
        molecule = _heaviest_component_ccdc(molecule)
1099
1100
    # Remove atoms with missing coordinates and warn
1101
    if any(a.fractional_coordinates is None for a in molecule.atoms):
1102
        warnings.warn('atoms without sites or missing data will be removed')
1103
        molecule.remove_atoms(
1104
            a for a in molecule.atoms if a.fractional_coordinates is None
1105
        )
1106
1107
    crystal.molecule = molecule
1108
    cellpar = crystal.cell_lengths + crystal.cell_angles
1109
    if None in cellpar:
1110
        raise ParseError(f'{crystal.identifier} has missing cell data')
1111
    cell = cellpar_to_cell(np.array(cellpar))
1112
1113
    if molecular_centres:
1114
        frac_centres = _frac_molecular_centres_ccdc(crystal, _EQ_SITE_TOL)
1115
        mol_centres = frac_centres @ cell
1116
        return PeriodicSet(mol_centres, cell, name=crystal.identifier)
1117
1118
    asym_atoms = crystal.asymmetric_unit_molecule.atoms
1119
    # check for None?
1120
    asym_unit = np.array([tuple(a.fractional_coordinates) for a in asym_atoms])
1121
1122
    if asym_unit.shape[0] == 0:
1123
        raise ParseError(f'{crystal.identifier} has no valid sites')
1124
1125
    asym_unit = np.mod(asym_unit, 1)
1126
1127
    # recommended by pymatgen, they use tol=1e-4
1128
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1129
1130
    asym_types = [a.atomic_number for a in asym_atoms]
1131
1132
    # Remove overlapping sites unless disorder == 'all_sites'
1133
    if disorder != 'all_sites':
1134
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1135
        if not np.all(keep_sites):
1136
            msg = 'may have overlapping sites; duplicates will be removed'
1137
            warnings.warn(msg)
1138
        asym_unit = asym_unit[keep_sites]
1139
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1140
1141
    # Symmetry operations
1142
    sitesym = crystal.symmetry_operators
1143
    # try spacegroup numbers?
1144
    if not sitesym:
1145
        warnings.warn('no symmetry data found, defaulting to P1')
1146
        sitesym = ['x,y,z']
1147
1148
    # Apply symmetries to asymmetric unit
1149
    rot, trans = _parse_sitesyms(sitesym)
1150
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1151
    _, wyc_muls = np.unique(invs, return_counts=True)
1152
    asym_inds = np.zeros_like(wyc_muls)
1153
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1154
    motif = frac_motif @ cell
1155
    types = np.array([asym_types[i] for i in invs])
1156
1157
    return PeriodicSet(
1158
        motif=motif,
1159
        cell=cell,
1160
        name=crystal.identifier,
1161
        asymmetric_unit=asym_inds,
1162
        wyckoff_multiplicities=wyc_muls,
1163
        types=types
1164
    )
1165
1166
1167
def periodicset_from_gemmi_block(
1168
        block,
1169
        remove_hydrogens=False,
1170
        disorder='skip'
1171
) -> PeriodicSet:
1172
    """:class:`gemmi.cif.Block` -->
1173
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`.
1174
    Block is the type returned by :class:`gemmi.cif.read_file`.
1175
1176
    Parameters
1177
    ----------
1178
    block : :class:`gemmi.cif.Block`
1179
        An ase CIFBlock object representing a crystal.
1180
    remove_hydrogens : bool, optional
1181
        Remove Hydrogens from the crystal.
1182
    disorder : str, optional
1183
        Controls how disordered structures are handled. Default is
1184
        ``skip`` which skips any crystal with disorder, since disorder
1185
        conflicts with the periodic set model. To read disordered
1186
        structures anyway, choose either :code:`ordered_sites` to remove
1187
        atoms with disorder or :code:`all_sites` include all atoms
1188
        regardless of disorder.
1189
1190
    Returns
1191
    -------
1192
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
1193
        Represents the crystal as a periodic set, consisting of a finite
1194
        set of points (motif) and lattice (unit cell). Contains other
1195
        useful data, e.g. the crystal's name and information about the
1196
        asymmetric unit for calculation.
1197
1198
    Raises
1199
    ------
1200
    ParseError
1201
        Raised if the structure fails to be parsed for any of the
1202
        following: 1. Required data is missing (e.g. cell parameters),
1203
        2. :code:``disorder == 'skip'`` and disorder is found on any
1204
        atom, 3. The motif is empty after removing H or disordered
1205
        sites.
1206
    """
1207
1208
    import gemmi
1209
    from gemmi.cif import as_number, as_string, as_int
1210
1211
    # Unit cell
1212
    cellpar = [as_number(block.find_value(t)) for t in _CIF_TAGS['cellpar']]
1213
    if any(p is None or math.isnan(p) for p in cellpar):
1214
        raise ParseError(f'{block.name} has missing cell data')
1215
    cell = cellpar_to_cell(np.array(cellpar))
1216
1217
    xyz_loop = block.find(_CIF_TAGS['atom_site_fract']).loop
1218
    if xyz_loop is None:
1219
        # TODO: check for Cartesian coordinates
1220
        raise ParseError(f'{block.name} has missing coordinate data')
1221
1222
    loop_dict = _loop_to_dict_gemmi(xyz_loop)
1223
1224
    # Asymmetric unit coordinates
1225
    xyz_str = [loop_dict[t] for t in _CIF_TAGS['atom_site_fract']]
1226
    asym_unit = [[as_number(c) for c in coords] for coords in xyz_str]
1227
    asym_unit = np.mod(np.array(asym_unit).T, 1)
1228
1229
    # recommended by pymatgen, they use tol=1e-4
1230
    # asym_unit = _snap_small_prec_coords(asym_unit, 1e-4)
1231
1232
    # Asymmetric unit types
1233
    if '_atom_site_type_symbol' in loop_dict:
1234
        asym_syms = [as_string(s) for s in loop_dict['_atom_site_type_symbol']]
1235
        asym_types = []
1236
        for s in asym_syms:
1237
            asym_types.append(gemmi.Element(s).atomic_number if s else 0)
1238
    else:
1239
        warnings.warn('missing atomic types will be labelled 0')
1240
        asym_types = [0 for _ in range(len(asym_unit))]
1241
1242
    remove_sites = []
1243
1244
    # Disorder
1245
    if '_atom_site_label' in loop_dict:
1246
        labels = [as_string(l) for l in loop_dict['_atom_site_label']]
1247
    else:
1248
        labels = [''] * xyz_loop.length()
1249
1250
    if '_atom_site_occupancy' in loop_dict:
1251
        occs = [as_number(occ) for occ in loop_dict['_atom_site_occupancy']]
1252
        occupancies = []
1253
        for occ in occs:
1254
            occupancies.append(1 if math.isnan(occ) else occ)
1255
    else:
1256
        occupancies = [1] * xyz_loop.length()
1257
1258
    if disorder == 'skip':
1259
        if any(_has_disorder(l, occ) for l, occ in zip(labels, occupancies)):
1260
            msg = f"{block.name} has disorder, pass " \
1261
                   "disorder='ordered_sites' or 'all_sites' to " \
1262
                   "remove/ignore disorder"
1263
            raise ParseError(msg)
1264
    elif disorder == 'ordered_sites':
1265
        for i, (label, occ) in enumerate(zip(labels, occupancies)):
1266
            if _has_disorder(label, occ):
1267
                remove_sites.append(i)
1268
1269
    if remove_hydrogens:
1270
        remove_sites.extend(
1271
            i for i, num in enumerate(asym_types) if num == 1
1272
        )
1273
1274
    # Asymmetric unit
1275
    asym_unit = np.delete(asym_unit, remove_sites, axis=0)
1276
    asym_types = [s for i, s in enumerate(asym_types) if i not in remove_sites]
1277
    if asym_unit.shape[0] == 0:
1278
        raise ParseError(f'{block.name} has no valid sites')
1279
1280
    # Remove overlapping sites unless disorder == 'all_sites'
1281
    if disorder != 'all_sites':
1282
        keep_sites = _unique_sites(asym_unit, _EQ_SITE_TOL)
1283
        if not np.all(keep_sites):
1284
            msg = 'may have overlapping sites; duplicates will be removed'
1285
            warnings.warn(msg)
1286
        asym_unit = asym_unit[keep_sites]
1287
        asym_types = [sym for sym, keep in zip(asym_types, keep_sites) if keep]
1288
1289
    # TODO: recheck below making sure missing/bad values are handled (as_string)
1290
    # Get symmetry operations
1291
    sitesym = []
1292
    for tag in _CIF_TAGS['symop']:
1293
        symop_loop = block.find([tag]).loop
1294
        if symop_loop is not None:
1295
            sitesym = _loop_to_dict_gemmi(symop_loop)[tag]
1296
            break
1297
1298
    if not sitesym:
1299
        for tag in _CIF_TAGS['spacegroup_name']:
1300
            label_or_num = block.find_value(tag)
1301
            if label_or_num is not None:
1302
                label_or_num = as_string(label_or_num)
1303
                break
1304
        if label_or_num is None:
0 ignored issues
show
introduced by
The variable label_or_num does not seem to be defined in case the for loop on line 1299 is not entered. Are you sure this can never be the case?
Loading history...
1305
            for tag in _CIF_TAGS['spacegroup_number']:
1306
                label_or_num = block.find_value(tag)
1307
                if label_or_num is not None:
1308
                    label_or_num = as_int(label_or_num)
1309
                    break
1310
        if label_or_num is None:
1311
            warnings.warn('no symmetry data found, defaulting to P1')
1312
            label_or_num = 1
1313
        ops = list(gemmi.SpaceGroup(label_or_num).operations())
1314
        rot = np.array([np.array(o.rot) / o.DEN for o in ops])
1315
        trans = np.array([np.array(o.tran) / o.DEN for o in ops])
1316
    else:
1317
        rot, trans =  _parse_sitesyms(sitesym)
1318
1319
    frac_motif, invs = _expand_asym_unit(asym_unit, rot, trans, _EQ_SITE_TOL)
1320
    _, wyc_muls = np.unique(invs, return_counts=True)
1321
    asym_inds = np.zeros_like(wyc_muls)
1322
    asym_inds[1:] = np.cumsum(wyc_muls)[:-1]
1323
    types = np.array([asym_types[i] for i in invs])
1324
    motif = frac_motif @ cell
1325
1326
    return PeriodicSet(
1327
        motif=motif,
1328
        cell=cell,
1329
        name=block.name,
1330
        asymmetric_unit=asym_inds,
1331
        wyckoff_multiplicities=wyc_muls,
1332
        types=types
1333
    )
1334
1335
1336
def _parse_sitesyms(symmetries: List[str]) -> Tuple[np.ndarray, np.ndarray]:
1337
    """Parses a sequence of symmetries in xyz form and returns rotation
1338
    and translation arrays. Similar to function found in ase.spacegroup.
1339
    """
1340
1341
    nsyms = len(symmetries)
1342
    rotations = np.zeros((nsyms, 3, 3), dtype=np.float64)
1343
    translations = np.zeros((nsyms, 3), dtype=np.float64)
1344
1345
    for i, sym in enumerate(symmetries):
1346
        for ind, element in enumerate(sym.split(',')):
1347
1348
            is_positive = True
1349
            is_fraction = False
1350
            sng_trans = None
1351
            fst_trans = []
1352
            snd_trans = []
1353
1354
            for char in element.lower():
1355
                if char == '+':
1356
                    is_positive = True
1357
                elif char == '-':
1358
                    is_positive = False
1359
                elif char == '/':
1360
                    is_fraction = True
1361
                elif char in 'xyz':
1362
                    rot_sgn = 1.0 if is_positive else -1.0
1363
                    rotations[i][ind][ord(char) - ord('x')] = rot_sgn
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ord does not seem to be defined.
Loading history...
1364
                elif char.isdigit() or char == '.':
1365
                    if sng_trans is None:
1366
                        sng_trans = 1.0 if is_positive else -1.0
1367
                    if is_fraction:
1368
                        snd_trans.append(char)
1369
                    else:
1370
                        fst_trans.append(char)
1371
1372
            if not fst_trans:
1373
                e_trans = 0.0
1374
            else:
1375
                e_trans = sng_trans * float(''.join(fst_trans))
1376
1377
            if is_fraction:
1378
                e_trans /= float(''.join(snd_trans))
1379
1380
            translations[i][ind] = e_trans
1381
1382
    return rotations, translations
1383
1384
1385
def _expand_asym_unit(
1386
        asym_unit: np.ndarray,
1387
        rotations: np.ndarray,
1388
        translations: np.ndarray,
1389
        tol: float
1390
) -> Tuple[np.ndarray, np.ndarray]:
1391
    """Asymmetric unit frac coords, list of rotations & translations -->
1392
    full fractional motif + inverse indices (which motif points come
1393
    from where in the asym unit).
1394
    """
1395
1396
    asym_unit = asym_unit.astype(np.float64)
1397
    rotations = rotations.astype(np.float64)
1398
    translations = translations.astype(np.float64)
1399
    expanded_sites = _expand_sites(asym_unit, rotations, translations)
1400
    frac_motif, invs = _reduce_expanded_sites(expanded_sites, tol)
1401
1402
    if not all(_unique_sites(frac_motif, tol)):
1403
        frac_motif, invs = _reduce_expanded_equiv_sites(expanded_sites, tol)
1404
1405
    return frac_motif, invs
1406
1407
1408
@numba.njit()
1409
def _expand_sites(
1410
        asym_unit: np.ndarray, 
1411
        rotations: np.ndarray, 
1412
        translations: np.ndarray
1413
) -> np.ndarray:
1414
    """Expand the asymmetric unit by applying rotations and
1415
    translations. Returns a 3D array shape (# points, # syms, dims).
1416
    """
1417
1418
    m, dims = asym_unit.shape
1419
    expanded_sites = np.empty((m, len(rotations), dims), dtype=np.float64)
1420
    for i in range(m):
1421
        p = asym_unit[i]
1422
        for j in range(len(rotations)):
1423
            expanded_sites[i, j] = np.dot(rotations[j], p) + translations[j]
1424
    expanded_sites = np.mod(expanded_sites, 1)
1425
    return expanded_sites
1426
1427
1428
@numba.njit()
1429
def _reduce_expanded_sites(
1430
        expanded_sites: np.ndarray,
1431
        tol: float
1432
) -> Tuple[np.ndarray, np.ndarray]:
1433
    """Reduce the asymmetric unit after being expended by symmetries by
1434
    removing invariant points. This is the fast version which works in
1435
    the case that no two sites in the asymmetric unit are equivalent.
1436
    If they are, the reduction is re-ran with
1437
    _reduce_expanded_equiv_sites() to account for it.
1438
    """
1439
1440
    all_unqiue_inds = []
1441
    multiplicities = np.zeros(shape=(expanded_sites.shape[0], ))
1442
1443
    for i, sites in enumerate(expanded_sites):
1444
        unique_inds = _unique_sites(sites, tol)
1445
        all_unqiue_inds.append(unique_inds)
1446
        multiplicities[i] = np.sum(unique_inds)
1447
1448
    m = int(np.sum(multiplicities))
1449
    frac_motif = np.zeros(shape=(m, expanded_sites.shape[-1]))
1450
    inverses = np.zeros(shape=(m, ), dtype=np.int32)
1451
    
1452
    s = 0
1453
    for i in range(expanded_sites.shape[0]):
1454
        t = s + multiplicities[i]
1455
        frac_motif[s:t, :] = expanded_sites[i][all_unqiue_inds[i]]
1456
        inverses[s:t] = i
1457
        s = t
1458
1459
    return frac_motif, inverses
1460
1461
1462
def _reduce_expanded_equiv_sites(
1463
        expanded_sites: np.ndarray,
1464
        tol: float
1465
) -> Tuple[np.ndarray, np.ndarray]:
1466
    """Reduce the asymmetric unit after being expended by symmetries by
1467
    removing invariant points. This is the slower version, called after
1468
    the fast version if we find equivalent motif points which need to be
1469
    removed.
1470
    """
1471
1472
    sites = expanded_sites[0]
1473
    unique_inds = _unique_sites(sites, tol)
1474
    frac_motif = sites[unique_inds]
1475
    inverses = [0] * len(frac_motif)
1476
1477
    for i in range(1, len(expanded_sites)):
1478
        sites = expanded_sites[i]
1479
        unique_inds = _unique_sites(sites, tol)
1480
1481
        points = []
1482
        for site in sites[unique_inds]:
1483
            diffs1 = np.abs(site - frac_motif)
1484
            diffs2 = np.abs(diffs1 - 1)
1485
            mask = np.all((diffs1 <= tol) | (diffs2 <= tol), axis=-1)
1486
1487
            if not np.any(mask):
1488
                points.append(site)
1489
            else:
1490
                msg = f'has equivalent sites at positions ' \
1491
                      f'{inverses[np.argmax(mask)]}, {i}'
1492
                warnings.warn(msg)
1493
1494
        if points:
1495
            inverses.extend(i for _ in range(len(points)))
1496
            frac_motif = np.concatenate((frac_motif, np.array(points)))
1497
1498
    return frac_motif, np.array(inverses)
1499
1500
1501
@numba.njit()
1502
def _unique_sites(asym_unit: np.ndarray, tol: float) -> np.ndarray:
1503
    """Uniquify (within tol) a list of fractional coordinates,
1504
    considering all points modulo 1. Returns an array of bools such that
1505
    asym_unit[_unique_sites(asym_unit, tol)] is the uniquified list.
1506
    """
1507
1508
    m, _ = asym_unit.shape
1509
    where_unique = np.full(shape=(m, ), fill_value=True)
1510
1511
    for i in range(1, m):
1512
        asym_unit[i]
1513
        site_diffs1 = np.abs(asym_unit[:i, :] - asym_unit[i])
1514
        site_diffs2 = np.abs(site_diffs1 - 1)
1515
        sites_neq_mask = (site_diffs1 > tol) & (site_diffs2 > tol)
1516
        if not np.all(np.sum(sites_neq_mask, axis=-1)):
1517
            where_unique[i] = False
1518
1519
    return where_unique
1520
1521
1522
def _has_disorder(label: str, occupancy):
1523
    """Return True if label ends with ?, or occupancy is a number < 1.
1524
    """
1525
    try:
1526
        occupancy = float(occupancy)
1527
    except:
1528
        occupancy = 1
1529
    return (occupancy < 1) or label.endswith('?')
1530
1531
1532
def _get_syms_pymatgen(data: dict) -> Tuple[np.ndarray, np.ndarray]:
1533
    """Parse symmetry operations given by data = block.data where block
1534
    is a pymatgen CifBlock object. If the symops are not present the
1535
    space group symbol/international number is parsed and symops are
1536
    generated.
1537
    """
1538
1539
    from pymatgen.symmetry.groups import SpaceGroup
1540
    from pymatgen.core.operations import SymmOp
1541
    import pymatgen.io.cif
1542
1543
    symops = []
1544
1545
    # Try xyz symmetry operations
1546
    for symmetry_label in _CIF_TAGS['symop']:
1547
        xyz = data.get(symmetry_label)
1548
        if not xyz:
1549
            continue
1550
        if isinstance(xyz, str):
1551
            xyz = [xyz]
1552
        try:
1553
            symops = [SymmOp.from_xyz_string(s) for s in xyz]
1554
            break
1555
        except ValueError:
1556
            continue
1557
1558
    # Try spacegroup symbol
1559
    if not symops:
1560
        for symmetry_label in _CIF_TAGS['spacegroup_name']:
1561
            sg = data.get(symmetry_label)
1562
            if not sg:
1563
                continue
1564
            sg = re.sub(r'[\s_]', '', sg)
1565
1566
            try:
1567
                spg = pymatgen.io.cif.space_groups.get(sg)
1568
                if not spg:
1569
                    continue
1570
                symops = SpaceGroup(spg).symmetry_ops
1571
                break
1572
            except ValueError:
1573
                pass
1574
1575
            try:
1576
                for d in pymatgen.io.cif._get_cod_data():
1577
                    if sg == re.sub(r'\s+', '', d['hermann_mauguin']):
1578
                        symops = [SymmOp.from_xyz_string(s)
1579
                                  for s in d['symops']]
1580
                        break
1581
            except Exception as e:
1582
                continue
1583
1584
            if symops:
1585
                break
1586
1587
    # Try international number
1588
    if not symops:
1589
        for symmetry_label in _CIF_TAGS['spacegroup_number']:
1590
            num = data.get(symmetry_label)
1591
            if not num:
1592
                continue
1593
1594
            try:
1595
                i = int(pymatgen.io.cif.str2float(num))
1596
                symops = SpaceGroup.from_int_number(i).symmetry_ops
1597
                break
1598
            except ValueError:
1599
                continue
1600
1601
    if not symops:
1602
        warnings.warn('no symmetry data found, defaulting to P1')
1603
        symops = [SymmOp.from_xyz_string('x,y,z')]
1604
1605
    rotations = [op.rotation_matrix for op in symops]
1606
    translations = [op.translation_vector for op in symops]
1607
    rotations = np.array(rotations, dtype=np.float64)
1608
    translations = np.array(translations, dtype=np.float64)
1609
1610
    return rotations, translations
1611
1612
1613
def _frac_molecular_centres_ccdc(crystal, tol):
1614
    """Returns the geometric centres of molecules in the unit cell.
1615
    Expects a ccdc Crystal object and returns fractional coordiantes.
1616
    """
1617
1618
    frac_centres = []
1619
    for comp in crystal.packing(inclusion='CentroidIncluded').components:
1620
        coords = [a.fractional_coordinates for a in comp.atoms]
1621
        frac_centres.append((sum(ax) / len(coords) for ax in zip(*coords)))
0 ignored issues
show
introduced by
The variable ax does not seem to be defined in case the for loop on line 1619 is not entered. Are you sure this can never be the case?
Loading history...
1622
    frac_centres = np.mod(np.array(frac_centres), 1)
1623
    return frac_centres[_unique_sites(frac_centres, tol)]
1624
1625
1626
def _heaviest_component_ccdc(molecule):
1627
    """Removes all but the heaviest component of the asymmetric unit.
1628
    Intended for removing solvents. Expects and returns a ccdc Molecule
1629
    object.
1630
    """
1631
1632
    component_weights = []
1633
    for component in molecule.components:
1634
        weight = 0
1635
        for a in component.atoms:
1636
            if isinstance(a.atomic_weight, (float, int)):
1637
                if isinstance(a.occupancy, (float, int)):
1638
                    weight += a.occupancy * a.atomic_weight
1639
                else:
1640
                    weight += a.atomic_weight
1641
        component_weights.append(weight)
1642
    largest_component_ind = np.argmax(np.array(component_weights))
1643
    molecule = molecule.components[largest_component_ind]
1644
    return molecule
1645
1646
1647
def _loop_to_dict_gemmi(gemmi_loop):
1648
    """gemmi Loop object --> dict, tags: values
1649
    """
1650
1651
    tablified_loop = [[] for _ in range(len(gemmi_loop.tags))]
1652
    n_cols = gemmi_loop.width()
1653
    for i, item in enumerate(gemmi_loop.values):
1654
        tablified_loop[i % n_cols].append(item)
1655
    return {tag: l for tag, l in zip(gemmi_loop.tags, tablified_loop)}
1656
1657
1658
def _snap_small_prec_coords(frac_coords: np.ndarray, tol: float) -> np.ndarray:
1659
    """Find where frac_coords is within 1e-4 of 1/3 or 2/3, change to
1660
    1/3 and 2/3. Recommended by pymatgen's CIF parser.
1661
    """
1662
    frac_coords[np.abs(1 - 3 * frac_coords) < tol] = 1 / 3.
1663
    frac_coords[np.abs(1 - 3 * frac_coords / 2) < tol] = 2 / 3.
1664
    return frac_coords
1665