Passed
Push — issue678-refactor-suggestionre... ( ec0260...82f1b2 )
by Osma
05:25 queued 02:48
created

EnsembleOptimizer.__init__()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 4
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
10
from . import backend, hyperopt
11
12
13
class BaseEnsembleBackend(backend.AnnifBackend):
14
    """Base class for ensemble backends"""
15
16
    def _get_sources_attribute(self, attr):
17
        params = self._get_backend_params(None)
18
        sources = annif.util.parse_sources(params["sources"])
19
        return [
20
            getattr(self.project.registry.get_project(project_id), attr)
21
            for project_id, _ in sources
22
        ]
23
24
    def initialize(self, parallel=False):
25
        # initialize all the source projects
26
        params = self._get_backend_params(None)
27
        for project_id, _ in annif.util.parse_sources(params["sources"]):
28
            project = self.project.registry.get_project(project_id)
29
            project.initialize(parallel)
30
31
    def _normalize_suggestion_batch(self, batch, source_project):
32
        """Hook for processing a batch of suggestions from backends.
33
        Intended to be overridden by subclasses."""
34
        return batch
35
36
    def _suggest_with_sources(self, texts, sources):
37
        batches_from_sources = []
38
        for project_id, weight in sources:
39
            source_project = self.project.registry.get_project(project_id)
40
            batch = source_project.suggest(texts)
41
            norm_batch = self._normalize_suggestion_batch(batch, source_project)
42
            batches_from_sources.append(
43
                annif.suggestion.WeightedSuggestionsBatch(
44
                    hit_sets=norm_batch,
45
                    weight=weight,
46
                    subjects=source_project.subjects,
47
                )
48
            )
49
        return batches_from_sources
50
51
    def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
52
        """Hook for merging hit sets from sources. Can be overridden by
53
        subclasses."""
54
        return annif.util.merge_hits(hit_sets_from_sources, len(self.project.subjects))
55
56
    def _suggest_batch(self, texts, params):
57
        sources = annif.util.parse_sources(params["sources"])
58
        hit_sets_from_sources = self._suggest_with_sources(texts, sources)
59
        return annif.suggestion.SuggestionBatch.from_sequence(
60
            self._merge_hit_sets_from_sources(hit_sets_from_sources, params),
61
            self.project.subjects,
62
        )
63
64
65
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
66
    """Hyperparameter optimizer for the ensemble backend"""
67
68
    def __init__(self, backend, corpus, metric):
69
        super().__init__(backend, corpus, metric)
70
        self._sources = [
71
            project_id
72
            for project_id, _ in annif.util.parse_sources(
73
                backend.config_params["sources"]
74
            )
75
        ]
76
77
    def _prepare(self, n_jobs=1):
78
        self._gold_subjects = []
79
        self._source_hits = []
80
81
        for project_id in self._sources:
82
            project = self._backend.project.registry.get_project(project_id)
83
            project.initialize()
84
85
        psmap = annif.parallel.ProjectSuggestMap(
86
            self._backend.project.registry,
87
            self._sources,
88
            backend_params=None,
89
            limit=int(self._backend.params["limit"]),
90
            threshold=0.0,
91
        )
92
93
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
94
95
        with pool_class(jobs) as pool:
96
            for hit_sets, subject_sets in pool.imap_unordered(
97
                psmap.suggest_batch, self._corpus.doc_batches
98
            ):
99
                self._gold_subjects.extend(subject_sets)
100
                self._source_hits.extend(self._hit_sets_to_list(hit_sets))
101
102
    def _hit_sets_to_list(self, hit_sets):
103
        """Convert a dict of lists of hits to a list of dicts of hits, e.g.
104
        {"proj-1": [p-1-doc-1-hits, p-1-doc-2-hits]
105
         "proj-2": [p-2-doc-1-hits, p-2-doc-2-hits]}
106
        to
107
        [{"proj-1": p-1-doc-1-hits, "proj-2": p-2-doc-1-hits},
108
         {"proj-1": p-1-doc-2-hits, "proj-2": p-2-doc-2-hits}]
109
        """
110
        return [
111
            dict(zip(hit_sets.keys(), doc_hits)) for doc_hits in zip(*hit_sets.values())
112
        ]
113
114
    def _normalize(self, hps):
115
        total = sum(hps.values())
116
        return {source: hps[source] / total for source in hps}
117
118
    def _format_cfg_line(self, hps):
119
        return "sources=" + ",".join(
120
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
121
        )
122
123
    def _objective(self, trial):
124
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
125
        weights = {
126
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
127
            for project_id in self._sources
128
        }
129
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
130
            weighted_hits = []
131
            for project_id, hits in srchits.items():
132
                weighted_hits.append(
133
                    annif.suggestion.WeightedSuggestionsBatch(
134
                        hit_sets=[hits],
135
                        weight=weights[project_id],
136
                        subjects=self._backend.project.subjects,
137
                    )
138
                )
139
            batch.evaluate_many(
140
                annif.util.merge_hits(
141
                    weighted_hits, len(self._backend.project.subjects)
142
                ),
143
                [goldsubj],
144
            )
145
        results = batch.results(metrics=[self._metric])
146
        return results[self._metric]
147
148
    def _postprocess(self, study):
149
        line = self._format_cfg_line(self._normalize(study.best_params))
150
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
151
152
153
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
154
    """Ensemble backend that combines results from multiple projects"""
155
156
    name = "ensemble"
157
158
    @property
159
    def is_trained(self):
160
        sources_trained = self._get_sources_attribute("is_trained")
161
        return all(sources_trained)
162
163
    @property
164
    def modification_time(self):
165
        mtimes = self._get_sources_attribute("modification_time")
166
        return max(filter(None, mtimes), default=None)
167
168
    def get_hp_optimizer(self, corpus, metric):
169
        return EnsembleOptimizer(self, corpus, metric)
170
171
    def _train(self, corpus, params, jobs=0):
172
        raise NotSupportedException("Training ensemble backend is not possible.")
173