Passed
Push — master ( 1d1d87...c4cb0d )
by Daniel
02:01
created

amd.pset_io.SetWriter.close()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Contains I/O tools, including a .CIF reader and CSD reader
2
(``csd-python-api`` only) to extract periodic set representations
3
of crystals which can be passed to :func:`.calculate.AMD` and :func:`.calculate.PDD`.
4
5
These intermediate :class:`.periodicset.PeriodicSet` representations can be written
6
to a .hdf5 file with :class:`SetWriter`, which can be read back with :class:`SetReader`.
7
This is much faster than rereading a .CIF and recomputing invariants.
8
"""
9
10
from typing import Iterable, Optional
11
12
import numpy as np
13
import h5py
14
15
from .periodicset import PeriodicSet
16
17
18
class SetWriter:
19
    """Write several :class:`.periodicset.PeriodicSet` objects to a .hdf5 file.
20
    Reading the .hdf5 is much faster than parsing a .CIF file.
21
22
    Examples:
23
24
        Write the crystals in mycif.cif to a .hdf5 file::
25
26
            with amd.SetWriter('crystals.hdf5') as writer:
27
28
                for periodic_set in amd.CifReader('mycif.cif'):
29
                    writer.write(periodic_set)
30
31
                # use iwrite to write straight from an iterator
32
                # below is equivalent to the above loop
33
                writer.iwrite(amd.CifReader('mycif.cif'))
34
35
    Read the crystals back from the file with :class:`SetReader`.
36
    """
37
38
    _str_dtype = h5py.vlen_dtype(str)
39
40
    def __init__(self, filename: str):
41
42
        self.file = h5py.File(filename, 'w', track_order=True)
43
44
    def write(self, periodic_set: PeriodicSet, name: Optional[str] = None):
45
        """Write a PeriodicSet object to file."""
46
47
        if not isinstance(periodic_set, PeriodicSet):
48
            raise ValueError(
49
                f'Object type {periodic_set.__class__.__name__} cannot be written with SetWriter')
50
51
        # need a name to store or you can't access items by key
52
        if name is None:
53
            if periodic_set.name is None:
54
                raise ValueError(
55
                    'Periodic set must have a name to be written. Either set the name '
56
                    'attribute of the PeriodicSet or pass a name to SetWriter.write()')
57
            name = periodic_set.name
58
59
        # this group is the PeriodicSet
60
        group = self.file.create_group(name)
61
62
        # datasets in the group for motif and cell
63
        group.create_dataset('motif', data=periodic_set.motif)
64
        group.create_dataset('cell', data=periodic_set.cell)
65
66
        if periodic_set.tags:
67
            # a subgroup contains tags that are lists or ndarrays
68
            tags_group = group.create_group('tags')
69
70
            for tag in periodic_set.tags:
71
                data = periodic_set.tags[tag]
72
73
                if data is None:               # nonce to handle None
74
                    tags_group.attrs[tag] = '__None'
75
                elif np.isscalar(data):        # scalars (nums and strs) stored as attrs
76
                    tags_group.attrs[tag] = data
77
                elif isinstance(data, np.ndarray):
78
                    tags_group.create_dataset(tag, data=data)
79
                elif isinstance(data, list):
80
                    # lists of strings stored as special type for some reason
81
                    if any(isinstance(d, str) for d in data):
82
                        data = [str(d) for d in data]
83
                        tags_group.create_dataset(tag,
84
                                                  data=data,
85
                                                  dtype=SetWriter._str_dtype)
86
                    else:    # other lists must be castable to ndarray
87
                        data = np.asarray(data)
88
                        tags_group.create_dataset(tag, data=np.array(data))
89
                else:
90
                    raise ValueError(
91
                        f'Cannot store tag of type {type(data)} with SetWriter')
92
93
    def iwrite(self, periodic_sets: Iterable[PeriodicSet]):
94
        """Write :class:`.periodicset.PeriodicSet` objects from an iterable to file."""
95
        for periodic_set in periodic_sets:
96
            self.write(periodic_set)
97
98
    def close(self):
99
        """Close the :class:`SetWriter`."""
100
        self.file.close()
101
102
    def __enter__(self):
103
        return self
104
105
    # handle exceptions?
106
    def __exit__(self, exc_type, exc_value, tb):
107
        self.file.close()
108
109
110
class SetReader:
111
    """Read :class:`.periodicset.PeriodicSet` objects from a .hdf5 file written
112
    with :class:`SetWriter`. Acts like a read-only dict that can be iterated
113
    over (preserves write order).
114
115
    Examples:
116
117
        Get PDDs (k=100) of crystals in crystals.hdf5::
118
119
            pdds = []
120
            with amd.SetReader('crystals.hdf5') as reader:
121
                for periodic_set in reader:
122
                    pdds.append(amd.PDD(periodic_set, 100))
123
124
            # above is equivalent to:
125
            pdds = [amd.PDD(pset, 100) for pset in amd.SetReader('crystals.hdf5')]
126
    """
127
128
    def __init__(self, filename: str):
129
130
        self.file = h5py.File(filename, 'r', track_order=True)
131
132
    def _get_set(self, name: str) -> PeriodicSet:
133
        # take a name in the set and return the PeriodicSet
134
        group = self.file[name]
135
        periodic_set = PeriodicSet(group['motif'][:], group['cell'][:], name=name)
136
137
        if 'tags' in group:
138
            for tag in group['tags']:
139
                data = group['tags'][tag][:]
140
141
                if any(isinstance(d, (bytes, bytearray)) for d in data):
142
                    periodic_set.tags[tag] = [d.decode() for d in data]
143
                else:
144
                    periodic_set.tags[tag] = data
145
146
            for attr in group['tags'].attrs:
147
                data = group['tags'].attrs[attr]
148
                periodic_set.tags[attr] = None if data == '__None' else data
149
150
        return periodic_set
151
152
    def close(self):
153
        """Close the :class:`SetReader`."""
154
        self.file.close()
155
156
    def family(self, refcode: str) -> Iterable[PeriodicSet]:
157
        """Yield any :class:`.periodicset.PeriodicSet` whose name starts with
158
        input refcode."""
159
        for name in self.keys():
160
            if name.startswith(refcode):
161
                yield self._get_set(name)
162
163
    def __getitem__(self, name):
164
        # index by name. Not found exc?
165
        return self._get_set(name)
166
167
    def __len__(self):
168
        return len(self.keys())
169
170
    def __iter__(self):
171
        # interface to loop over the SetReader; does not close the SetReader when done
172
        for name in self.keys():
173
            yield self._get_set(name)
174
175
    def __contains__(self, item):
176
        return bool(item in self.keys())
177
178
    def keys(self):
179
        """Yield names of items in the :class:`SetReader`."""
180
        return self.file['/'].keys()
181
182
    def __enter__(self):
183
        return self
184
185
    # handle exceptions?
186
    def __exit__(self, exc_type, exc_value, tb):
187
        self.file.close()
188
0 ignored issues
show
coding-style introduced by
Trailing newlines
Loading history...
189