Passed
Push — main ( da1836...1db6a8 )
by Osma
07:26 queued 04:14
created

EnsembleHPObjective.objective()   A

Complexity

Conditions 2

Size

Total Lines 18
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 15
nop 3
dl 0
loc 18
rs 9.65
c 0
b 0
f 0
1
"""Ensemble backend that combines results from multiple projects"""
2
3
from __future__ import annotations
4
5
from typing import TYPE_CHECKING, Any
6
7
import annif.eval
8
import annif.parallel
9
import annif.util
10
from annif.exception import NotSupportedException
11
from annif.suggestion import SuggestionBatch
12
13
from . import backend, hyperopt
14
15
if TYPE_CHECKING:
16
    from datetime import datetime
17
18
    from optuna.study.study import Study
19
    from optuna.trial import Trial
20
21
    from annif.backend.hyperopt import HPRecommendation
22
    from annif.corpus.document import Document, DocumentCorpus
23
24
25
class BaseEnsembleBackend(backend.AnnifBackend):
26
    """Base class for ensemble backends"""
27
28
    def _get_sources_attribute(self, attr: str) -> list[bool | None]:
29
        params = self._get_backend_params(None)
30
        sources = annif.util.parse_sources(params["sources"])
31
        return [
32
            getattr(self.project.registry.get_project(project_id), attr)
33
            for project_id, _ in sources
34
        ]
35
36
    def initialize(self, parallel: bool = False) -> None:
37
        # initialize all the source projects
38
        params = self._get_backend_params(None)
39
        for project_id, _ in annif.util.parse_sources(params["sources"]):
40
            project = self.project.registry.get_project(project_id)
41
            project.initialize(parallel)
42
43
    def _suggest_with_sources(
44
        self, documents: list[Document], sources: list[tuple[str, float]]
45
    ) -> dict[str, SuggestionBatch]:
46
        return {
47
            project_id: self.project.registry.get_project(project_id).suggest(documents)
48
            for project_id, _ in sources
49
        }
50
51
    def _merge_source_batches(
52
        self,
53
        batch_by_source: dict[str, SuggestionBatch],
54
        sources: list[tuple[str, float]],
55
        params: dict[str, Any],
56
    ) -> SuggestionBatch:
57
        """Merge the given SuggestionBatches from each source into a single
58
        SuggestionBatch. The default implementation computes a weighted
59
        average based on the weights given in the sources tuple. Intended
60
        to be overridden in subclasses."""
61
62
        batches = [batch_by_source[project_id] for project_id, _ in sources]
63
        weights = [weight for _, weight in sources]
64
        return SuggestionBatch.from_averaged(batches, weights).filter(
65
            limit=int(params["limit"])
66
        )
67
68
    def _suggest_batch(
69
        self, documents: list[Document], params: dict[str, Any]
70
    ) -> SuggestionBatch:
71
        sources = annif.util.parse_sources(params["sources"])
72
        batch_by_source = self._suggest_with_sources(documents, sources)
73
        return self._merge_source_batches(batch_by_source, sources, params)
74
75
76
class EnsembleHPObjective(hyperopt.HPObjective):
77
    """Objective function of the ensemble hyperparameter optimizer"""
78
79
    @classmethod
80
    def objective(cls, trial: Trial, args) -> float:
81
        eval_batch = annif.eval.EvaluationBatch(args["subject_index"])
82
        proj_weights = {
83
            project_id: trial.suggest_float(project_id, 0.0, 1.0)
84
            for project_id in args["sources"]
85
        }
86
        for gold_batch, src_batches in zip(
87
            args["gold_batches"], args["source_batches"]
88
        ):
89
            batches = [src_batches[project_id] for project_id in args["sources"]]
90
            weights = [proj_weights[project_id] for project_id in args["sources"]]
91
            avg_batch = SuggestionBatch.from_averaged(batches, weights).filter(
92
                limit=int(args["limit"])
93
            )
94
            eval_batch.evaluate_many(avg_batch, gold_batch)
95
        results = eval_batch.results(metrics=[args["metric"]])
96
        return results[args["metric"]]
97
98
99
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
100
    """Hyperparameter optimizer for the ensemble backend"""
101
102
    def __init__(
103
        self, backend: EnsembleBackend, corpus: DocumentCorpus, metric: str
104
    ) -> None:
105
        super().__init__(backend, corpus, metric, EnsembleHPObjective)
106
        self._sources = [
107
            project_id
108
            for project_id, _ in annif.util.parse_sources(
109
                backend.config_params["sources"]
110
            )
111
        ]
112
113
    def _prepare(self, n_jobs: int = 1) -> dict[str, Any]:
114
        gold_batches = []
115
        source_batches = []
116
117
        for project_id in self._sources:
118
            project = self._backend.project.registry.get_project(project_id)
119
            project.initialize()
120
121
        psmap = annif.parallel.ProjectSuggestMap(
122
            self._backend.project.registry,
123
            self._sources,
124
            backend_params=None,
125
            limit=int(self._backend.params["limit"]),
126
            threshold=0.0,
127
        )
128
129
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
130
131
        with pool_class(jobs) as pool:
132
            for suggestions, gold_batch in pool.imap_unordered(
133
                psmap.suggest_batch, self._corpus.doc_batches
134
            ):
135
                source_batches.append(suggestions)
136
                gold_batches.append(gold_batch)
137
138
        return {
139
            "gold_batches": gold_batches,
140
            "source_batches": source_batches,
141
            "subject_index": self._backend.project.subjects,
142
            "sources": self._sources,
143
            "limit": self._backend.params["limit"],
144
            "metric": self._metric,
145
        }
146
147
    def _normalize(self, hps: dict[str, float]) -> dict[str, float]:
148
        total = sum(hps.values())
149
        return {source: hps[source] / total for source in hps}
150
151
    def _format_cfg_line(self, hps: dict[str, float]) -> str:
152
        return "sources=" + ",".join(
153
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
154
        )
155
156
    def _postprocess(self, study: Study) -> HPRecommendation:
157
        line = self._format_cfg_line(self._normalize(study.best_params))
158
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
159
160
161
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
162
    """Ensemble backend that combines results from multiple projects"""
163
164
    name = "ensemble"
165
166
    @property
167
    def is_trained(self) -> bool:
168
        sources_trained = self._get_sources_attribute("is_trained")
169
        return all(sources_trained)
170
171
    @property
172
    def modification_time(self) -> datetime | None:
173
        mtimes = self._get_sources_attribute("modification_time")
174
        return max(filter(None, mtimes), default=None)
175
176
    def get_hp_optimizer(
177
        self, corpus: DocumentCorpus, metric: str
178
    ) -> EnsembleOptimizer:
179
        return EnsembleOptimizer(self, corpus, metric)
180
181
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
182
        raise NotSupportedException("Training ensemble backend is not possible.")
183