Passed
Pull Request — master (#677)
by Juho
03:06
created

BaseEnsembleBackend._merge_hit_sets_from_sources()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 3
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
10
from . import backend, hyperopt
11
12
13
class BaseEnsembleBackend(backend.AnnifBackend):
14
    """Base class for ensemble backends"""
15
16
    def _get_sources_attribute(self, attr):
17
        params = self._get_backend_params(None)
18
        sources = annif.util.parse_sources(params["sources"])
19
        return [
20
            getattr(self.project.registry.get_project(project_id), attr)
21
            for project_id, _ in sources
22
        ]
23
24
    def initialize(self, parallel=False):
25
        # initialize all the source projects
26
        params = self._get_backend_params(None)
27
        for project_id, _ in annif.util.parse_sources(params["sources"]):
28
            project = self.project.registry.get_project(project_id)
29
            project.initialize(parallel)
30
31
    def _normalize_hits(self, hits, source_project):
32
        """Hook for processing hits from backends. Intended to be overridden
33
        by subclasses."""
34
        return hits
35
36
    def _suggest_with_sources(self, texts, sources):
37
        hit_sets_from_sources = []
38
        for project_id, weight in sources:
39
            source_project = self.project.registry.get_project(project_id)
40
            hit_sets = source_project.suggest(texts)
41
            norm_hit_sets = [
42
                self._normalize_hits(hits, source_project) for hits in hit_sets
43
            ]
44
            hit_sets_from_sources.append(
45
                [
46
                    annif.suggestion.WeightedSuggestion(
47
                        hits=norm_hits, weight=weight, subjects=source_project.subjects
48
                    )
49
                    for norm_hits in norm_hit_sets
50
                ]
51
            )
52
        return hit_sets_from_sources
53
54
    def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
55
        """Hook for merging hits from sources. Can be overridden by
56
        subclasses."""
57
        return [
58
            annif.util.merge_hits(hits, len(self.project.subjects))
59
            for hits in hit_sets_from_sources
60
        ]
61
62
    def _suggest_batch(self, texts, params):
63
        sources = annif.util.parse_sources(params["sources"])
64
        hit_sets_from_sources = self._suggest_with_sources(texts, sources)
65
        return self._merge_hit_sets_from_sources(hit_sets_from_sources, params)
66
67
68
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
69
    """Hyperparameter optimizer for the ensemble backend"""
70
71
    def __init__(self, backend, corpus, metric):
72
        super().__init__(backend, corpus, metric)
73
        self._sources = [
74
            project_id
75
            for project_id, _ in annif.util.parse_sources(
76
                backend.config_params["sources"]
77
            )
78
        ]
79
80
    def _prepare(self, n_jobs=1):
81
        self._gold_subjects = []
82
        self._source_hits = []
83
84
        for project_id in self._sources:
85
            project = self._backend.project.registry.get_project(project_id)
86
            project.initialize()
87
88
        psmap = annif.parallel.ProjectSuggestMap(
89
            self._backend.project.registry,
90
            self._sources,
91
            backend_params=None,
92
            limit=int(self._backend.params["limit"]),
93
            threshold=0.0,
94
        )
95
96
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
97
98
        with pool_class(jobs) as pool:
99
            for hit_sets, subject_sets in pool.imap_unordered(
100
                psmap.suggest_batch, self._corpus.doc_batches
101
            ):
102
                self._gold_subjects.extend(subject_sets)
103
                self._source_hits.extend(self._hit_sets_to_list(hit_sets))
104
105
    def _hit_sets_to_list(self, hit_sets):
106
        """Convert a dict of lists of hits to a list of dicts of hits, e.g.
107
        {"proj-1": [p-1-doc-1-hits, p-1-doc-2-hits]
108
         "proj-2": [p-2-doc-1-hits, p-2-doc-2-hits]}
109
        to
110
        [{"proj-1": p-1-doc-1-hits, "proj-2": p-2-doc-1-hits},
111
         {"proj-1": p-1-doc-2-hits, "proj-2": p-2-doc-2-hits}]
112
        """
113
        return [
114
            dict(zip(hit_sets.keys(), doc_hits)) for doc_hits in zip(*hit_sets.values())
115
        ]
116
117
    def _normalize(self, hps):
118
        total = sum(hps.values())
119
        return {source: hps[source] / total for source in hps}
120
121
    def _format_cfg_line(self, hps):
122
        return "sources=" + ",".join(
123
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
124
        )
125
126
    def _objective(self, trial):
127
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
128
        weights = {
129
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
130
            for project_id in self._sources
131
        }
132
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
133
            weighted_hits = []
134
            for project_id, hits in srchits.items():
135
                weighted_hits.append(
136
                    annif.suggestion.WeightedSuggestion(
137
                        hits=hits,
138
                        weight=weights[project_id],
139
                        subjects=self._backend.project.subjects,
140
                    )
141
                )
142
            batch.evaluate(
143
                annif.util.merge_hits(
144
                    weighted_hits, len(self._backend.project.subjects)
145
                ),
146
                goldsubj,
147
            )
148
        results = batch.results(metrics=[self._metric])
149
        return results[self._metric]
150
151
    def _postprocess(self, study):
152
        line = self._format_cfg_line(self._normalize(study.best_params))
153
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
154
155
156
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
157
    """Ensemble backend that combines results from multiple projects"""
158
159
    name = "ensemble"
160
161
    @property
162
    def is_trained(self):
163
        sources_trained = self._get_sources_attribute("is_trained")
164
        return all(sources_trained)
165
166
    @property
167
    def modification_time(self):
168
        mtimes = self._get_sources_attribute("modification_time")
169
        return max(filter(None, mtimes), default=None)
170
171
    def get_hp_optimizer(self, corpus, metric):
172
        return EnsembleOptimizer(self, corpus, metric)
173
174
    def _train(self, corpus, params, jobs=0):
175
        raise NotSupportedException("Training ensemble backend is not possible.")
176