Passed
Pull Request — main (#681)
by Osma
02:36
created

EnsembleOptimizer._postprocess()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
from annif.suggestion import SuggestionBatch
10
11
from . import backend, hyperopt
12
13
14
class BaseEnsembleBackend(backend.AnnifBackend):
15
    """Base class for ensemble backends"""
16
17
    def _get_sources_attribute(self, attr):
18
        params = self._get_backend_params(None)
19
        sources = annif.util.parse_sources(params["sources"])
20
        return [
21
            getattr(self.project.registry.get_project(project_id), attr)
22
            for project_id, _ in sources
23
        ]
24
25
    def initialize(self, parallel=False):
26
        # initialize all the source projects
27
        params = self._get_backend_params(None)
28
        for project_id, _ in annif.util.parse_sources(params["sources"]):
29
            project = self.project.registry.get_project(project_id)
30
            project.initialize(parallel)
31
32
    def _suggest_with_sources(self, texts, sources):
33
        return {
34
            project_id: self.project.registry.get_project(project_id).suggest(texts)
35
            for project_id, _ in sources
36
        }
37
38
    def _merge_source_batches(self, batch_by_source, sources, params):
39
        """Merge the given SuggestionBatches from each source into a single
40
        SuggestionBatch. The default implementation computes a weighted
41
        average based on the weights given in the sources tuple. Intended
42
        to be overridden in subclasses."""
43
44
        batches = [batch_by_source[project_id] for project_id, _ in sources]
45
        weights = [weight for _, weight in sources]
46
        return SuggestionBatch.from_averaged(batches, weights).filter(
47
            limit=int(params["limit"])
48
        )
49
50
    def _suggest_batch(self, texts, params):
51
        sources = annif.util.parse_sources(params["sources"])
52
        batch_by_source = self._suggest_with_sources(texts, sources)
53
        return self._merge_source_batches(batch_by_source, sources, params)
54
55
56
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
57
    """Hyperparameter optimizer for the ensemble backend"""
58
59
    def __init__(self, backend, corpus, metric):
60
        super().__init__(backend, corpus, metric)
61
        self._sources = [
62
            project_id
63
            for project_id, _ in annif.util.parse_sources(
64
                backend.config_params["sources"]
65
            )
66
        ]
67
68
    def _prepare(self, n_jobs=1):
69
        self._gold_batches = []
70
        self._source_batches = []
71
72
        for project_id in self._sources:
73
            project = self._backend.project.registry.get_project(project_id)
74
            project.initialize()
75
76
        psmap = annif.parallel.ProjectSuggestMap(
77
            self._backend.project.registry,
78
            self._sources,
79
            backend_params=None,
80
            limit=int(self._backend.params["limit"]),
81
            threshold=0.0,
82
        )
83
84
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
85
86
        with pool_class(jobs) as pool:
87
            for suggestions, gold_batch in pool.imap_unordered(
88
                psmap.suggest_batch, self._corpus.doc_batches
89
            ):
90
                self._source_batches.append(suggestions)
91
                self._gold_batches.append(gold_batch)
92
93
    def _normalize(self, hps):
94
        total = sum(hps.values())
95
        return {source: hps[source] / total for source in hps}
96
97
    def _format_cfg_line(self, hps):
98
        return "sources=" + ",".join(
99
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
100
        )
101
102
    def _objective(self, trial):
103
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
104
        proj_weights = {
105
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
106
            for project_id in self._sources
107
        }
108
        for gold_batch, src_batches in zip(self._gold_batches, self._source_batches):
109
            batches = [src_batches[project_id] for project_id in self._sources]
110
            weights = [proj_weights[project_id] for project_id in self._sources]
111
            avg_batch = SuggestionBatch.from_averaged(batches, weights).filter(
112
                limit=int(self._backend.params["limit"])
113
            )
114
            eval_batch.evaluate_many(avg_batch, gold_batch)
115
        results = eval_batch.results(metrics=[self._metric])
116
        return results[self._metric]
117
118
    def _postprocess(self, study):
119
        line = self._format_cfg_line(self._normalize(study.best_params))
120
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
121
122
123
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
124
    """Ensemble backend that combines results from multiple projects"""
125
126
    name = "ensemble"
127
128
    @property
129
    def is_trained(self):
130
        sources_trained = self._get_sources_attribute("is_trained")
131
        return all(sources_trained)
132
133
    @property
134
    def modification_time(self):
135
        mtimes = self._get_sources_attribute("modification_time")
136
        return max(filter(None, mtimes), default=None)
137
138
    def get_hp_optimizer(self, corpus, metric):
139
        return EnsembleOptimizer(self, corpus, metric)
140
141
    def _train(self, corpus, params, jobs=0):
142
        raise NotSupportedException("Training ensemble backend is not possible.")
143