| 1 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | Run searches and write files. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from __future__ import annotations | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | import functools | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | import time | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from dataclasses import dataclass | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | from datetime import timedelta | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from pathlib import Path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from typing import Sequence | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | from pocketutils.core.exceptions import IllegalStateError | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | from typeddfs import Checksums, TypedDfs | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  | from mandos.model import CompoundNotFoundError | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | from mandos.model.hit_dfs import HitDf | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  | from mandos.model.hits import AbstractHit | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  | from mandos.model.search_caches import SearchCache | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  | from mandos.model.searches import Search, SearchError | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  | from mandos.model.settings import SETTINGS | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  | from mandos.model.utils.setup import logger | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 25 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 26 |  |  | def _fix_cols(df): | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                        
                            
            
                                    
            
            
                | 27 |  |  |     return df.rename(columns={s: s.lower() for s in df.columns}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  | InputCompoundsDf = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |     TypedDfs.typed("InputCompoundsDf") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |     .require("inchikey") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |     .reserve("inchi", "smiles", "compound_id", dtype=str) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |     .post(_fix_cols) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     .strict(cols=False) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |     .secure() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  | ).build() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  | class MemoizedInputCompounds: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |     @functools.cache | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |     def read_file(cls, path: Path) -> InputCompoundsDf: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |         logger.debug(f"Reading compounds from {path}") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         df = InputCompoundsDf.read_file(path) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         logger.info(f"Read {len(df)} compounds from {path}") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         return df | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  | @dataclass(frozen=True, repr=True, order=True) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  | class SearchReturnInfo: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |     n_kept: int | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |     n_processed: int | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |     n_errored: int | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |     time_taken: timedelta | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  | @dataclass(frozen=True, repr=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  | class Searcher: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |     Executes one or more searches and saves the results. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |     Create and use once. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |     what: Search | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |     input_df: InputCompoundsDf | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |     to: Path | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |     proceed: bool | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |     restart: bool | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |     def search(self) -> SearchReturnInfo: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |         Performs the search, and writes data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         inchikeys = self.input_df["inchikey"].unique() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         if self.is_complete: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |             logger.info(f"{self.to} already complete") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |             return SearchReturnInfo( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |                 n_kept=len(inchikeys), n_processed=0, n_errored=0, time_taken=timedelta(seconds=0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         logger.info(f"Will save every {SETTINGS.save_every} compounds") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |         logger.info(f"Writing {self.what.key} to {self.to}") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |         annotes = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |         compounds_run = set() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         cache = SearchCache(self.to, inchikeys, restart=self.restart, proceed=self.proceed) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         # refresh so we know it's (no longer) complete | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |         # this would only happen if we're forcing this -- which is not currently allowed | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |         ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |             Checksums() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |             .load_dirsum_of_file(self.to, missing_ok=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |             .remove(self.to, missing_ok=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |             .write(rm_if_empty=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |         t0, n0, n_proc, n_err, n_annot = time.monotonic(), cache.at, 0, 0, 0 | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         while True: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |             try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |                 compound = cache.next() | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |             except StopIteration: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |                 break | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |             try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |                 with logger.contextualize(compound=compound): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |                     x = self.what.find(compound) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |                 annotes.extend(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |             except CompoundNotFoundError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |                 logger.info(f"Compound {compound} not found for {self.what.key}") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |                 x = [] | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |                 n_err += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |             except Exception: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |                 raise SearchError( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |                     f"Failed {self.what.key} [{self.what.search_class}] on compound {compound}", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |                     compound=compound, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |                     search_key=self.what.key, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |                     search_class=self.what.search_class, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |             compounds_run.add(compound) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |             logger.debug(f"Found {len(x)} {self.what.search_name()} annotations for {compound}") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |             n_annot += len(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |             n_proc += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |             # logging, caching, and such: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |             on_nth = cache.at % SETTINGS.save_every == SETTINGS.save_every - 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |             is_last = cache.at == len(inchikeys) - 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |             if on_nth or is_last: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |                 logger.log( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |                     "NOTICE" if is_last else "INFO", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |                     f"Found {len(annotes)} {self.what.search_name()} annotations" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |                     + f" for {cache.at} of {len(inchikeys)} compounds", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |                 self._save(annotes, done=is_last) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |             cache.save(*compounds_run)  # CRITICAL -- do this AFTER saving | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         # done! | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         i1, t1 = cache.at, time.monotonic() | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |         assert i1 == len(inchikeys) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         cache.kill() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |         logger.success(f"Wrote {self.what.key} to {self.to}") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         return SearchReturnInfo( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |             n_kept=n0, n_processed=n_proc, n_errored=n_err, time_taken=timedelta(seconds=t1 - t0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |     def is_partial(self) -> bool: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |         return self.to.exists() and not self.is_complete | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |     def is_complete(self) -> bool: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |         done = self.to in Checksums().load_dirsum_of_file(self.to) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |         if done and not self.to.exists(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |             raise IllegalStateError(f"{self.to} marked complete but does not exist") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |         return done | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |     def _save(self, hits: Sequence[AbstractHit], *, done: bool) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |         df = HitDf.from_hits(hits) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |         # keep all of the original extra columns from the input | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |         # e.g. if the user had 'inchi' or 'smiles' or 'pretty_name' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         # if "origin_inchikey" not in df.columns: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         for extra_col in [c for c in self.input_df.columns if c != "inchikey"]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |             extra_mp = self.input_df.set_index("inchikey")[extra_col].to_dict() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |             df[extra_col] = df["origin_inchikey"].map(extra_mp.get) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         # write the file | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         df: HitDf = HitDf.of(df) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |         params = self.what.get_params() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |         df = df.set_attrs(**params, key=self.what.key) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |         df.write_file(self.to.resolve(), mkdirs=True, attrs=True, dir_hash=done) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |         logger.debug(f"Saved {len(df)} rows to {self.to}") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 166 |  |  | __all__ = ["Searcher", "InputCompoundsDf", "SearchReturnInfo", "MemoizedInputCompounds"] | 
            
                                                        
            
                                    
            
            
                | 167 |  |  |  |