| 1 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | X. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from __future__ import annotations | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from dataclasses import dataclass | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | from pathlib import Path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | from typing import TypeVar, Mapping, Sequence, List | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | import numpy as np | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | import pandas as pd | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from typeddfs import BaseDf | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from typeddfs.df_errors import UnsupportedOperationError | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | from mandos import logger | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | from mandos.analysis.io_defns import SimilarityDfLongForm, SimilarityDfShortForm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | from mandos.entries.searcher import InputFrame | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  | from mandos.model.rdkit_utils import RdkitUtils | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  | T = TypeVar("T", bound=BaseDf) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  | @dataclass(frozen=True, repr=True) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  | class MatrixPrep: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |     kind: str | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |     normalize: bool | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |     log: bool | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |     invert: bool | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 29 |  |  |     def from_files(self, paths: Sequence[Path]) -> SimilarityDfLongForm: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                        
                            
            
                                    
            
            
                | 30 |  |  |         dct = {} | 
            
                                                                        
                            
            
                                    
            
            
                | 31 |  |  |         for p in paths: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                        
                            
            
                                    
            
            
                | 32 |  |  |             key = p.with_suffix("").name | 
            
                                                                        
                            
            
                                    
            
            
                | 33 |  |  |             try: | 
            
                                                                        
                            
            
                                    
            
            
                | 34 |  |  |                 mx = SimilarityDfShortForm.read_file(p) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                        
                            
            
                                    
            
            
                | 35 |  |  |                 dct[key] = mx | 
            
                                                                        
                            
            
                                    
            
            
                | 36 |  |  |             except (OSError, UnsupportedOperationError, ValueError): | 
            
                                                                        
                            
            
                                    
            
            
                | 37 |  |  |                 logger.error(f"Failed to load matrix at {str(p)}") | 
            
                                                                        
                            
            
                                    
            
            
                | 38 |  |  |                 raise | 
            
                                                                        
                            
            
                                    
            
            
                | 39 |  |  |         return self.create(dct) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     def create(self, key_to_mx: Mapping[str, SimilarityDfShortForm]) -> SimilarityDfLongForm: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         df = SimilarityDfLongForm( | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |             pd.concat([mx.to_long_form(self.kind, key) for key, mx in key_to_mx.items()]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         vals = df["value"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         if self.invert: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |             vals = -vals | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |         if self.normalize: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |             mn, mx = vals.min(), vals.max() | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |             vals = (vals - mn) / (mn - mx) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         if self.log: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |             # this is a bit stupid, but calc the log then normalize again | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |             # we can't take the log before normalization because we might have negative values | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |             vals = vals.map(np.log10) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |             mn, mx = vals.min(), vals.max() | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |             vals = (vals - mn) / (mn - mx) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |         df["value"] = vals | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |         return SimilarityDfLongForm.convert(df) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |     def ecfp_matrix(cls, df: InputFrame, radius: int, n_bits: int) -> SimilarityDfShortForm: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                            
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         # TODO: This is inefficient and long | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         indices = range(len(df)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         keys = df["inchikey"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         on_bits = [ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |             RdkitUtils.ecfp(c, radius=radius, n_bits=n_bits).list_on for c in df.get_structures() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         ] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         the_rows: List[List[float]] = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         for i, row_key, row_print in zip(indices, keys, on_bits): | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |             for j, col_key, col_print in zip(indices, keys, on_bits): | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |                 the_row = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |                 if i < j: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |                     jaccard = len(row_print.intersection(col_print)) / len( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |                         row_print.union(col_print) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |                     ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |                     the_row.append(jaccard) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |                 the_rows.append(the_row) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         short = SimilarityDfShortForm(the_rows) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |         short["inchikey"] = keys | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |         short = short.set_index("inchikey") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         short.columns = keys | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |         return SimilarityDfShortForm.convert(short) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 85 |  |  | __all__ = ["MatrixPrep"] | 
            
                                                        
            
                                    
            
            
                | 86 |  |  |  |