| Conditions | 8 |
| Total Lines | 93 |
| Code Lines | 73 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | import os |
||
| 88 | def _train( |
||
| 89 | self, |
||
| 90 | corpus: DocumentCorpus, |
||
| 91 | params: dict[str, Any], |
||
| 92 | jobs: int = 0, |
||
| 93 | ) -> None: |
||
| 94 | self.info("starting train") |
||
| 95 | self._model = EbmModel( |
||
| 96 | db_path=os.path.join(self.datadir, self.DB_FILE), |
||
| 97 | embedding_model_name=params["embedding_model_name"], |
||
| 98 | embedding_dimensions=params["embedding_dimensions"], |
||
| 99 | chunk_tokenizer=self._analyzer, |
||
| 100 | max_chunk_count=params["max_chunk_count"], |
||
| 101 | max_chunk_length=params["max_chunk_length"], |
||
| 102 | chunking_jobs=params["chunking_jobs"], |
||
| 103 | max_sentence_count=params["max_sentence_count"], |
||
| 104 | hnsw_index_params=params["hnsw_index_params"], |
||
| 105 | candidates_per_chunk=params["candidates_per_chunk"], |
||
| 106 | candidates_per_doc=params["candidates_per_doc"], |
||
| 107 | query_jobs=params["query_jobs"], |
||
| 108 | xgb_shrinkage=params["xgb_shrinkage"], |
||
| 109 | xgb_interaction_depth=params["xgb_interaction_depth"], |
||
| 110 | xgb_subsample=params["xgb_subsample"], |
||
| 111 | xgb_rounds=params["xgb_rounds"], |
||
| 112 | xgb_jobs=params["xgb_jobs"], |
||
| 113 | duckdb_threads=jobs if jobs else params["duckdb_threads"], |
||
| 114 | use_altLabels=params["use_altLabels"], |
||
| 115 | model_args=params["model_args"], |
||
| 116 | encode_args_vocab=params["encode_args_vocab"], |
||
| 117 | encode_args_documents=params["encode_args_documents"], |
||
| 118 | logger=self, |
||
| 119 | ) |
||
| 120 | |||
| 121 | if corpus != "cached": |
||
| 122 | if corpus.is_empty(): |
||
| 123 | raise NotSupportedException( |
||
| 124 | f"training backend {self.backend_id} with no documents" |
||
| 125 | ) |
||
| 126 | |||
| 127 | self.info("creating vector database") |
||
| 128 | self._model.create_vector_db( |
||
| 129 | vocab_in_path=os.path.join( |
||
| 130 | self.project.vocab.datadir, self.project.vocab.INDEX_FILENAME_TTL |
||
| 131 | ), |
||
| 132 | force=True, |
||
| 133 | ) |
||
| 134 | |||
| 135 | self.info("preparing training data") |
||
| 136 | doc_ids = [] |
||
| 137 | texts = [] |
||
| 138 | label_ids = [] |
||
| 139 | for doc_id, doc in enumerate(corpus.documents): |
||
| 140 | for subject_id in [ |
||
| 141 | subject_id for subject_id in getattr(doc, "subject_set") |
||
| 142 | ]: |
||
| 143 | doc_ids.append(doc_id) |
||
| 144 | texts.append(getattr(doc, "text")) |
||
| 145 | label_ids.append(self.project.subjects[subject_id].uri) |
||
| 146 | |||
| 147 | train_data = self._model.prepare_train( |
||
| 148 | doc_ids=doc_ids, |
||
| 149 | label_ids=label_ids, |
||
| 150 | texts=texts, |
||
| 151 | n_jobs=jobs, |
||
| 152 | ) |
||
| 153 | |||
| 154 | atomic_save( |
||
| 155 | obj=train_data, |
||
| 156 | dirname=self.datadir, |
||
| 157 | filename=self.TRAIN_FILE, |
||
| 158 | method=joblib.dump, |
||
| 159 | ) |
||
| 160 | |||
| 161 | else: |
||
| 162 | self.info("reusing cached training data from previous run") |
||
| 163 | if not os.path.exists(self._model.db_path): |
||
| 164 | raise NotInitializedException( |
||
| 165 | f"database file {self._model.db_path} not found", |
||
| 166 | backend_id=self.backend_id, |
||
| 167 | ) |
||
| 168 | if not os.path.exists(os.path.join(self.datadir, self.TRAIN_FILE)): |
||
| 169 | raise NotInitializedException( |
||
| 170 | f"train data file {self.TRAIN_FILE} not found", |
||
| 171 | backend_id=self.backend_id, |
||
| 172 | ) |
||
| 173 | |||
| 174 | train_data = joblib.load(os.path.join(self.datadir, self.TRAIN_FILE)) |
||
| 175 | |||
| 176 | self.info("training model") |
||
| 177 | self._model.train(train_data, jobs) |
||
| 178 | |||
| 179 | self.info("saving model") |
||
| 180 | atomic_save(self._model, self.datadir, self.MODEL_FILE) |
||
| 181 | |||
| 205 |