Passed
Pull Request — main (#681)
by Osma
05:43 queued 02:50
created

EnsembleOptimizer._prepare()   A

Complexity

Conditions 4

Size

Total Lines 24
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 18
nop 2
dl 0
loc 24
rs 9.5
c 0
b 0
f 0
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
from annif.suggestion import vector_to_suggestions
10
11
from . import backend, hyperopt
12
13
14
class BaseEnsembleBackend(backend.AnnifBackend):
15
    """Base class for ensemble backends"""
16
17
    def _get_sources_attribute(self, attr):
18
        params = self._get_backend_params(None)
19
        sources = annif.util.parse_sources(params["sources"])
20
        return [
21
            getattr(self.project.registry.get_project(project_id), attr)
22
            for project_id, _ in sources
23
        ]
24
25
    def initialize(self, parallel=False):
26
        # initialize all the source projects
27
        params = self._get_backend_params(None)
28
        for project_id, _ in annif.util.parse_sources(params["sources"]):
29
            project = self.project.registry.get_project(project_id)
30
            project.initialize(parallel)
31
32
    def _normalize_suggestion_batch(self, batch, source_project):
33
        """Hook for processing a batch of suggestions from backends.
34
        Intended to be overridden by subclasses."""
35
        return batch
36
37
    def _suggest_with_sources(self, texts, sources):
38
        batches_from_sources = []
39
        for project_id, weight in sources:
40
            source_project = self.project.registry.get_project(project_id)
41
            batch = source_project.suggest(texts)
42
            norm_batch = self._normalize_suggestion_batch(batch, source_project)
43
            batches_from_sources.append(
44
                annif.suggestion.WeightedSuggestionsBatch(
45
                    hit_sets=norm_batch,
46
                    weight=weight,
47
                    subjects=source_project.subjects,
48
                )
49
            )
50
        return batches_from_sources
51
52
    def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
53
        """Hook for merging hit sets from sources. Can be overridden by
54
        subclasses."""
55
        return annif.util.merge_hits(hit_sets_from_sources)
56
57
    def _suggest_batch(self, texts, params):
58
        sources = annif.util.parse_sources(params["sources"])
59
        hit_sets_from_sources = self._suggest_with_sources(texts, sources)
60
        return annif.suggestion.SuggestionBatch.from_sequence(
61
            [
62
                vector_to_suggestions(row, int(params["limit"]))
63
                for row in self._merge_hit_sets_from_sources(
64
                    hit_sets_from_sources, params
65
                )
66
            ],
67
            self.project.subjects,
68
        )
69
70
71
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
72
    """Hyperparameter optimizer for the ensemble backend"""
73
74
    def __init__(self, backend, corpus, metric):
75
        super().__init__(backend, corpus, metric)
76
        self._sources = [
77
            project_id
78
            for project_id, _ in annif.util.parse_sources(
79
                backend.config_params["sources"]
80
            )
81
        ]
82
83
    def _prepare(self, n_jobs=1):
84
        self._gold_subjects = []
85
        self._source_hits = []
86
87
        for project_id in self._sources:
88
            project = self._backend.project.registry.get_project(project_id)
89
            project.initialize()
90
91
        psmap = annif.parallel.ProjectSuggestMap(
92
            self._backend.project.registry,
93
            self._sources,
94
            backend_params=None,
95
            limit=int(self._backend.params["limit"]),
96
            threshold=0.0,
97
        )
98
99
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
100
101
        with pool_class(jobs) as pool:
102
            for hit_sets, subject_sets in pool.imap_unordered(
103
                psmap.suggest_batch, self._corpus.doc_batches
104
            ):
105
                self._gold_subjects.extend(subject_sets)
106
                self._source_hits.extend(self._hit_sets_to_list(hit_sets))
107
108
    def _hit_sets_to_list(self, hit_sets):
109
        """Convert a dict of lists of hits to a list of dicts of hits, e.g.
110
        {"proj-1": [p-1-doc-1-hits, p-1-doc-2-hits]
111
         "proj-2": [p-2-doc-1-hits, p-2-doc-2-hits]}
112
        to
113
        [{"proj-1": p-1-doc-1-hits, "proj-2": p-2-doc-1-hits},
114
         {"proj-1": p-1-doc-2-hits, "proj-2": p-2-doc-2-hits}]
115
        """
116
        return [
117
            dict(zip(hit_sets.keys(), doc_hits)) for doc_hits in zip(*hit_sets.values())
118
        ]
119
120
    def _normalize(self, hps):
121
        total = sum(hps.values())
122
        return {source: hps[source] / total for source in hps}
123
124
    def _format_cfg_line(self, hps):
125
        return "sources=" + ",".join(
126
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
127
        )
128
129
    def _objective(self, trial):
130
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
131
        weights = {
132
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
133
            for project_id in self._sources
134
        }
135
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
136
            weighted_hits = []
137
            for project_id, hits in srchits.items():
138
                weighted_hits.append(
139
                    annif.suggestion.WeightedSuggestionsBatch(
140
                        hit_sets=[hits],
141
                        weight=weights[project_id],
142
                        subjects=self._backend.project.subjects,
143
                    )
144
                )
145
            batch.evaluate_many(
146
                [
147
                    vector_to_suggestions(row, int(self._backend.params["limit"]))
148
                    for row in annif.util.merge_hits(weighted_hits)
149
                ],
150
                [goldsubj],
151
            )
152
        results = batch.results(metrics=[self._metric])
153
        return results[self._metric]
154
155
    def _postprocess(self, study):
156
        line = self._format_cfg_line(self._normalize(study.best_params))
157
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
158
159
160
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
161
    """Ensemble backend that combines results from multiple projects"""
162
163
    name = "ensemble"
164
165
    @property
166
    def is_trained(self):
167
        sources_trained = self._get_sources_attribute("is_trained")
168
        return all(sources_trained)
169
170
    @property
171
    def modification_time(self):
172
        mtimes = self._get_sources_attribute("modification_time")
173
        return max(filter(None, mtimes), default=None)
174
175
    def get_hp_optimizer(self, corpus, metric):
176
        return EnsembleOptimizer(self, corpus, metric)
177
178
    def _train(self, corpus, params, jobs=0):
179
        raise NotSupportedException("Training ensemble backend is not possible.")
180