Passed
Push — master ( c17a7d...ea11a0 )
by Osma
02:27 queued 13s
created

annif.backend.mllm.MLLMBackend._train()   A

Complexity

Conditions 2

Size

Total Lines 26
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 23
nop 3
dl 0
loc 26
rs 9.328
c 0
b 0
f 0
1
"""Maui-like Lexical Matching backend"""
2
3
import os.path
4
import joblib
5
import numpy as np
6
import annif.util
7
from annif.exception import NotInitializedException
8
from annif.lexical.mllm import MLLMModel
9
from annif.suggestion import VectorSuggestionResult
10
from . import backend
11
from . import hyperopt
12
13
14
class MLLMOptimizer(hyperopt.HyperparameterOptimizer):
15
    """Hyperparameter optimizer for the MLLM backend"""
16
17
    def _prepare(self, n_jobs=1):
18
        self._backend.initialize()
19
        self._train_x, self._train_y = self._backend._load_train_data()
20
        self._candidates = []
21
        self._gold_subjects = []
22
23
        # TODO parallelize generation of candidates
24
        for doc in self._corpus.documents:
25
            candidates = self._backend._generate_candidates(doc.text)
26
            self._candidates.append(candidates)
27
            self._gold_subjects.append(
28
                annif.corpus.SubjectSet((doc.uris, doc.labels)))
29
30
    def _objective(self, trial):
31
        params = {
32
            'min_samples_leaf': trial.suggest_int('min_samples_leaf', 5, 30),
33
            'max_leaf_nodes': trial.suggest_int('max_leaf_nodes', 100, 2000),
34
            'max_samples': trial.suggest_float('max_samples', 0.5, 1.0),
35
            'use_hidden_labels':
36
                trial.suggest_categorical('use_hidden_labels', [True, False]),
37
            'limit': 100
38
        }
39
        model = self._backend._model._create_classifier(params)
40
        model.fit(self._train_x, self._train_y)
41
42
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
43
        for goldsubj, candidates in zip(self._gold_subjects, self._candidates):
44
            if candidates:
45
                features = \
46
                    self._backend._model._candidates_to_features(candidates)
47
                scores = model.predict_proba(features)
48
                ranking = self._backend._model._prediction_to_list(
49
                    scores, candidates)
50
            else:
51
                ranking = []
52
            results = self._backend._prediction_to_result(ranking, params)
53
            batch.evaluate(results, goldsubj)
54
        results = batch.results(metrics=[self._metric])
55
        return results[self._metric]
56
57
    def _postprocess(self, study):
58
        bp = study.best_params
59
        lines = [
60
            f"min_samples_leaf={bp['min_samples_leaf']}",
61
            f"max_leaf_nodes={bp['max_leaf_nodes']}",
62
            f"max_samples={bp['max_samples']:.4f}",
63
            f"use_hidden_labels={bp['use_hidden_labels']}"
64
        ]
65
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
66
67
68
class MLLMBackend(hyperopt.AnnifHyperoptBackend):
69
    """Maui-like Lexical Matching backend for Annif"""
70
    name = "mllm"
71
    needs_subject_index = True
72
73
    # defaults for unitialized instances
74
    _model = None
75
76
    MODEL_FILE = 'mllm-model.gz'
77
    TRAIN_FILE = 'mllm-train.gz'
78
79
    DEFAULT_PARAMETERS = {
80
        'min_samples_leaf': 20,
81
        'max_leaf_nodes': 1000,
82
        'max_samples': 0.9,
83
        'use_hidden_labels': False
84
    }
85
86
    def get_hp_optimizer(self, corpus, metric):
87
        return MLLMOptimizer(self, corpus, metric)
88
89
    def default_params(self):
90
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
91
        params.update(self.DEFAULT_PARAMETERS)
92
        return params
93
94
    def _load_model(self):
95
        path = os.path.join(self.datadir, self.MODEL_FILE)
96
        self.debug('loading model from {}'.format(path))
97
        if os.path.exists(path):
98
            return MLLMModel.load(path)
99
        else:
100
            raise NotInitializedException(
101
                'model {} not found'.format(path),
102
                backend_id=self.backend_id)
103
104
    def _load_train_data(self):
105
        path = os.path.join(self.datadir, self.TRAIN_FILE)
106
        if os.path.exists(path):
107
            return joblib.load(path)
108
        else:
109
            raise NotInitializedException(
110
                'train data file {} not found'.format(path),
111
                backend_id=self.backend_id)
112
113
    def initialize(self):
114
        if self._model is None:
115
            self._model = self._load_model()
116
117
    def _train(self, corpus, params):
118
        self.info('starting train')
119
        if corpus != 'cached':
120
            self.info("preparing training data")
121
            self._model = MLLMModel()
122
            train_data = self._model.prepare_train(corpus,
123
                                                   self.project.vocab,
124
                                                   self.project.analyzer,
125
                                                   params)
126
            annif.util.atomic_save(train_data,
127
                                   self.datadir,
128
                                   self.TRAIN_FILE,
129
                                   method=joblib.dump)
130
        else:
131
            self.info("reusing cached training data from previous run")
132
            self._model = self._load_model()
133
            train_data = self._load_train_data()
134
135
        self.info("training model")
136
        self._model.train(train_data[0], train_data[1], params)
137
138
        self.info('saving model')
139
        annif.util.atomic_save(
140
            self._model,
141
            self.datadir,
142
            self.MODEL_FILE)
143
144
    def _generate_candidates(self, text):
145
        return self._model.generate_candidates(text, self.project.analyzer)
146
147
    def _prediction_to_result(self, prediction, params):
148
        vector = np.zeros(len(self.project.subjects), dtype=np.float32)
149
        for score, subject_id in prediction:
150
            vector[subject_id] = score
151
        result = VectorSuggestionResult(vector)
152
        return result.filter(self.project.subjects,
153
                             limit=int(params['limit']))
154
155
    def _suggest(self, text, params):
156
        candidates = self._generate_candidates(text)
157
        prediction = self._model.predict(candidates)
158
        return self._prediction_to_result(prediction, params)
159