Passed
Push — testing-on-windows-and-macos ( 782857...ea99ad )
by Juho
04:06
created

annif.backend.mllm.MLLMOptimizer._postprocess()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Maui-like Lexical Matching backend"""
2
3
from __future__ import annotations
4
5
import os.path
6
from typing import TYPE_CHECKING, Any
7
8
import joblib
9
import numpy as np
10
11
import annif.eval
12
import annif.util
13
from annif.exception import NotInitializedException, NotSupportedException
14
from annif.lexical.mllm import MLLMModel
15
from annif.suggestion import vector_to_suggestions
16
17
from . import hyperopt
18
19
if TYPE_CHECKING:
20
    from collections.abc import Iterator
21
22
    from optuna.study.study import Study
23
    from optuna.trial import Trial
24
25
    from annif.backend.hyperopt import HPRecommendation
26
    from annif.corpus.document import DocumentCorpus
27
    from annif.lexical.mllm import Candidate
28
29
30
class MLLMOptimizer(hyperopt.HyperparameterOptimizer):
31
    """Hyperparameter optimizer for the MLLM backend"""
32
33
    def _prepare(self, n_jobs: int = 1) -> None:
34
        self._backend.initialize()
35
        self._train_x, self._train_y = self._backend._load_train_data()
36
        self._candidates = []
37
        self._gold_subjects = []
38
39
        # TODO parallelize generation of candidates
40
        for doc in self._corpus.documents:
41
            candidates = self._backend._generate_candidates(doc.text)
42
            self._candidates.append(candidates)
43
            self._gold_subjects.append(doc.subject_set)
44
45
    def _objective(self, trial: Trial) -> float:
46
        params = {
47
            "min_samples_leaf": trial.suggest_int("min_samples_leaf", 5, 30),
48
            "max_leaf_nodes": trial.suggest_int("max_leaf_nodes", 100, 2000),
49
            "max_samples": trial.suggest_float("max_samples", 0.5, 1.0),
50
            "limit": 100,
51
        }
52
        model = self._backend._model._create_classifier(params)
53
        model.fit(self._train_x, self._train_y)
54
55
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
56
        for goldsubj, candidates in zip(self._gold_subjects, self._candidates):
57
            if candidates:
58
                features = self._backend._model._candidates_to_features(candidates)
59
                scores = model.predict_proba(features)
60
                ranking = self._backend._model._prediction_to_list(scores, candidates)
61
            else:
62
                ranking = []
63
            results = self._backend._prediction_to_result(ranking, params)
64
            batch.evaluate_many([results], [goldsubj])
65
        results = batch.results(metrics=[self._metric])
66
        return results[self._metric]
67
68
    def _postprocess(self, study: Study) -> HPRecommendation:
69
        bp = study.best_params
70
        lines = [
71
            f"min_samples_leaf={bp['min_samples_leaf']}",
72
            f"max_leaf_nodes={bp['max_leaf_nodes']}",
73
            f"max_samples={bp['max_samples']:.4f}",
74
        ]
75
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
76
77
78
class MLLMBackend(hyperopt.AnnifHyperoptBackend):
79
    """Maui-like Lexical Matching backend for Annif"""
80
81
    name = "mllm"
82
83
    # defaults for unitialized instances
84
    _model = None
85
86
    MODEL_FILE = "mllm-model.gz"
87
    TRAIN_FILE = "mllm-train.gz"
88
89
    DEFAULT_PARAMETERS = {
90
        "min_samples_leaf": 20,
91
        "max_leaf_nodes": 1000,
92
        "max_samples": 0.9,
93
        "use_hidden_labels": False,
94
    }
95
96
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> MLLMOptimizer:
97
        return MLLMOptimizer(self, corpus, metric)
98
99
    def _load_model(self) -> MLLMModel:
100
        path = os.path.join(self.datadir, self.MODEL_FILE)
101
        self.debug("loading model from {}".format(path))
102
        if os.path.exists(path):
103
            return MLLMModel.load(path)
104
        else:
105
            raise NotInitializedException(
106
                "model {} not found".format(path), backend_id=self.backend_id
107
            )
108
109
    def _load_train_data(self) -> tuple[np.ndarray, np.ndarray]:
110
        path = os.path.join(self.datadir, self.TRAIN_FILE)
111
        if os.path.exists(path):
112
            return joblib.load(path)
113
        else:
114
            raise NotInitializedException(
115
                "train data file {} not found".format(path), backend_id=self.backend_id
116
            )
117
118
    def initialize(self, parallel: bool = False) -> None:
119
        if self._model is None:
120
            self._model = self._load_model()
121
122
    def _train(
123
        self,
124
        corpus: DocumentCorpus,
125
        params: dict[str, Any],
126
        jobs: int = 0,
127
    ) -> None:
128
        self.info("starting train")
129
        if corpus != "cached":
130
            if corpus.is_empty():
131
                raise NotSupportedException(
132
                    "training backend {} with no documents".format(self.backend_id)
133
                )
134
            self.info("preparing training data")
135
            self._model = MLLMModel()
136
            train_data = self._model.prepare_train(
137
                corpus, self.project.vocab, self.project.analyzer, params, jobs
138
            )
139
            annif.util.atomic_save(
140
                train_data, self.datadir, self.TRAIN_FILE, method=joblib.dump
141
            )
142
        else:
143
            self.info("reusing cached training data from previous run")
144
            self._model = self._load_model()
145
            train_data = self._load_train_data()
146
147
        self.info("training model")
148
        self._model.train(train_data[0], train_data[1], params)
149
150
        self.info("saving model")
151
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
152
153
    def _generate_candidates(self, text: str) -> list[Candidate]:
154
        return self._model.generate_candidates(text, self.project.analyzer)
155
156
    def _prediction_to_result(
157
        self,
158
        prediction: list[tuple[np.float64, int]],
159
        params: dict[str, Any],
160
    ) -> Iterator:
161
        vector = np.zeros(len(self.project.subjects), dtype=np.float32)
162
        for score, subject_id in prediction:
163
            vector[subject_id] = score
164
        return vector_to_suggestions(vector, int(params["limit"]))
165
166
    def _suggest(self, text: str, params: dict[str, Any]) -> Iterator:
167
        candidates = self._generate_candidates(text)
168
        prediction = self._model.predict(candidates)
169
        return self._prediction_to_result(prediction, params)
170