1 | #! /usr/bin/env python |
||
2 | # |
||
3 | # Copyright (C) 2016 Rich Lewis <[email protected]> |
||
4 | # License: 3-clause BSD |
||
5 | |||
6 | |||
7 | 1 | """ |
|
8 | # skchem.base |
||
9 | |||
10 | Base classes for scikit-chem objects. |
||
11 | """ |
||
12 | 1 | import subprocess |
|
13 | 1 | from abc import ABCMeta, abstractmethod |
|
14 | 1 | import multiprocessing |
|
15 | 1 | from tempfile import NamedTemporaryFile |
|
16 | 1 | import time |
|
17 | 1 | import logging |
|
18 | |||
19 | 1 | import pandas as pd |
|
20 | |||
21 | 1 | from .utils import NamedProgressBar, DummyProgressBar |
|
22 | 1 | from . import core |
|
23 | 1 | from .utils import (iterable_to_series, optional_second_method, nanarray, |
|
24 | squeeze, yaml_dump, json_dump) |
||
25 | 1 | from . import io |
|
26 | |||
27 | 1 | LOGGER = logging.getLogger(__name__) |
|
28 | |||
29 | |||
30 | 1 | class BaseTransformer(object): |
|
31 | |||
32 | """ Transformer Base Class. |
||
33 | |||
34 | Specific Base Transformer classes inherit from this class and implement |
||
35 | `transform` and `axis_names`. |
||
36 | """ |
||
37 | |||
38 | 1 | __metaclass__ = ABCMeta |
|
39 | |||
40 | # To share some functionality betweeen Transformer and AtomTransformer |
||
41 | |||
42 | 1 | def __init__(self, n_jobs=1, verbose=True): |
|
43 | 1 | self._n_jobs = None # property cache |
|
44 | 1 | self.n_jobs = n_jobs |
|
45 | 1 | self.verbose = verbose |
|
46 | |||
47 | 1 | @property |
|
48 | def n_jobs(self): |
||
49 | 1 | return self._n_jobs |
|
50 | |||
51 | 1 | @n_jobs.setter |
|
52 | def n_jobs(self, val): |
||
53 | 1 | if val >= 1: |
|
54 | 1 | self._n_jobs = val |
|
55 | elif val == -1: |
||
56 | self._n_jobs = multiprocessing.cpu_count() |
||
57 | |||
58 | 1 | def get_params(self): |
|
59 | """ Get a dictionary of the parameters of this object. """ |
||
60 | params = list(self.__class__.__init__.__code__.co_varnames) |
||
61 | params.remove('self') |
||
62 | return {param: getattr(self, param) for param in params} |
||
63 | |||
64 | 1 | @classmethod |
|
65 | def from_params(cls, params): |
||
66 | """ Create a instance from a params dictionary. """ |
||
67 | return cls(**params) |
||
68 | |||
69 | 1 | def to_dict(self): |
|
70 | |||
71 | """ Return a dictionary representation of the object.""" |
||
72 | |||
73 | n = '{}.{}'.format(self.__class__.__module__, self.__class__.__name__) |
||
0 ignored issues
–
show
|
|||
74 | return {n: self.get_params()} |
||
75 | |||
76 | 1 | def to_json(self, target=None): |
|
77 | |||
78 | """ Serialize the object as JSON. |
||
79 | |||
80 | Args: |
||
81 | target (str or file-like): |
||
82 | A file or filepath to serialize the object to. If `None`, |
||
83 | return the JSON as a string. |
||
84 | |||
85 | Returns: |
||
86 | None or str |
||
87 | """ |
||
88 | |||
89 | return json_dump(self.to_dict(), target) |
||
90 | |||
91 | 1 | def to_yaml(self, target=None): |
|
92 | |||
93 | """ Serialize the object as YAML. |
||
94 | |||
95 | Args: |
||
96 | target (str or file-like): |
||
97 | A file or filepath to serialize the object to. If `None`, |
||
98 | return the YAML as a string. |
||
99 | |||
100 | Returns: |
||
101 | None or str |
||
102 | """ |
||
103 | |||
104 | return yaml_dump(self.to_dict(), target) |
||
105 | |||
106 | 1 | def copy(self): |
|
107 | """ Return a copy of this object. """ |
||
108 | return self.__class__(**self.get_params()) |
||
109 | |||
110 | 1 | def optional_bar(self, **kwargs): |
|
111 | 1 | if self.verbose: |
|
112 | 1 | bar = NamedProgressBar(name=self.__class__.__name__, **kwargs) |
|
113 | else: |
||
114 | bar = DummyProgressBar(**kwargs) |
||
115 | 1 | return bar |
|
116 | |||
117 | 1 | @property |
|
118 | 1 | @abstractmethod |
|
119 | def axes_names(self): |
||
120 | """ tuple: The names of the axes. """ |
||
121 | pass |
||
122 | |||
123 | 1 | @abstractmethod |
|
124 | def transform(self, mols): |
||
125 | """ Transform objects according to the objects transform protocol. |
||
126 | |||
127 | Args: |
||
128 | mols (skchem.Mol or pd.Series or iterable): |
||
129 | The mol objects to transform. |
||
130 | |||
131 | Returns: |
||
132 | pd.Series or pd.DataFrame |
||
133 | """ |
||
134 | pass |
||
135 | |||
136 | 1 | def __eq__(self, other): |
|
137 | return self.get_params() == other.get_params() |
||
138 | |||
139 | |||
140 | 1 | class Transformer(BaseTransformer): |
|
141 | |||
142 | """ Molecular based Transformer Base class. |
||
143 | |||
144 | Concrete Transformers inherit from this class and must implement |
||
145 | `_transform_mol` and `_columns`. |
||
146 | |||
147 | See Also: |
||
148 | AtomTransformer.""" |
||
149 | |||
150 | 1 | @property |
|
151 | 1 | @abstractmethod |
|
152 | def columns(self): |
||
153 | """ pd.Index: The column index to use. """ |
||
154 | return pd.Index(None) |
||
155 | |||
156 | 1 | @abstractmethod |
|
157 | def _transform_mol(self, mol): |
||
158 | """ Transform a molecule. """ |
||
159 | pass |
||
160 | |||
161 | 1 | def _transform_series(self, ser): |
|
162 | """ Transform a series of molecules to an np.ndarray. """ |
||
163 | 1 | LOGGER.debug('Transforming series of length %s with %s jobs', |
|
164 | len(ser), self.n_jobs) |
||
165 | |||
166 | 1 | bar = self.optional_bar(max_value=len(ser)) |
|
167 | 1 | if self.n_jobs == 1: |
|
168 | 1 | return [self._transform_mol(mol) for mol in bar(ser)] |
|
169 | else: |
||
170 | cpy = self.copy() |
||
171 | with multiprocessing.Pool(processes=self.n_jobs) as pool: |
||
172 | return [res for res in bar(pool.imap(cpy._transform_mol, ser))] |
||
173 | |||
174 | 1 | @optional_second_method |
|
175 | def transform(self, mols, **kwargs): |
||
176 | """ Transform objects according to the objects transform protocol. |
||
177 | |||
178 | Args: |
||
179 | mols (skchem.Mol or pd.Series or iterable): |
||
180 | The mol objects to transform. |
||
181 | |||
182 | Returns: |
||
183 | pd.Series or pd.DataFrame |
||
184 | """ |
||
185 | 1 | if isinstance(mols, core.Mol): |
|
186 | # just squeeze works on series |
||
187 | 1 | return pd.Series(self._transform_mol(mols), |
|
188 | index=self.columns, |
||
189 | name=self.__class__.__name__).squeeze() |
||
190 | |||
191 | 1 | elif not isinstance(mols, pd.Series): |
|
192 | 1 | mols = iterable_to_series(mols) |
|
193 | |||
194 | 1 | res = pd.DataFrame(self._transform_series(mols), |
|
195 | index=mols.index, |
||
196 | columns=self.columns) |
||
197 | |||
198 | 1 | return squeeze(res, axis=1) |
|
199 | |||
200 | 1 | @property |
|
201 | def axes_names(self): |
||
202 | """ tuple: The names of the axes. """ |
||
203 | return 'batch', self.columns.name |
||
204 | |||
205 | |||
206 | 1 | class BatchTransformer(BaseTransformer): |
|
207 | """ Mixin for which transforms on multiple molecules save overhead. |
||
208 | |||
209 | Implement `_transform_series` with the transformation rather than |
||
210 | `_transform_mol`. Must occur before `Transformer` or `AtomTransformer` in |
||
211 | method resolution order. |
||
212 | |||
213 | See Also: |
||
214 | Transformer, AtomTransformer. |
||
215 | """ |
||
216 | |||
217 | 1 | def _transform_mol(self, mol): |
|
218 | """ Transform a molecule. """ |
||
219 | |||
220 | v = self.verbose |
||
0 ignored issues
–
show
The name
v does not conform to the variable naming conventions ([a-z_][a-z0-9_]{2,30}$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site. ![]() |
|||
221 | self.verbose = False |
||
222 | res = self.transform([mol]).iloc[0] |
||
223 | self.verbose = v |
||
224 | return res |
||
225 | |||
226 | 1 | @abstractmethod |
|
227 | def _transform_series(self, ser): |
||
228 | """ Transform a series of molecules to an np.ndarray. """ |
||
229 | pass |
||
230 | |||
231 | |||
232 | 1 | class AtomTransformer(BaseTransformer): |
|
233 | """ Transformer that will produce a Panel. |
||
234 | |||
235 | Concrete classes inheriting from this should implement `_transform_atom`, |
||
236 | `_transform_mol` and `minor_axis`. |
||
237 | |||
238 | See Also: |
||
239 | Transformer |
||
240 | """ |
||
241 | |||
242 | 1 | def __init__(self, max_atoms=100, **kwargs): |
|
243 | 1 | self.max_atoms = max_atoms |
|
244 | 1 | self.major_axis = pd.RangeIndex(self.max_atoms, name='atom_idx') |
|
245 | 1 | super(AtomTransformer, self).__init__(**kwargs) |
|
246 | |||
247 | 1 | @property |
|
248 | 1 | @abstractmethod |
|
249 | def minor_axis(self): |
||
250 | """ pd.Index: Minor axis of transformed values. """ |
||
251 | return pd.Index(None) # expects a length |
||
252 | |||
253 | 1 | @property |
|
254 | def axes_names(self): |
||
255 | """ tuple: The names of the axes. """ |
||
256 | return 'batch', 'atom_idx', self.minor_axis.name |
||
257 | |||
258 | 1 | @optional_second_method |
|
259 | def transform(self, mols): |
||
260 | """ Transform objects according to the objects transform protocol. |
||
261 | |||
262 | Args: |
||
263 | mols (skchem.Mol or pd.Series or iterable): |
||
264 | The mol objects to transform. |
||
265 | |||
266 | Returns: |
||
267 | pd.Series or pd.DataFrame |
||
268 | """ |
||
269 | 1 | if isinstance(mols, core.Atom): |
|
270 | # just squeeze works on series |
||
271 | 1 | return pd.Series(self._transform_atom(mols), |
|
272 | index=self.minor_axis).squeeze() |
||
273 | |||
274 | 1 | elif isinstance(mols, core.Mol): |
|
275 | 1 | res = pd.DataFrame(self._transform_mol(mols), |
|
276 | index=self.major_axis[:len(mols.atoms)], |
||
277 | columns=self.minor_axis) |
||
278 | 1 | return squeeze(res, axis=1) |
|
279 | |||
280 | 1 | elif not isinstance(mols, pd.Series): |
|
281 | mols = iterable_to_series(mols) |
||
282 | |||
283 | 1 | res = pd.Panel(self._transform_series(mols), |
|
284 | items=mols.index, |
||
285 | major_axis=self.major_axis, |
||
286 | minor_axis=self.minor_axis) |
||
287 | |||
288 | 1 | return squeeze(res, axis=(1, 2)) |
|
289 | |||
290 | 1 | @abstractmethod |
|
291 | def _transform_atom(self, atom): |
||
292 | """ Transform an atom to a 1D array of length `len(self.columns)`. """ |
||
293 | |||
294 | pass |
||
295 | |||
296 | 1 | def _transform_mol(self, mol): |
|
297 | """ Transform a Mol to a 2D array. """ |
||
298 | |||
299 | res = nanarray((len(mol.atoms), len(self.minor_axis))) |
||
300 | for i, atom in enumerate(mol.atoms): |
||
301 | res[i] = self._transform_atom(atom) |
||
302 | return res |
||
303 | |||
304 | 1 | def _transform_series(self, ser): |
|
305 | """ Transform a Series<Mol> to a 3D array. """ |
||
306 | 1 | LOGGER.debug('Transforming series of length %s with %s jobs', |
|
307 | len(ser), self.n_jobs) |
||
308 | 1 | bar = self.optional_bar(max_value=len(ser)) |
|
309 | |||
310 | 1 | res = nanarray((len(ser), self.max_atoms, len(self.minor_axis))) |
|
311 | |||
312 | 1 | if self.n_jobs == 1: |
|
313 | 1 | for i, mol in enumerate(bar(ser)): |
|
314 | 1 | res[i, :len(mol.atoms), |
|
315 | :len(self.minor_axis)] = self._transform_mol(mol) |
||
316 | else: |
||
317 | cpy = self.copy() |
||
318 | with multiprocessing.Pool(self.n_jobs) as pool: |
||
319 | for (i, ans) in enumerate(bar(pool.imap(cpy._transform_mol, |
||
320 | ser))): |
||
321 | res[i, :len(ans), :len(self.minor_axis)] = ans |
||
322 | 1 | return res |
|
323 | |||
324 | 1 | class External(object): |
|
325 | """ Mixin for wrappers of external CLI tools. |
||
326 | |||
327 | Concrete classes must implement `validate_install`. |
||
328 | |||
329 | Attributes: |
||
330 | install_hint (str): an explanation of how to install external tool. |
||
331 | """ |
||
332 | |||
333 | 1 | __metaclass__ = ABCMeta |
|
334 | |||
335 | 1 | install_hint = "" |
|
336 | |||
337 | 1 | def __init__(self, **kwargs): |
|
338 | if not self.validated: |
||
339 | msg = 'External tool not installed. {}'.format(self.install_hint) |
||
340 | raise RuntimeError(msg) |
||
341 | super(External, self).__init__(**kwargs) |
||
342 | |||
343 | 1 | @property |
|
344 | def validated(self): |
||
345 | """ bool: whether the external tool is installed and active. """ |
||
346 | if not hasattr(self.__class__, '_validated'): |
||
347 | self.__class__._validated = self.validate_install() |
||
348 | return self.__class__._validated |
||
349 | |||
350 | 1 | @staticmethod |
|
351 | 1 | @abstractmethod |
|
352 | def validate_install(): |
||
353 | """ Determine if the external tool is available. """ |
||
354 | pass |
||
355 | |||
356 | |||
357 | 1 | class CLIWrapper(External, BaseTransformer): |
|
358 | """ CLI wrapper. |
||
359 | |||
360 | Concrete classes inheriting from this must implement `_cli_args`, |
||
361 | `monitor_progress`, `_parse_outfile`, `_parse_errors`.""" |
||
362 | |||
363 | 1 | def __init__(self, error_on_fail=False, warn_on_fail=True, **kwargs): |
|
364 | super(CLIWrapper, self).__init__(**kwargs) |
||
365 | self.error_on_fail = error_on_fail |
||
366 | self.warn_on_fail = warn_on_fail |
||
367 | |||
368 | 1 | @property |
|
369 | def n_jobs(self): |
||
370 | return self._n_jobs |
||
371 | |||
372 | 1 | @n_jobs.setter |
|
373 | def n_jobs(self, val): |
||
374 | if val != 1: |
||
375 | raise NotImplementedError('Multiprocessed external code is not yet' |
||
376 | ' supported.') |
||
377 | else: |
||
378 | self._n_jobs = val |
||
379 | |||
380 | 1 | def _transform_series(self, ser): |
|
381 | """ Transform a series. """ |
||
382 | with NamedTemporaryFile(suffix='.sdf') as infile, \ |
||
383 | NamedTemporaryFile() as outfile: |
||
384 | io.write_sdf(ser, infile.name) |
||
385 | args = self._cli_args(infile.name, outfile.name) |
||
386 | p = subprocess.Popen(args, stderr=subprocess.PIPE) |
||
0 ignored issues
–
show
The name
p does not conform to the variable naming conventions ([a-z_][a-z0-9_]{2,30}$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site. ![]() |
|||
387 | |||
388 | if self.verbose: |
||
389 | bar = self.optional_bar(max_value=len(ser)) |
||
390 | while p.poll() is None: |
||
391 | time.sleep(0.5) |
||
392 | bar.update(self.monitor_progress(outfile.name)) |
||
393 | bar.finish() |
||
394 | |||
395 | p.wait() |
||
396 | res = self._parse_outfile(outfile.name) |
||
397 | |||
398 | errs = p.stderr.read().decode() |
||
399 | errs = self._parse_errors(errs) |
||
400 | # set the index of results to that of the input, with the failed |
||
401 | # indices removed |
||
402 | if isinstance(res, (pd.Series, pd.DataFrame)): |
||
403 | res.index = ser.index.delete(errs) |
||
404 | elif isinstance(res, pd.Panel): |
||
405 | res.items = ser.index.delete(errs) |
||
406 | else: |
||
407 | msg = 'Parsed datatype ({}) not supported.'.format(type(res)) |
||
408 | raise ValueError(msg) |
||
409 | |||
410 | # go through the errors and put them back in |
||
411 | # (transform doesn't lose instances) |
||
412 | if len(errs): |
||
413 | for err in errs: |
||
414 | err = ser.index[err] |
||
415 | if self.error_on_fail: |
||
416 | raise ValueError('Failed to transform {}.'.format(err)) |
||
417 | if self.warn_on_fail: |
||
418 | LOGGER.warn('Failed to transform %s', err) |
||
419 | res.ix[err] = None |
||
420 | |||
421 | return res.loc[ser.index].values |
||
422 | |||
423 | 1 | @abstractmethod |
|
424 | def _cli_args(self, infile, outfile): |
||
425 | """ list: The cli arguments. """ |
||
426 | return [] |
||
427 | |||
428 | 1 | @abstractmethod |
|
429 | def monitor_progress(self, filename): |
||
430 | """ Report the progress. """ |
||
431 | pass |
||
432 | |||
433 | 1 | @abstractmethod |
|
434 | def _parse_outfile(self, outfile): |
||
435 | """ Parse the file written and return a series. """ |
||
436 | pass |
||
437 | |||
438 | 1 | @abstractmethod |
|
439 | def _parse_errors(self, errs): |
||
440 | """ Parse stderr and return error indices. """ |
||
441 | pass |
||
442 | |||
443 | |||
444 | 1 | class Featurizer(object): |
|
445 | |||
446 | """ Base class for m -> data transforms, such as Fingerprinting etc. |
||
447 | |||
448 | Concrete subclasses should implement `name`, returning a string uniquely |
||
449 | identifying the featurizer. """ |
||
450 | |||
451 | __metaclass__ = ABCMeta |
||
452 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.