Passed
Pull Request — master (#677)
by Juho
02:53
created

BaseEnsembleBackend._suggest_batch()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
10
from . import backend, hyperopt
11
12
13
class BaseEnsembleBackend(backend.AnnifBackend):
14
    """Base class for ensemble backends"""
15
16
    def _get_sources_attribute(self, attr):
17
        params = self._get_backend_params(None)
18
        sources = annif.util.parse_sources(params["sources"])
19
        return [
20
            getattr(self.project.registry.get_project(project_id), attr)
21
            for project_id, _ in sources
22
        ]
23
24
    def initialize(self, parallel=False):
25
        # initialize all the source projects
26
        params = self._get_backend_params(None)
27
        for project_id, _ in annif.util.parse_sources(params["sources"]):
28
            project = self.project.registry.get_project(project_id)
29
            project.initialize(parallel)
30
31
    def _normalize_hits(self, hits, source_project):
32
        """Hook for processing hits from backends. Intended to be overridden
33
        by subclasses."""
34
        return hits
35
36
    def _suggest_with_sources(self, texts, sources):
37
        hit_sets_from_sources = []
38
        for project_id, weight in sources:
39
            source_project = self.project.registry.get_project(project_id)
40
            hit_sets = source_project.suggest(texts)
41
            norm_hit_sets = [
42
                self._normalize_hits(hits, source_project) for hits in hit_sets
43
            ]
44
            hit_sets_from_sources.append(
45
                [
46
                    annif.suggestion.WeightedSuggestion(
47
                        hits=norm_hits, weight=weight, subjects=source_project.subjects
48
                    )
49
                    for norm_hits in norm_hit_sets
50
                ]
51
            )
52
        return hit_sets_from_sources
53
54
    def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
55
        """Hook for merging hits from sources. Can be overridden by
56
        subclasses."""
57
        return [
58
            annif.util.merge_hits(hits, len(self.project.subjects))
59
            for hits in hit_sets_from_sources
60
        ]
61
62
    def _suggest_batch(self, texts, params):
63
        sources = annif.util.parse_sources(params["sources"])
64
        hit_sets_from_sources = self._suggest_with_sources(texts, sources)
65
        return self._merge_hit_sets_from_sources(hit_sets_from_sources, params)
66
67
68
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
69
    """Hyperparameter optimizer for the ensemble backend"""
70
71
    def __init__(self, backend, corpus, metric):
72
        super().__init__(backend, corpus, metric)
73
        self._sources = [
74
            project_id
75
            for project_id, _ in annif.util.parse_sources(
76
                backend.config_params["sources"]
77
            )
78
        ]
79
80
    def _prepare(self, n_jobs=1):
81
        self._gold_subjects = []
82
        self._source_hits = []
83
84
        for project_id in self._sources:
85
            project = self._backend.project.registry.get_project(project_id)
86
            project.initialize()
87
88
        psmap = annif.parallel.ProjectSuggestMap(
89
            self._backend.project.registry,
90
            self._sources,
91
            backend_params=None,
92
            limit=int(self._backend.params["limit"]),
93
            threshold=0.0,
94
        )
95
96
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
97
98
        with pool_class(jobs) as pool:
99
            for hit_sets, subject_sets in pool.imap_unordered(
100
                psmap.suggest_batch, self._corpus.doc_batches
101
            ):
102
                self._gold_subjects.extend(subject_sets)
103
                self._source_hits.extend(self._hit_sets_to_list(hit_sets))
104
105
    def _hit_sets_to_list(self, hit_sets):
106
        """Convert a dict of lists of hits to a list of dicts of hits"""
107
        return [dict(zip(hit_sets.keys(), hit)) for hit in zip(*hit_sets.values())]
108
109
    def _normalize(self, hps):
110
        total = sum(hps.values())
111
        return {source: hps[source] / total for source in hps}
112
113
    def _format_cfg_line(self, hps):
114
        return "sources=" + ",".join(
115
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
116
        )
117
118
    def _objective(self, trial):
119
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
120
        weights = {
121
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
122
            for project_id in self._sources
123
        }
124
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
125
            weighted_hits = []
126
            for project_id, hits in srchits.items():
127
                weighted_hits.append(
128
                    annif.suggestion.WeightedSuggestion(
129
                        hits=hits,
130
                        weight=weights[project_id],
131
                        subjects=self._backend.project.subjects,
132
                    )
133
                )
134
            batch.evaluate(
135
                annif.util.merge_hits(
136
                    weighted_hits, len(self._backend.project.subjects)
137
                ),
138
                goldsubj,
139
            )
140
        results = batch.results(metrics=[self._metric])
141
        return results[self._metric]
142
143
    def _postprocess(self, study):
144
        line = self._format_cfg_line(self._normalize(study.best_params))
145
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
146
147
148
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
149
    """Ensemble backend that combines results from multiple projects"""
150
151
    name = "ensemble"
152
153
    @property
154
    def is_trained(self):
155
        sources_trained = self._get_sources_attribute("is_trained")
156
        return all(sources_trained)
157
158
    @property
159
    def modification_time(self):
160
        mtimes = self._get_sources_attribute("modification_time")
161
        return max(filter(None, mtimes), default=None)
162
163
    def get_hp_optimizer(self, corpus, metric):
164
        return EnsembleOptimizer(self, corpus, metric)
165
166
    def _train(self, corpus, params, jobs=0):
167
        raise NotSupportedException("Training ensemble backend is not possible.")
168