|
1
|
|
|
""" |
|
2
|
|
|
Calculations of overlap (similarity) between annotation sets. |
|
3
|
|
|
""" |
|
4
|
|
|
import abc |
|
5
|
|
|
import enum |
|
6
|
|
|
import math |
|
7
|
|
|
import time |
|
8
|
|
|
from collections import defaultdict |
|
9
|
|
|
from pathlib import Path |
|
10
|
|
|
from typing import Collection, Sequence, Type, Union |
|
11
|
|
|
|
|
12
|
|
|
import decorateme |
|
|
|
|
|
|
13
|
|
|
import numpy as np |
|
|
|
|
|
|
14
|
|
|
import pandas as pd |
|
|
|
|
|
|
15
|
|
|
from pocketutils.core.chars import Chars |
|
|
|
|
|
|
16
|
|
|
from pocketutils.core.enums import CleverEnum |
|
|
|
|
|
|
17
|
|
|
from pocketutils.tools.unit_tools import UnitTools |
|
|
|
|
|
|
18
|
|
|
from typeddfs.df_errors import HashFileMissingError |
|
|
|
|
|
|
19
|
|
|
|
|
20
|
|
|
from mandos.analysis import AnalysisUtils as Au |
|
21
|
|
|
from mandos.analysis.io_defns import SimilarityDfLongForm, SimilarityDfShortForm |
|
22
|
|
|
from mandos.model.hit_dfs import HitDf |
|
23
|
|
|
from mandos.model.hits import AbstractHit |
|
24
|
|
|
from mandos.model.utils import unlink |
|
25
|
|
|
|
|
26
|
|
|
# note that most of these math functions are much faster than their numpy counterparts |
|
27
|
|
|
# if we're not broadcasting, it's almost always better to use them |
|
28
|
|
|
# some are more accurate, too |
|
29
|
|
|
# e.g. we're using fsum rather than sum |
|
30
|
|
|
from mandos.model.utils.setup import logger |
|
31
|
|
|
|
|
32
|
|
|
|
|
33
|
|
|
@decorateme.auto_repr_str() |
|
|
|
|
|
|
34
|
|
|
class MatrixCalculator(metaclass=abc.ABCMeta): |
|
35
|
|
|
def __init__(self, *, min_compounds: int, min_nonzero: int, min_hits: int): |
|
36
|
|
|
self.min_compounds = min_compounds |
|
37
|
|
|
self.min_nonzero = min_nonzero |
|
38
|
|
|
self.min_hits = min_hits |
|
39
|
|
|
|
|
40
|
|
|
def calc_all(self, hits: Path, to: Path, *, keep_temp: bool = False) -> SimilarityDfLongForm: |
|
|
|
|
|
|
41
|
|
|
raise NotImplemented() |
|
|
|
|
|
|
42
|
|
|
|
|
43
|
|
|
|
|
44
|
|
|
class _Inf: |
|
45
|
|
|
def __init__(self, n: int): |
|
46
|
|
|
self.n = n |
|
|
|
|
|
|
47
|
|
|
self.used, self.t0, self.nonzeros = set(), time.monotonic(), 0 |
|
|
|
|
|
|
48
|
|
|
|
|
49
|
|
|
def is_used(self, c1: str, c2: str) -> bool: |
|
|
|
|
|
|
50
|
|
|
return (c1, c2) in self.used or (c2, c1) in self.used |
|
51
|
|
|
|
|
52
|
|
|
def got(self, c1: str, c2: str, z: float) -> None: |
|
|
|
|
|
|
53
|
|
|
self.used.add((c1, c2)) |
|
54
|
|
|
self.nonzeros += int(c1 != c2 and not np.isnan(z) and 0 < z < 1) |
|
55
|
|
|
if self.i % 20000 == 0: |
|
56
|
|
|
self.log("info") |
|
57
|
|
|
|
|
58
|
|
|
@property |
|
59
|
|
|
def i(self) -> int: |
|
|
|
|
|
|
60
|
|
|
return len(self.used) |
|
61
|
|
|
|
|
62
|
|
|
def log(self, level: str) -> None: |
|
|
|
|
|
|
63
|
|
|
delta = UnitTools.delta_time_to_str(time.monotonic() - self.t0, space=Chars.narrownbsp) |
|
64
|
|
|
logger.log( |
|
65
|
|
|
level.upper(), |
|
66
|
|
|
f"Processed {self.i:,}/{self.n:,} pairs in {delta};" |
|
67
|
|
|
+ f" {self.nonzeros:,} ({self.nonzeros / self.i * 100:.1f}%) are nonzero", |
|
68
|
|
|
) |
|
69
|
|
|
|
|
70
|
|
|
def __repr__(self): |
|
71
|
|
|
return f"{self.__class__.__name__}({self.i}/{self.n})" |
|
72
|
|
|
|
|
73
|
|
|
def __str__(self): |
|
74
|
|
|
return repr(self) |
|
75
|
|
|
|
|
76
|
|
|
|
|
77
|
|
|
class JPrimeMatrixCalculator(MatrixCalculator): |
|
|
|
|
|
|
78
|
|
|
def calc_all(self, path: Path, to: Path, *, keep_temp: bool = False) -> SimilarityDfLongForm: |
|
|
|
|
|
|
79
|
|
|
hits = HitDf.read_file(path).to_hits() |
|
80
|
|
|
key_to_hit = Au.hit_multidict(hits, "search_key") |
|
81
|
|
|
logger.notice(f"Calculating J on {len(key_to_hit):,} keys from {len(hits):,} hits") |
|
82
|
|
|
deltas, files, good_keys = [], [], {} |
|
83
|
|
|
for key, key_hits in key_to_hit.items(): |
|
84
|
|
|
key: str = key |
|
85
|
|
|
key_hits: Sequence[AbstractHit] = key_hits |
|
86
|
|
|
n_compounds_0 = len({k.origin_inchikey for k in key_hits}) |
|
87
|
|
|
part_path = self._path_of(path, key) |
|
88
|
|
|
n_compounds_in_mx = None |
|
89
|
|
|
n_nonzero = None |
|
90
|
|
|
df = None |
|
|
|
|
|
|
91
|
|
|
if part_path.exists(): |
|
92
|
|
|
try: |
|
93
|
|
|
df = SimilarityDfLongForm.read_file( |
|
|
|
|
|
|
94
|
|
|
part_path, file_hash=False |
|
95
|
|
|
) # TODO: file_hash=True |
|
|
|
|
|
|
96
|
|
|
logger.warning(f"Results for key {key} already exist ({len(df):,} rows)") |
|
97
|
|
|
n_compounds_in_mx = len(df["inchikey_1"].unique()) |
|
98
|
|
|
except HashFileMissingError: |
|
99
|
|
|
logger.error(f"Extant results for key {key} appear incomplete; restarting") |
|
100
|
|
|
logger.opt(exception=True).debug(f"Hash error for {key}") |
|
101
|
|
|
unlink(part_path) |
|
102
|
|
|
# now let it go into the next block -- calculate from scratch |
|
103
|
|
|
if n_compounds_0 >= self.min_compounds: |
|
104
|
|
|
t1 = time.monotonic() |
|
|
|
|
|
|
105
|
|
|
df: SimilarityDfShortForm = self.calc_one(key, key_hits) |
|
|
|
|
|
|
106
|
|
|
t2 = time.monotonic() |
|
|
|
|
|
|
107
|
|
|
deltas.append(t2 - t1) |
|
108
|
|
|
df = df.to_long_form(kind="psi", key=key) |
|
|
|
|
|
|
109
|
|
|
n_compounds_in_mx = len(df["inchikey_1"].unique()) |
|
110
|
|
|
df.write_file(part_path) |
|
111
|
|
|
logger.debug(f"Wrote results for {key} to {part_path}") |
|
112
|
|
|
if df is not None: |
|
113
|
|
|
n_nonzero = len(df[df["value"] > 0]) |
|
114
|
|
|
if n_compounds_in_mx < self.min_compounds: |
|
115
|
|
|
logger.warning( |
|
116
|
|
|
f"Key {key} has {n_compounds_in_mx:,} < {self.min_compounds:,} compounds; skipping" |
|
|
|
|
|
|
117
|
|
|
) |
|
118
|
|
|
elif len(key_hits) < self.min_hits: |
|
119
|
|
|
logger.warning( |
|
120
|
|
|
f"Key {key} has {len(key_hits):,} < {self.min_hits:,} hits; skipping" |
|
121
|
|
|
) |
|
122
|
|
|
elif n_nonzero is not None and n_nonzero < self.min_nonzero: |
|
123
|
|
|
logger.warning( |
|
124
|
|
|
f"Key {key} has {n_nonzero:,} < {self.min_nonzero:,} nonzero pairs; skipping" |
|
125
|
|
|
) # TODO: percent nonzero? |
|
|
|
|
|
|
126
|
|
|
else: |
|
127
|
|
|
files.append(part_path) |
|
128
|
|
|
good_keys[key] = n_compounds_in_mx |
|
129
|
|
|
del df |
|
130
|
|
|
logger.debug(f"Concatenating {len(files):,} files") |
|
131
|
|
|
df = SimilarityDfLongForm( |
|
|
|
|
|
|
132
|
|
|
pd.concat( |
|
133
|
|
|
[SimilarityDfLongForm.read_file(self._path_of(path, k)) for k in good_keys.keys()] |
|
|
|
|
|
|
134
|
|
|
) |
|
135
|
|
|
) |
|
136
|
|
|
logger.notice(f"Included {len(good_keys):,} keys: {', '.join(good_keys.keys())}") |
|
137
|
|
|
quartiles = {} |
|
138
|
|
|
for k, v in good_keys.items(): |
|
|
|
|
|
|
139
|
|
|
vals = df[df["key"] == k]["value"] |
|
140
|
|
|
qs = {x: vals.quantile(x) for x in [0, 0.25, 0.5, 0.75, 1]} |
|
|
|
|
|
|
141
|
|
|
quartiles[k] = list(qs.values()) |
|
142
|
|
|
logger.info(f"Key {k} has {v:,} compounds and {len(key_to_hit[k]):,} hits") |
|
143
|
|
|
logger.info( |
|
144
|
|
|
f" {k} {Chars.fatright} unique values = {len(vals.unique())} unique values" |
|
145
|
|
|
) |
|
146
|
|
|
logger.info(f" {k} {Chars.fatright} quartiles: " + " | ".join(qs.values())) |
|
147
|
|
|
df = df.set_attrs( |
|
|
|
|
|
|
148
|
|
|
dict( |
|
149
|
|
|
keys={ |
|
150
|
|
|
k: dict(compounds=v, hits=len(key_to_hit[k]), quartiles=quartiles[k]) |
|
151
|
|
|
for k, v in good_keys.items() |
|
152
|
|
|
} |
|
153
|
|
|
) |
|
154
|
|
|
) |
|
155
|
|
|
df.write_file(to, attrs=True, file_hash=True) |
|
156
|
|
|
logger.notice(f"Wrote {len(df):,} rows to {to}") |
|
157
|
|
|
if not keep_temp: |
|
158
|
|
|
for k in key_to_hit.keys(): |
|
159
|
|
|
unlink(self._path_of(path, k)) |
|
160
|
|
|
return df |
|
161
|
|
|
|
|
162
|
|
|
def calc_one(self, key: str, hits: Sequence[AbstractHit]) -> SimilarityDfShortForm: |
|
|
|
|
|
|
163
|
|
|
ik2hits = Au.hit_multidict(hits, "origin_inchikey") |
|
164
|
|
|
logger.info(f"Calculating J on {key} for {len(ik2hits):,} compounds and {len(hits):,} hits") |
|
165
|
|
|
data = defaultdict(dict) |
|
166
|
|
|
inf = _Inf(n=int(len(ik2hits) * (len(ik2hits) - 1) / 2)) |
|
167
|
|
|
for (c1, hits1) in ik2hits.items(): |
|
|
|
|
|
|
168
|
|
|
for (c2, hits2) in ik2hits.items(): |
|
|
|
|
|
|
169
|
|
|
if inf.is_used(c1, c2): |
|
170
|
|
|
continue |
|
171
|
|
|
z = 1 if c1 == c2 else self._j_prime(key, hits1, hits2) |
|
|
|
|
|
|
172
|
|
|
data[c1][c2] = z |
|
173
|
|
|
inf.got(c1, c2, z) |
|
174
|
|
|
inf.log("success") |
|
175
|
|
|
return SimilarityDfShortForm.from_dict(data) |
|
176
|
|
|
|
|
177
|
|
|
def _path_of(self, path: Path, key: str): |
|
|
|
|
|
|
178
|
|
|
return path.parent / f".{path.name}-{key}.tmp.feather" |
|
179
|
|
|
|
|
180
|
|
|
def _j_prime( |
|
181
|
|
|
self, key: str, hits1: Collection[AbstractHit], hits2: Collection[AbstractHit] |
|
|
|
|
|
|
182
|
|
|
) -> float: |
|
183
|
|
|
if len(hits1) == 0 or len(hits2) == 0: |
|
184
|
|
|
return 0 |
|
185
|
|
|
sources = {h.data_source for h in hits1}.intersection({h.data_source for h in hits2}) |
|
186
|
|
|
if len(sources) == 0: |
|
187
|
|
|
return float("NaN") |
|
188
|
|
|
values = [ |
|
189
|
|
|
self._jx( |
|
190
|
|
|
key, |
|
191
|
|
|
[h for h in hits1 if h.data_source == source], |
|
192
|
|
|
[h for h in hits2 if h.data_source == source], |
|
193
|
|
|
) |
|
194
|
|
|
for source in sources |
|
195
|
|
|
] |
|
196
|
|
|
return float(math.fsum(values) / len(values)) |
|
197
|
|
|
|
|
198
|
|
|
def _jx( |
|
199
|
|
|
self, key: str, hits1: Collection[AbstractHit], hits2: Collection[AbstractHit] |
|
|
|
|
|
|
200
|
|
|
) -> float: |
|
201
|
|
|
# TODO -- for testing only |
|
|
|
|
|
|
202
|
|
|
# TODO: REMOVE ME! |
|
|
|
|
|
|
203
|
|
|
if key in ["core.chemidplus.effects", "extra.chemidplus.specific-effects"]: |
|
204
|
|
|
hits1 = [h.copy(weight=math.pow(10, -h.weight)) for h in hits1] |
|
205
|
|
|
hits2 = [h.copy(weight=math.pow(10, -h.weight)) for h in hits2] |
|
206
|
|
|
pair_to_weights = Au.weights_of_pairs(hits1, hits2) |
|
207
|
|
|
values = [self._wedge(ca, cb) / self._vee(ca, cb) for ca, cb in pair_to_weights.values()] |
|
208
|
|
|
return float(math.fsum(values) / len(values)) |
|
209
|
|
|
|
|
210
|
|
|
def _wedge(self, ca: float, cb: float) -> float: |
|
|
|
|
|
|
211
|
|
|
return math.sqrt(Au.elle(ca) * Au.elle(cb)) |
|
212
|
|
|
|
|
213
|
|
|
def _vee(self, ca: float, cb: float) -> float: |
|
|
|
|
|
|
214
|
|
|
return Au.elle(ca) + Au.elle(cb) - math.sqrt(Au.elle(ca) * Au.elle(cb)) |
|
215
|
|
|
|
|
216
|
|
|
|
|
217
|
|
|
class MatrixAlg(CleverEnum): |
|
|
|
|
|
|
218
|
|
|
j = enum.auto() |
|
219
|
|
|
|
|
220
|
|
|
@property |
|
221
|
|
|
def clazz(self) -> Type[MatrixCalculator]: |
|
|
|
|
|
|
222
|
|
|
return {MatrixAlg.j: JPrimeMatrixCalculator}[self] |
|
223
|
|
|
|
|
224
|
|
|
|
|
225
|
|
|
@decorateme.auto_utils() |
|
|
|
|
|
|
226
|
|
|
class MatrixCalculation: |
|
227
|
|
|
@classmethod |
|
228
|
|
|
def create( |
|
|
|
|
|
|
229
|
|
|
cls, |
|
|
|
|
|
|
230
|
|
|
algorithm: Union[str, MatrixAlg], |
|
|
|
|
|
|
231
|
|
|
*, |
|
|
|
|
|
|
232
|
|
|
min_compounds: int, |
|
|
|
|
|
|
233
|
|
|
min_nonzero: int, |
|
|
|
|
|
|
234
|
|
|
min_hits: int, |
|
|
|
|
|
|
235
|
|
|
) -> MatrixCalculator: |
|
236
|
|
|
return MatrixAlg.of(algorithm).clazz( |
|
237
|
|
|
min_compounds=min_compounds, min_nonzero=min_nonzero, min_hits=min_hits |
|
238
|
|
|
) |
|
239
|
|
|
|
|
240
|
|
|
|
|
241
|
|
|
__all__ = ["MatrixCalculator", "JPrimeMatrixCalculator", "MatrixCalculation", "MatrixAlg"] |
|
242
|
|
|
|