1
|
|
|
from __future__ import annotations |
|
|
|
|
2
|
|
|
|
3
|
|
|
from typing import Sequence, Optional, Tuple, Set |
4
|
|
|
|
5
|
|
|
import numpy as np |
|
|
|
|
6
|
|
|
|
7
|
|
|
from mandos.model.apis.chembl_support.chembl_target_graphs import ( |
8
|
|
|
ChemblTargetGraphFactory, |
9
|
|
|
ChemblTargetGraph, |
10
|
|
|
) |
11
|
|
|
from mandos.model.apis.chembl_support.chembl_targets import TargetFactory |
12
|
|
|
from pocketutils.tools.string_tools import StringTools |
|
|
|
|
13
|
|
|
|
14
|
|
|
from mandos import logger |
|
|
|
|
15
|
|
|
from mandos.model.taxonomy import Taxonomy, Taxon |
16
|
|
|
from typeddfs import TypedDf |
|
|
|
|
17
|
|
|
|
18
|
|
|
from mandos.model.apis.chembl_api import ChemblApi |
|
|
|
|
19
|
|
|
from mandos.model.apis.chembl_scrape_api import ( |
20
|
|
|
ChemblScrapePage, |
21
|
|
|
ChemblScrapeApi, |
22
|
|
|
ChemblTargetPredictionTable, |
23
|
|
|
) |
24
|
|
|
from mandos.model.apis.chembl_support import ChemblCompound |
25
|
|
|
from mandos.model.apis.chembl_support.chembl_utils import ChemblUtils |
26
|
|
|
from mandos.search.chembl import ChemblScrapeSearch |
27
|
|
|
from mandos.model.concrete_hits import ChemblTargetPredictionHit |
28
|
|
|
from mandos.search.chembl.target_traversal import TargetTraversalStrategies |
29
|
|
|
|
30
|
|
|
P = ChemblScrapePage.target_predictions |
31
|
|
|
T = ChemblTargetPredictionTable |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
class TargetPredictionSearch(ChemblScrapeSearch[ChemblTargetPredictionHit]): |
|
|
|
|
35
|
|
|
""" """ |
36
|
|
|
|
37
|
|
|
@classmethod |
38
|
|
|
def page(cls) -> ChemblScrapePage: |
|
|
|
|
39
|
|
|
return ChemblScrapePage.target_predictions |
40
|
|
|
|
41
|
|
|
def __init__( |
|
|
|
|
42
|
|
|
self, |
|
|
|
|
43
|
|
|
key: str, |
|
|
|
|
44
|
|
|
api: ChemblApi, |
|
|
|
|
45
|
|
|
scrape: ChemblScrapeApi, |
|
|
|
|
46
|
|
|
taxa: Sequence[Taxonomy], |
|
|
|
|
47
|
|
|
traversal: str, |
|
|
|
|
48
|
|
|
target_types: Set[str], |
|
|
|
|
49
|
|
|
required_level: int = 70, |
|
|
|
|
50
|
|
|
min_threshold: float = 1.0, |
|
|
|
|
51
|
|
|
binding_score: float = 1.0, |
|
|
|
|
52
|
|
|
nonbinding_score: float = 1.0, |
|
|
|
|
53
|
|
|
): |
54
|
|
|
super().__init__(key, api, scrape) |
55
|
|
|
self.taxa = taxa |
56
|
|
|
self.traversal_strategy = TargetTraversalStrategies.by_name(traversal, self.api) |
57
|
|
|
self.target_types = target_types |
58
|
|
|
if required_level not in [70, 80, 90]: |
59
|
|
|
raise ValueError(f"required_level must be 70, 80, or 90, not {required_level}") |
60
|
|
|
if min_threshold <= 0: |
61
|
|
|
raise ValueError(f"min_threshold must be positive, not {min_threshold}") |
62
|
|
|
if binding_score <= 0: |
63
|
|
|
raise ValueError(f"binding_score must be positive, not {binding_score}") |
64
|
|
|
if nonbinding_score <= 0: |
65
|
|
|
raise ValueError(f"nonbinding_score must be positive, not {nonbinding_score}") |
66
|
|
|
self.required_level = required_level |
67
|
|
|
self.min_threshold = min_threshold |
68
|
|
|
self.binding_score = binding_score |
69
|
|
|
self.nonbinding_score = nonbinding_score |
70
|
|
|
|
71
|
|
|
@property |
72
|
|
|
def data_source(self) -> str: |
|
|
|
|
73
|
|
|
return "ChEMBL :: target predictions" |
74
|
|
|
|
75
|
|
|
def find(self, lookup: str) -> Sequence[ChemblTargetPredictionHit]: |
|
|
|
|
76
|
|
|
ch = ChemblUtils(self.api).get_compound_dot_dict(lookup) |
|
|
|
|
77
|
|
|
compound = ChemblUtils(self.api).compound_dot_dict_to_obj(ch) |
78
|
|
|
table: TypedDf = self.scrape.fetch_predictions(compound.chid) |
79
|
|
|
hits = [] |
80
|
|
|
for row in table.itertuples(): |
81
|
|
|
hits.extend(self.process(lookup, compound, row)) |
82
|
|
|
return hits |
83
|
|
|
|
84
|
|
|
def process( |
|
|
|
|
85
|
|
|
self, lookup: str, compound: ChemblCompound, row |
|
|
|
|
86
|
|
|
) -> Sequence[ChemblTargetPredictionHit]: |
87
|
|
|
tax_id, tax_name = self._get_taxon(row.target_organism) |
88
|
|
|
if tax_id is tax_name is None: |
89
|
|
|
return [] |
90
|
|
|
thresh = row.activity_threshold |
91
|
|
|
if row.activity_threshold < self.min_threshold: |
92
|
|
|
return [] |
93
|
|
|
factory = TargetFactory(self.api) |
94
|
|
|
target_obj = factory.find(row.target_chembl_id) |
95
|
|
|
graph_factory = ChemblTargetGraphFactory.create(self.api, factory) |
96
|
|
|
graph = graph_factory.at_target(target_obj) |
97
|
|
|
ancestors: Sequence[ChemblTargetGraph] = self.traversal_strategy(graph) |
98
|
|
|
lst = [] |
99
|
|
|
for ancestor in ancestors: |
100
|
|
|
for conf_t, conf_v in zip( |
101
|
|
|
[70, 80, 90], [row.confidence_70, row.confidence_80, row.confidence_90] |
|
|
|
|
102
|
|
|
): |
103
|
|
|
predicate = f"binding:{conf_v.yes_no_mixed}" |
104
|
|
|
weight = ( |
105
|
|
|
np.sqrt(thresh) |
106
|
|
|
* abs(conf_t / (100 - conf_t) * conf_v.score) |
107
|
|
|
/ 4 |
108
|
|
|
/ np.sqrt(self.min_threshold) |
109
|
|
|
) |
110
|
|
|
hit = self._create_hit( |
111
|
|
|
c_origin=lookup, |
112
|
|
|
c_matched=compound.inchikey, |
113
|
|
|
c_id=compound.chid, |
114
|
|
|
c_name=compound.name, |
115
|
|
|
predicate=predicate, |
116
|
|
|
object_id=ancestor.chembl, |
117
|
|
|
object_name=ancestor.name, |
118
|
|
|
data_source=self.data_source, |
119
|
|
|
exact_target_id=row.target_chembl_id, |
120
|
|
|
exact_target_name=row.target_pref_name, |
121
|
|
|
weight=weight, |
122
|
|
|
prediction=conf_v, |
123
|
|
|
confidence_set=conf_t, |
124
|
|
|
threshold=thresh, |
125
|
|
|
) |
126
|
|
|
lst.append(hit) |
127
|
|
|
return lst |
128
|
|
|
|
129
|
|
|
def _get_taxon(self, organism: str) -> Tuple[Optional[int], Optional[str]]: |
130
|
|
|
if len(self.taxa) == 0: # allow all |
131
|
|
|
return None, organism |
132
|
|
|
matches = {} |
133
|
|
|
for tax in self.taxa: |
134
|
|
|
matches += tax.get_by_id_or_name(organism) |
135
|
|
|
if len(matches) == 0: |
136
|
|
|
logger.debug(f"Taxon {organism} not in set. Excluding.") |
137
|
|
|
return None, None |
138
|
|
|
best: Taxon = next(iter(matches)) |
139
|
|
|
if best.scientific_name != organism and best.mnemonic != organism: |
|
|
|
|
140
|
|
|
logger.warning(f"Organism {organism} matched to {best.scientific_name} by common name") |
141
|
|
|
if len(matches) > 1: |
142
|
|
|
logger.warning( |
143
|
|
|
f"Multiple matches for taxon {organism}: {matches}; using {best.scientific_name}" |
144
|
|
|
) |
145
|
|
|
return best.id, organism |
146
|
|
|
|
147
|
|
|
|
148
|
|
|
__all__ = ["TargetPredictionSearch"] |
149
|
|
|
|