| Conditions | 8 |
| Total Lines | 94 |
| Code Lines | 74 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | import os |
||
| 90 | def _train( |
||
| 91 | self, |
||
| 92 | corpus: DocumentCorpus, |
||
| 93 | params: dict[str, Any], |
||
| 94 | jobs: int = 0, |
||
| 95 | ) -> None: |
||
| 96 | self.info("starting train") |
||
| 97 | self._model = EbmModel( |
||
| 98 | db_path=os.path.join(self.datadir, self.DB_FILE), |
||
| 99 | embedding_dimensions=params["embedding_dimensions"], |
||
| 100 | chunk_tokenizer=self._analyzer, |
||
| 101 | max_chunk_count=params["max_chunk_count"], |
||
| 102 | max_chunk_length=params["max_chunk_length"], |
||
| 103 | chunking_jobs=params["chunking_jobs"], |
||
| 104 | max_sentence_count=params["max_sentence_count"], |
||
| 105 | hnsw_index_params=params["hnsw_index_params"], |
||
| 106 | candidates_per_chunk=params["candidates_per_chunk"], |
||
| 107 | candidates_per_doc=params["candidates_per_doc"], |
||
| 108 | query_jobs=params["query_jobs"], |
||
| 109 | xgb_shrinkage=params["xgb_shrinkage"], |
||
| 110 | xgb_interaction_depth=params["xgb_interaction_depth"], |
||
| 111 | xgb_subsample=params["xgb_subsample"], |
||
| 112 | xgb_rounds=params["xgb_rounds"], |
||
| 113 | xgb_jobs=params["xgb_jobs"], |
||
| 114 | duckdb_threads=jobs if jobs else params["duckdb_threads"], |
||
| 115 | use_altLabels=params["use_altLabels"], |
||
| 116 | embedding_model_name=params["embedding_model_name"], |
||
| 117 | embedding_model_deployment=params["embedding_model_deployment"], |
||
| 118 | embedding_model_args=params["embedding_model_args"], |
||
| 119 | encode_args_vocab=params["encode_args_vocab"], |
||
| 120 | encode_args_documents=params["encode_args_documents"], |
||
| 121 | logger=self, |
||
| 122 | ) |
||
| 123 | |||
| 124 | if corpus != "cached": |
||
| 125 | if corpus.is_empty(): |
||
| 126 | raise NotSupportedException( |
||
| 127 | f"training backend {self.backend_id} with no documents" |
||
| 128 | ) |
||
| 129 | |||
| 130 | self.info("creating vector database") |
||
| 131 | self._model.create_vector_db( |
||
| 132 | vocab_in_path=os.path.join( |
||
| 133 | self.project.vocab.datadir, self.project.vocab.INDEX_FILENAME_TTL |
||
| 134 | ), |
||
| 135 | force=True, |
||
| 136 | ) |
||
| 137 | |||
| 138 | self.info("preparing training data") |
||
| 139 | doc_ids = [] |
||
| 140 | texts = [] |
||
| 141 | label_ids = [] |
||
| 142 | for doc_id, doc in enumerate(corpus.documents): |
||
| 143 | for subject_id in [ |
||
| 144 | subject_id for subject_id in getattr(doc, "subject_set") |
||
| 145 | ]: |
||
| 146 | doc_ids.append(doc_id) |
||
| 147 | texts.append(getattr(doc, "text")) |
||
| 148 | label_ids.append(self.project.subjects[subject_id].uri) |
||
| 149 | |||
| 150 | train_data = self._model.prepare_train( |
||
| 151 | doc_ids=doc_ids, |
||
| 152 | label_ids=label_ids, |
||
| 153 | texts=texts, |
||
| 154 | n_jobs=jobs, |
||
| 155 | ) |
||
| 156 | |||
| 157 | atomic_save( |
||
| 158 | obj=train_data, |
||
| 159 | dirname=self.datadir, |
||
| 160 | filename=self.TRAIN_FILE, |
||
| 161 | method=joblib.dump, |
||
| 162 | ) |
||
| 163 | |||
| 164 | else: |
||
| 165 | self.info("reusing cached training data from previous run") |
||
| 166 | if not os.path.exists(self._model.db_path): |
||
| 167 | raise NotInitializedException( |
||
| 168 | f"database file {self._model.db_path} not found", |
||
| 169 | backend_id=self.backend_id, |
||
| 170 | ) |
||
| 171 | if not os.path.exists(os.path.join(self.datadir, self.TRAIN_FILE)): |
||
| 172 | raise NotInitializedException( |
||
| 173 | f"train data file {self.TRAIN_FILE} not found", |
||
| 174 | backend_id=self.backend_id, |
||
| 175 | ) |
||
| 176 | |||
| 177 | train_data = joblib.load(os.path.join(self.datadir, self.TRAIN_FILE)) |
||
| 178 | |||
| 179 | self.info("training model") |
||
| 180 | self._model.train(train_data, jobs) |
||
| 181 | |||
| 182 | self.info("saving model") |
||
| 183 | atomic_save(self._model, self.datadir, self.MODEL_FILE) |
||
| 184 | |||
| 208 |