1
|
|
|
#! /usr/bin/env python |
2
|
|
|
# |
3
|
|
|
# Copyright (C) 2016 Rich Lewis <[email protected]> |
4
|
|
|
# License: 3-clause BSD |
5
|
|
|
|
6
|
1 |
|
""" |
7
|
|
|
## skchem.cross_validation.similarity_threshold |
8
|
|
|
|
9
|
|
|
Similarity threshold dataset partitioning functionality. |
10
|
|
|
""" |
11
|
|
|
|
12
|
1 |
|
import logging |
13
|
1 |
|
import sys |
|
|
|
|
14
|
|
|
|
15
|
1 |
|
import numpy as np |
|
|
|
|
16
|
1 |
|
import pandas as pd |
|
|
|
|
17
|
1 |
|
import matplotlib.pyplot as plt |
|
|
|
|
18
|
1 |
|
from scipy.spatial.distance import pdist, squareform |
|
|
|
|
19
|
1 |
|
from scipy.sparse import triu |
|
|
|
|
20
|
1 |
|
from scipy.optimize import minimize_scalar |
|
|
|
|
21
|
|
|
|
22
|
1 |
|
import multiprocessing |
23
|
1 |
|
from functools import partial, wraps |
24
|
|
|
|
25
|
1 |
|
from .. import descriptors |
26
|
|
|
|
27
|
1 |
|
LOGGER = logging.getLogger(__name__) |
28
|
|
|
|
29
|
|
|
|
30
|
1 |
|
def returns_pairs(func): |
31
|
|
|
""" Wraps a function that returns a ((i, j), sim) list to return a dataframe. """ |
32
|
1 |
|
@wraps(func) |
33
|
|
|
def inner(*args, **kwargs): |
|
|
|
|
34
|
1 |
|
pairs = func(*args, **kwargs) |
35
|
1 |
|
return pd.DataFrame([(p[0][0], p[0][1], p[1]) for p in pairs], columns=['i', 'j', 'sim']).sort_values('sim') |
|
|
|
|
36
|
1 |
|
return inner |
37
|
|
|
|
38
|
|
|
|
39
|
1 |
|
def _above_minimum(args, X, metric, threshold, size): |
|
|
|
|
40
|
|
|
""" finds pairs above a minimum similarity in chunks """ |
41
|
1 |
|
from scipy.spatial.distance import cdist |
|
|
|
|
42
|
1 |
|
from scipy.sparse import dok_matrix |
|
|
|
|
43
|
1 |
|
import numpy as np |
|
|
|
|
44
|
1 |
|
i, j = slice(*args[0]), slice(*args[1]) |
45
|
1 |
|
x_i, x_j = X[i], X[j] |
46
|
1 |
|
C = 1 - cdist(x_i, x_j, metric=metric) |
|
|
|
|
47
|
1 |
|
if i == j: |
48
|
1 |
|
C = np.triu(C, k=1) |
|
|
|
|
49
|
1 |
|
C[C <= threshold] = 0 |
50
|
1 |
|
M = dok_matrix((size, size), dtype=float) |
|
|
|
|
51
|
1 |
|
M[i, j] = C |
52
|
1 |
|
return list(M.items()) |
53
|
|
|
|
54
|
|
|
|
55
|
1 |
|
class SimThresholdSplit(object): |
|
|
|
|
56
|
|
|
|
57
|
1 |
|
def __init__(self, min_threshold=0.45, largest_cluster_fraction=0.1, fper='morgan', |
|
|
|
|
58
|
|
|
similarity_metric='jaccard', memory_optimized=True, n_jobs=1, block_width=1000, |
59
|
|
|
verbose=False): |
60
|
|
|
|
61
|
|
|
""" Threshold similarity split for chemical datasets. |
62
|
|
|
|
63
|
|
|
This class implements a splitting technique that will pool compounds |
64
|
|
|
with similarity above a theshold into the same splits. The threshold |
65
|
|
|
value is decided by specifying the maximum number of compounds to pool |
66
|
|
|
into a cluster, as the density of compounds varies with dataset. |
67
|
|
|
|
68
|
|
|
Machine learning techniques should be able to extrapolate outside of a |
69
|
|
|
molecular series, or scaffold, however random splits will result in some |
70
|
|
|
'easy' test sets that are either *identical* or in the same molecular |
71
|
|
|
series or share a significant scaffold with training set compounds. |
72
|
|
|
|
73
|
|
|
This splitting technique reduces or eliminates (depending on the |
74
|
|
|
threshold set) this effect, making the problem harder. |
75
|
|
|
|
76
|
|
|
Args: |
77
|
|
|
|
78
|
|
|
min_threshold (float): |
79
|
|
|
The minimum similarity threshold. Lower will be slower. |
80
|
|
|
|
81
|
|
|
largest_cluster_fraction (float): |
82
|
|
|
The fraction of the total dataset the largest cluster can be. This decided the |
83
|
|
|
final similarity threshold. |
84
|
|
|
|
85
|
|
|
fper (str or skchem.Fingerprinter): |
86
|
|
|
The fingerprinting technique to use to generate the similarity |
87
|
|
|
matrix. |
88
|
|
|
|
89
|
|
|
similarity_metric (str or callable): |
90
|
|
|
The similarity metric to use. |
91
|
|
|
|
92
|
|
|
memory_optimized (bool): |
93
|
|
|
Whether to use the memory optimized implementation. |
94
|
|
|
|
95
|
|
|
n_jobs (int): |
96
|
|
|
If memory_optimized is True, how many processes to run it over. |
97
|
|
|
|
98
|
|
|
block_width (int): |
99
|
|
|
If memory_optimized, what block length to use. This is the width of the sub |
100
|
|
|
matrices that are calculated at a time. |
101
|
|
|
|
102
|
|
|
Notes: |
103
|
|
|
The splits will not always be exactly the size requested, due to the |
104
|
|
|
constraint and requirement to maintain random shuffling. |
105
|
|
|
""" |
106
|
|
|
|
107
|
|
|
if isinstance(fper, str): |
108
|
|
|
fper = descriptors.get(fper) |
109
|
|
|
|
110
|
|
|
self.fper = fper |
111
|
|
|
self.similarity_metric = similarity_metric |
112
|
|
|
self.memory_optimized = memory_optimized |
113
|
|
|
self.n_jobs = n_jobs |
114
|
|
|
self._block_width = block_width |
115
|
1 |
|
self.min_threshold = min_threshold |
116
|
|
|
self.largest_cluster = largest_cluster_fraction |
117
|
|
|
|
118
|
1 |
|
if self.fper: |
119
|
1 |
|
self.fper.verbose = verbose |
120
|
1 |
|
|
121
|
1 |
|
def fit(self, inp, pairs=None): |
122
|
1 |
|
|
123
|
1 |
|
""" |
124
|
1 |
|
Args: |
125
|
1 |
|
inp (pd.Series or pd.DataFrame or np.array): |
126
|
1 |
|
- `pd.Series` of `skchem.Mol` instances |
127
|
|
|
- `pd.DataFrame` with `skchm.Mol` instances as a `structure` row. |
128
|
1 |
|
- `pd.DataFrame` of fingerprints if `fper` is `None` |
129
|
|
|
- `pd.DataFrame` of similarity matrix if `similarity_metric` is `None` |
130
|
|
|
- `np.array` of similarity matrix if `similarity_metric` is `None` |
131
|
|
|
|
132
|
1 |
|
pairs (list<tuple<tuple(i, j), k>>): |
133
|
1 |
|
An optional precalculated list of pairwise distances. |
134
|
|
|
""" |
135
|
|
|
|
136
|
|
|
self.n_instances_ = len(inp) |
137
|
1 |
|
self.pairs_ = pairs |
|
|
|
|
138
|
|
|
|
139
|
|
|
if isinstance(inp, (pd.Series, pd.DataFrame)): |
140
|
|
|
self.index = inp.index |
|
|
|
|
141
|
1 |
|
else: |
142
|
|
|
self.index = pd.RangeIndex(len(inp), name='batch') |
|
|
|
|
143
|
1 |
|
|
144
|
1 |
|
if self.similarity_metric is None: |
145
|
1 |
|
# we were passed a similarity matrix directly |
146
|
|
|
self.pairs_ = self._pairs_from_sim_mat(inp) |
|
|
|
|
147
|
|
|
|
148
|
|
|
elif self.fper is None: |
149
|
|
|
# we were passed fingerprints directly |
150
|
|
|
self.fps = inp |
|
|
|
|
151
|
|
|
if self.pairs_ is None: |
152
|
|
|
self.pairs_ = self._pairs_from_fps(inp) |
|
|
|
|
153
|
1 |
|
|
154
|
|
|
else: |
155
|
1 |
|
# we were passed Mol |
156
|
|
|
if self.pairs_ is None: |
157
|
|
|
self.fps = self.fper.transform(inp) |
|
|
|
|
158
|
1 |
|
self.pairs_ = self._pairs_from_fps(self.fps) |
|
|
|
|
159
|
|
|
|
160
|
1 |
|
self.threshold_ = self._optimal_thresh() |
|
|
|
|
161
|
|
|
|
162
|
1 |
|
return self |
163
|
|
|
|
164
|
|
|
@property |
165
|
1 |
|
def n_jobs(self): |
166
|
|
|
""" The number of processes to use to calculate the distance matrix. -1 for all available. """ |
|
|
|
|
167
|
1 |
|
return self._n_jobs |
168
|
|
|
|
169
|
|
|
@n_jobs.setter |
170
|
1 |
|
def n_jobs(self, val): |
|
|
|
|
171
|
|
|
if val == -1: |
172
|
1 |
|
self._n_jobs = multiprocessing.cpu_count() |
|
|
|
|
173
|
|
|
else: |
174
|
1 |
|
self._n_jobs = val |
|
|
|
|
175
|
1 |
|
|
176
|
|
|
@property |
177
|
1 |
|
def block_width(self): |
178
|
|
|
""" The width of the subsets of features. Only used in parallelized. """ |
179
|
1 |
|
return self._block_width |
180
|
1 |
|
|
181
|
1 |
|
@block_width.setter |
182
|
1 |
|
def block_width(self, val): |
|
|
|
|
183
|
|
|
assert val <= self.n_instances_, 'The block width should be less than or equal to the number of instances' |
|
|
|
|
184
|
1 |
|
self._block_width = val |
185
|
|
|
|
186
|
|
|
@property |
187
|
|
|
def n_instances_(self): |
188
|
|
|
""" The number of instances that were used to fit the object. """ |
189
|
|
|
return self._n_instances_ |
190
|
|
|
|
191
|
|
|
@n_instances_.setter |
192
|
|
|
def n_instances_(self): |
|
|
|
|
193
|
|
|
assert val >= self._block_width, 'The block width should be less than or equal to the number of instances' |
|
|
|
|
194
|
|
|
self._n_instances_ = val |
|
|
|
|
195
|
|
|
|
196
|
|
|
def _cluster_cumsum(self, shuffled=True): |
197
|
|
|
|
198
|
|
|
nums = self.clusters.value_counts() |
199
|
|
|
if shuffled: |
200
|
|
|
nums = nums.ix[np.random.permutation(nums.index)].cumsum() |
201
|
1 |
|
return nums |
202
|
1 |
|
|
203
|
1 |
|
def split(self, ratio): |
204
|
|
|
|
205
|
1 |
|
""" Return splits of the data with thresholded similarity according to a |
206
|
1 |
|
specified ratio. |
207
|
1 |
|
|
208
|
1 |
|
Args: |
209
|
|
|
ratio (tuple[ints]): |
210
|
1 |
|
the ratio to use. |
211
|
1 |
|
Returns: |
212
|
1 |
|
generator[pd.Series]: |
213
|
|
|
Generator of boolean split masks for the reqested splits. |
214
|
1 |
|
|
215
|
|
|
Example: |
216
|
|
|
st = SimThresholdSplit(ms, fper='morgan', similarity_metric='jaccard') |
217
|
|
|
train, valid, test = st.split(ratio=(70, 15, 15)) |
218
|
|
|
""" |
219
|
|
|
|
220
|
|
|
ratio = self._split_sizes(ratio) |
221
|
|
|
nums = self._cluster_cumsum() |
222
|
|
|
res = pd.Series(np.nan, index=nums.index, name='split') |
223
|
|
|
|
224
|
|
|
for i in range(len(ratio)): |
225
|
|
|
lower = 0 if i == 0 else sum(ratio[:i]) |
226
|
|
|
upper = ratio if i == len(ratio) else sum(ratio[:i + 1]) |
227
|
1 |
|
res[nums[(nums > lower) & (nums <= upper)].index] = i |
228
|
1 |
|
|
229
|
|
|
res = res.sort_index() |
230
|
1 |
|
res = self.clusters.to_frame().join(res, on='clusters')['split'] |
231
|
|
|
return (res == i for i in range(len(ratio))) |
232
|
|
|
|
233
|
1 |
|
def k_fold(self, n_folds): |
234
|
1 |
|
|
235
|
|
|
""" Returns k-fold cross-validated folds with thresholded similarity. |
236
|
1 |
|
|
237
|
1 |
|
Args: |
238
|
|
|
n_folds (int): |
239
|
|
|
The number of folds to provide. |
240
|
|
|
|
241
|
|
|
Returns: |
242
|
1 |
|
generator[(pd.Series, pd.Series)]: |
243
|
|
|
The splits in series. |
244
|
|
|
""" |
245
|
1 |
|
|
246
|
1 |
|
folds = self.split((1,) * n_folds) |
247
|
|
|
return ((~fold, fold) for fold in folds) |
248
|
|
|
|
249
|
|
|
def _split_sizes(self, ratio): |
250
|
1 |
|
""" Calculate the sizes of the splits """ |
251
|
|
|
|
252
|
1 |
|
tot = sum(ratio) |
253
|
|
|
return [self.n_instances_ * rat / tot for rat in ratio] |
254
|
|
|
|
255
|
|
|
@staticmethod |
256
|
|
|
@returns_pairs |
257
|
|
|
def _pairs_from_sim_mat(S): |
|
|
|
|
258
|
|
|
S = triu(S, k=1).todok() |
259
|
|
|
return list(S.items()) |
260
|
1 |
|
|
261
|
|
|
@returns_pairs |
262
|
|
|
def _pairs_from_fps(self, fps): |
263
|
|
|
""" Pairs from fps. """ |
264
|
1 |
|
if self.memory_optimized: |
265
|
|
|
pairs = self._pairs_from_fps_mem_opt(fps) |
266
|
1 |
|
else: |
267
|
1 |
|
pairs = self._pairs_from_fps_mem_intensive(fps) |
268
|
1 |
|
|
269
|
1 |
|
return pairs |
270
|
|
|
|
271
|
|
|
def _pairs_from_fps_mem_intensive(self, fps): |
272
|
1 |
|
""" Fast single process but memory intensive implementation of pairs. """ |
273
|
1 |
|
LOGGER.debug('Generating pairs using memory intensive technique.') |
274
|
1 |
|
D = squareform(pdist(fps, self.similarity_metric)) |
|
|
|
|
275
|
|
|
S = 1 - D # similarity is 1 - distance |
|
|
|
|
276
|
1 |
|
S[S <= self.min_threshold] = 0 |
277
|
|
|
return self._pairs_from_sim_mat(S) |
278
|
1 |
|
|
279
|
1 |
|
def _pairs_from_fps_mem_opt(self, fps): |
280
|
1 |
|
|
281
|
|
|
""" Fast, multi-processed and memory efficient generation of pairwise distances above a certain threshold.""" |
|
|
|
|
282
|
1 |
|
|
283
|
|
|
def slice_generator(low, high, width, end=False): |
284
|
1 |
|
""" Generator of index of checkerboards for the upper triangle of a matrix. """ |
285
|
1 |
|
while low < high: |
286
|
|
|
res = (low, low + width if low + width < high else high) |
287
|
|
|
if end: |
288
|
|
|
yield res |
289
|
|
|
else: |
290
|
|
|
# py2 compat |
291
|
|
|
# yield from ((res, j) for j in slice_generator(low, high, width, end=True)) |
292
|
|
|
for slice_ in ((res, j) for j in slice_generator(low, high, width, end=True)): |
293
|
|
|
yield slice_ |
294
|
|
|
low += width |
295
|
|
|
|
296
|
1 |
|
size = len(fps) |
297
|
|
|
|
298
|
|
|
fps = fps.values |
299
|
1 |
|
f = partial(_above_minimum, X=fps, threshold=self.min_threshold, metric=self.similarity_metric, size=size) |
|
|
|
|
300
|
1 |
|
slices = slice_generator(0, len(fps), self.block_width) |
301
|
|
|
|
302
|
1 |
|
if self.n_jobs == 1: |
303
|
1 |
|
# single processed |
304
|
1 |
|
LOGGER.debug('Generating pairs using memory optimized technique.') |
305
|
1 |
|
return sum((f(slice) for slice in slices), []) |
306
|
|
|
else: |
307
|
1 |
|
# multiprocessed |
308
|
1 |
|
LOGGER.debug('Generating pairs using memory optimized technique with %s processes', self.n_jobs) |
|
|
|
|
309
|
|
|
# py2 compat |
310
|
1 |
|
# with multiprocessing.Pool(self.n_jobs) as p: |
311
|
|
|
# return sum(p.map(f, [(i, j) for i, j in slices]), []) |
312
|
1 |
|
p = multiprocessing.Pool(self.n_jobs) |
|
|
|
|
313
|
1 |
|
res = sum(p.map(f, [(i, j) for i, j in slices]), []) |
314
|
1 |
|
p.close() |
315
|
1 |
|
return res |
316
|
|
|
|
317
|
1 |
|
def _cluster(self, pairs): |
318
|
1 |
|
""" Assign instances to clusters. """ |
319
|
1 |
|
|
320
|
|
|
LOGGER.debug('Generating clusters with %s close pairs', len(pairs)) |
321
|
|
|
clustered = np.arange(self.n_instances_) |
322
|
1 |
|
|
323
|
|
|
for i, j in pairs.values.tolist(): # faster as list |
324
|
1 |
|
i_clust, j_clust = clustered[i], clustered[j] |
325
|
|
|
if i_clust < j_clust: |
326
|
|
|
clustered[clustered == j_clust] = i_clust |
327
|
|
|
else: |
328
|
|
|
clustered[clustered == i_clust] = j_clust |
329
|
|
|
return clustered |
330
|
|
|
|
331
|
|
|
def _optimal_thresh(self): |
332
|
|
|
""" Calculate the optimal threshold for the given max pair density. """ |
333
|
|
|
def f(threshold): |
|
|
|
|
334
|
|
|
pairs = self.pairs_.loc[self.pairs_.sim > threshold, ('i', 'j')] |
|
|
|
|
335
|
|
|
res = pd.Series(self._cluster(pairs)) |
336
|
|
|
return np.abs(res.value_counts().max() - self.largest_cluster * self.n_instances_) |
337
|
|
|
|
338
|
|
|
self.threshold_ = minimize_scalar(f, bounds=(self.min_threshold, 1), method='bounded').x |
|
|
|
|
339
|
|
|
LOGGER.info('Optimal threshold: %s', self.threshold_) |
340
|
|
|
self.clusters = pd.Series(self._cluster(self.pairs_.loc[self.pairs_.sim > self.threshold_, ('i', 'j')]), |
|
|
|
|
341
|
|
|
index=self.index, |
342
|
|
|
name='clusters') |
343
|
|
|
return self.threshold_ |
344
|
|
|
|
345
|
|
|
def visualize_similarities(self, subsample=5000, ax=None): |
|
|
|
|
346
|
|
|
|
347
|
|
|
""" Plot a histogram of similarities, with the threshold plotted. |
348
|
|
|
|
349
|
|
|
Args: |
350
|
|
|
subsample (int): |
351
|
|
|
For a large dataset, subsample the number of compounds to |
352
|
|
|
consider. |
353
|
1 |
|
ax (matplotlib.axis): |
354
|
|
|
Axis to make the plot on. |
355
|
|
|
Returns: |
356
|
|
|
matplotlib.axes |
357
|
|
|
""" |
358
|
|
|
|
359
|
|
|
if not ax: |
360
|
|
|
ax = plt.gca() |
361
|
|
|
|
362
|
|
|
if subsample and len(self.fps) > subsample: |
363
|
|
|
fps = self.fps.sample(subsample) |
364
|
|
|
else: |
365
|
|
|
fps = self.fps |
366
|
|
|
|
367
|
|
|
dists = 1 - squareform(pdist(fps, self.similarity_metric)) |
368
|
|
|
dists = (dists - np.identity(dists.shape[0])).flatten() |
369
|
|
|
hist = ax.hist(dists, bins=50) |
370
|
|
|
ax.vlines(self.threshold_, 0, max(hist[0])) |
371
|
|
|
ax.set_xlabel('similarity') |
372
|
|
|
return ax |
373
|
|
|
|
374
|
|
|
def visualize_space(self, dim_reducer='tsne', dim_red_kw={}, subsample=5000, ax=None, c=None): |
|
|
|
|
375
|
|
|
|
376
|
|
|
""" Plot chemical space using a transformer |
377
|
|
|
|
378
|
|
|
Args: |
379
|
|
|
dim_reducer (str or sklearn object): |
380
|
|
|
Technique to use to reduce fingerprint space. |
381
|
|
|
|
382
|
|
|
subsample (int): |
383
|
|
|
for a large dataset, subsample the number of compounds to |
384
|
|
|
consider. |
385
|
|
|
|
386
|
|
|
ax (matplotlib.axis): |
387
|
|
|
Axis to make the plot on. |
388
|
|
|
Returns: |
389
|
|
|
matplotlib.axes |
390
|
|
|
""" |
391
|
|
|
|
392
|
|
|
if isinstance(dim_reducer, str): |
393
|
|
|
if dim_reducer not in ('tsne', 'mds'): |
394
|
|
|
raise NotImplementedError('Dimensionality reducer {} not available'.format(dim_reducer)) |
|
|
|
|
395
|
|
|
from sklearn.manifold import TSNE, MDS |
|
|
|
|
396
|
|
|
reducers = {'tsne': TSNE, 'mds': MDS} |
397
|
|
|
dim_reducer = reducers[dim_reducer](**dim_red_kw) |
|
|
|
|
398
|
|
|
|
399
|
|
|
two_d = dim_reducer.fit_transform(self.fps) |
|
|
|
|
400
|
|
|
|
401
|
|
|
if not ax: |
402
|
|
|
ax = plt.gca() |
403
|
|
|
|
404
|
|
|
return ax.scatter(two_d[:, 0], two_d[:, 1], c=c) |
405
|
|
|
|
406
|
|
|
|
407
|
|
|
|
408
|
|
|
|