Completed
Push — master ( 01edc4...2b1c29 )
by Rich
01:29
created

Converter.save_features()   A

Complexity

Conditions 2

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 1
Metric Value
c 1
b 0
f 1
dl 0
loc 6
rs 9.4285
cc 2
1
#! /usr/bin/env python
2
#
3
# Copyright (C) 2016 Rich Lewis <[email protected]>
4
# License: 3-clause BSD
5
6
import warnings
7
import logging
8
import os
9
from collections import namedtuple
10
11
import numpy as np
12
import pandas as pd
13
import h5py
14
from fuel.datasets import H5PYDataset
15
16
from ... import forcefields
17
from ... import filters
18
from ... import descriptors
19
from ... import cross_validation
20
from ... import standardizers
21
22
logger = logging.getLogger(__name__)
23
24
25
Feature = namedtuple('Feature', ['fper', 'key', 'axis_names'])
26
27
DEFAULT_FEATURES = (
28
    Feature(fper=descriptors.MorganFingerprinter(),
29
            key='X_morg',
30
            axis_names=['batch', 'features']),
31
    Feature(fper=descriptors.PhysicochemicalFingerprinter(),
32
            key='X_pc',
33
            axis_names=['batch', 'features']),
34
    Feature(fper=descriptors.AtomFeatureCalculator(max_atoms=100),
35
            key='A',
36
            axis_names=['batch', 'atom_idx', 'features']),
37
    Feature(fper=descriptors.GraphDistanceCalculator(max_atoms=100),
38
            key='G',
39
            axis_names=['batch', 'atom_idx', 'atom_idx']),
40
    Feature(fper=descriptors.SpaceDistanceCalculator(max_atoms=100),
41
            key='G_d',
42
            axis_names=['batch', 'atom_idx', 'atom_idx']),
43
    Feature(fper=descriptors.ChemAxonFeatureCalculator(feat_set='optimal'),
44
            key='X_cx',
45
            axis_names=['batch', 'features']),
46
    Feature(fper=descriptors.ChemAxonAtomFeatureCalculator(feat_set='all', max_atoms=),
0 ignored issues
show
introduced by
invalid syntax
Loading history...
47
            key='A_cx',
48
            axis_names=['batch', 'atom_idx', 'features'])
49
)
50
51
DEFAULT_FILTERS = (
52
    filters.OrganicFilter(),
53
    filters.AtomNumberFilter(above=5, below=100, include_hydrogens=True),
54
    filters.MassFilter(below=1000)
55
)
56
57
DEFAULT_STANDARDIZER = standardizers.ChemAxonStandardizer(keep_failed=True, warn_on_fail=False)
58
59
DEFAULT_FORCEFIELD = forcefields.UFF(add_hs=True, warn_on_fail=False)
60
61
class Converter(object):
62
    """ Create a fuel dataset from molecules and targets.
63
64
    Args:
65
        ms (pd.Series):
66
            The molecules of the dataset.
67
        ys (pd.Series or pd.DataFrame):
68
            The target labels of the dataset.
69
        output_path (str):
70
            The path to which the dataset should be saved.
71
        features (list[Feature]):
72
            The features to calculate. Defaults are provided.
73
        splits (dict):
74
            A dictionary of different splits provided.
75
            The keys should be the split name, and values an array of indices.
76
            Alternatively, if `contiguous_splits` is `True`, the keys should be
77
            the split name, and the values a tuple of start and stop.
78
            If `None`, use `skchem.cross_validation.SimThresholdSplit`
79
    """
80
81
82
    def __init__(self, directory, output_directory, output_filename='default.h5'):
83
        raise NotImplemented
84
85
    def run(self, ms, y, output_path,
86
                features=DEFAULT_FEATURES, splits=None, contiguous=False):
87
88
        self.contiguous = contiguous
89
        self.output_path = output_path
90
        self.features = features
91
        self.feature_names = [feat.key for feat in self.features] + ['y']
92
93
        self.create_file(output_path)
94
95
        if not splits:
96
            splits, idx = self.create_splits(ms)
97
            ms, y = ms.ix[idx], y.ix[idx]
98
99
        split_dict = self.process_splits(splits)
100
101
        self.save_splits(split_dict)
102
        self.save_molecules(ms)
103
        self.save_targets(y)
104
        self.save_features(ms)
105
106
    def create_file(self, path):
107
        logger.info('Creating h5 file at %s...', self.output_path)
108
        self.data_file = h5py.File(path, 'w')
109
        return self.data_file
110
111
    def filter(self, data, filters=DEFAULT_FILTERS):
112
113
        """ Filter the compounds according to the usual filters. """
114
        n_initial = len(data)
115
        logger.info('Filtering %s compounds', n_initial)
116
117
        for filt in DEFAULT_FILTERS:
118
            data = filt.filter(data)
119
120
        logger.info('Filtered out %s compounds', n_initial - len(data))
121
122
        return data
123
124
    def standardize(self, data, standardizer=DEFAULT_STANDARDIZER):
125
126
        """ Standardize the compounds. """
127
        logger.info('Standardizing %s compounds', len(data))
128
        return standardizer.transform(data)
129
130
    def optimize(self, data, optimizer=DEFAULT_FORCEFIELD):
131
132
        """ Opimize 3D geometry of the comopunds. """
133
134
        logger.info('Optimizing the geometry of %s compounds with %s', len(data), DEFAULT_FORCEFIELD.__class__)
135
        return optimizer.transform(data)
136
137
    def save_molecules(self, mols):
138
139
        """ Save the molecules to the data file. """
140
141
        logger.info('Writing molecules to file...')
142
        logger.debug('Writing %s molecules to %s', len(mols), self.data_file.filename)
143
        with warnings.catch_warnings():
144
            warnings.simplefilter('ignore')
145
            mols.to_hdf(self.data_file.filename, 'structure')
146
            mols.apply(lambda m: m.to_smiles().encode('utf-8')).to_hdf(self.data_file.filename, 'smiles')
147
148
    def save_targets(self, y):
149
150
        """ Save the targets to the data file. """
151
        y_name = getattr(y, 'name', None)
152
        if not y_name:
153
            y_name = getattr(y.columns, 'name', None)
154
        if not y_name:
155
            y_name = 'targets'
156
157
        logger.info('Writing %s', y_name)
158
        logger.debug('Writing targets of shape %s to %s', y.shape, self.data_file.filename)
159
160
        with warnings.catch_warnings():
161
            warnings.simplefilter('ignore')
162
            y.to_hdf(self.data_file.filename, '/targets/' + y_name)
163
164
        if isinstance(y, pd.Series):
165
            self.data_file['y'] = h5py.SoftLink('/targets/{}/values'.format(y_name))
166
            self.data_file['y'].dims[0].label = 'batch'
167
168
        elif isinstance(y, pd.DataFrame):
169
            self.data_file['y'] = h5py.SoftLink('/targets/{}/block0_values'.format(y_name))
170
            self.data_file['y'].dims[0].label = 'batch'
171
            self.data_file['y'].dims[0].label = 'task'
172
173
    def save_features(self, ms):
174
175
        """ Save all features for the dataset. """
176
        logger.debug('Saving features')
177
        for feat in self.features:
178
            self._save_feature(ms, feat)
179
180
    def _save_feature(self, ms, feat):
181
182
        """ Calculate and save a feature to the data file. """
183
        logger.info('Calculating %s', feat.key)
184
185
        fps = feat.fper.transform(ms)
186
        if len(feat.axis_names) > 2:
187
            fps = fps.transpose(2, 1, 0) # panel serialize backwards for some reason...
188
        logger.debug('Writing features with shape %s to %s', fps.shape, self.data_file.filename)
189
        with warnings.catch_warnings():
190
            warnings.simplefilter('ignore')
191
            fps.to_hdf(self.data_file.filename, 'features/{}'.format(feat.key))
192
        self.data_file[feat.key] = h5py.SoftLink('/features/{}/block0_values'.format(feat.key))
193
        self.data_file[feat.key].dims[0].label = feat.axis_names[0]
194
        self.data_file[feat.key].dims[1].label = feat.axis_names[1]
195
        if len(feat.axis_names) > 2:
196
            self.data_file[feat.key].dims[2].label = feat.axis_names[2]
197
198
    def create_splits(self, ms, contiguous=True):
199
200
        """ Create a split dict for fuel from mols, using SimThresholdSplit.
201
202
        Args:
203
            ms (pd.Series):
204
                The molecules to use to design the splits.
205
            contiguous (bool):
206
                Whether the split should be contiguous.  This allows for more
207
                efficient loading times.  This usually is the appropriate if
208
                there are no other splits for the dataset, and will reorder
209
                the dataset.
210
        Returns:
211
            (dict, idx)
212
                The split dict, and the index to align the data with.
213
        """
214
215
        logger.info('Creating Similarity Threshold splits...')
216
        cv = cross_validation.SimThresholdSplit(ms, memory_optimized=True)
217
        train, valid, test = cv.split((70, 15, 15))
218
219
        def bool_to_index(ser):
220
            return np.nonzero(ser.values)[0]
221
222
        if self.contiguous:
223
            dset = pd.Series(0, ms.index)
224
            dset[train] = 0
225
            dset[valid] = 1
226
            dset[test] = 2
227
            dset = dset.sort_values()
228
            idx = dset.index
229
            train_split = bool_to_index(dset == 0)
230
            valid_split = bool_to_index(dset == 1)
231
            test_split = bool_to_index(dset == 2)
232
233
            def min_max(split):
234
                return min(split), max(split) + 1
235
236
            splits = {
237
                'train': min_max(train_split),
238
                'valid': min_max(valid_split),
239
                'test': min_max(test_split)
240
            }
241
242
        else:
243
244
            idx = ms.index
245
246
            splits = {
247
                'train': bool_to_index(train),
248
                'valid': bool_to_index(valid),
249
                'test': bool_to_index(test)
250
            }
251
252
        return splits, idx
253
254
    def process_splits(self, splits, contiguous=False):
255
256
        """ Create a split dict for fuel from provided indexes. """
257
258
        logger.info('Creating split array.')
259
260
        split_dict = {}
261
262
        if self.contiguous:
263
            logger.debug('Contiguous splits.')
264
            for split_name, (start, stop) in splits.items():
265
                split_dict[split_name] = {feat: (start, stop, h5py.Reference()) for feat in self.feature_names}
266
        else:
267
            for split_name, split in splits.items():
268
                split_indices_name = '{}_indices'.format(split_name).encode('utf-8')
269
                logger.debug('Saving %s to %s', split_indices_name, self.data_file.filename)
270
                self.data_file[split_indices_name] = split
271
                split_ref = self.data_file[split_indices_name].ref
272
                split_dict[split_name] = {feat: (-1, -1, split_ref) for feat in self.feature_names}
273
274
        return split_dict
275
276
    def save_splits(self, split_dict):
277
278
        """ Save the splits to the data file. """
279
280
        logger.info('Producing dataset splits...')
281
        split = H5PYDataset.create_split_array(split_dict)
282
        logger.debug('split: %s', split)
283
        logger.info('Saving splits...')
284
        with warnings.catch_warnings():
285
            warnings.simplefilter('ignore')
286
            self.data_file.attrs['split'] = split
287
288
    @classmethod
289
    def convert(cls, **kwargs):
290
        kwargs.setdefault('directory', os.getcwd())
291
        kwargs.setdefault('output_directory', os.getcwd())
292
293
        return cls(**kwargs).output_path,
294
295
    @classmethod
296
    def fill_subparser(cls, subparser):
297
        return cls.convert
298