Passed
Pull Request — master (#677)
by Juho
03:22
created

BaseEnsembleBackend._suggest_batch()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
10
from . import backend, hyperopt
11
12
13
class BaseEnsembleBackend(backend.AnnifBackend):
14
    """Base class for ensemble backends"""
15
16
    def _get_sources_attribute(self, attr):
17
        params = self._get_backend_params(None)
18
        sources = annif.util.parse_sources(params["sources"])
19
        return [
20
            getattr(self.project.registry.get_project(project_id), attr)
21
            for project_id, _ in sources
22
        ]
23
24
    def initialize(self, parallel=False):
25
        # initialize all the source projects
26
        params = self._get_backend_params(None)
27
        for project_id, _ in annif.util.parse_sources(params["sources"]):
28
            project = self.project.registry.get_project(project_id)
29
            project.initialize(parallel)
30
31
    def _normalize_hits(self, hits, source_project):
32
        """Hook for processing hits from backends. Intended to be overridden
33
        by subclasses."""
34
        return hits
35
36
    def _suggest_with_sources(self, texts, sources):
37
        hit_sets_from_sources = []
38
        for project_id, weight in sources:
39
            source_project = self.project.registry.get_project(project_id)
40
            hit_sets = source_project.suggest(texts)
41
            norm_hit_sets = [
42
                self._normalize_hits(hits, source_project) for hits in hit_sets
43
            ]
44
            hit_sets_from_sources.append(
45
                annif.suggestion.WeightedSuggestionsBatch(
46
                    hit_sets=norm_hit_sets,
47
                    weight=weight,
48
                    subjects=source_project.subjects,
49
                )
50
            )
51
        return hit_sets_from_sources
52
53
    def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
54
        """Hook for merging hit sets from sources. Can be overridden by
55
        subclasses."""
56
        return annif.util.merge_hits(hit_sets_from_sources, len(self.project.subjects))
57
58
    def _suggest_batch(self, texts, params):
59
        sources = annif.util.parse_sources(params["sources"])
60
        hit_sets_from_sources = self._suggest_with_sources(texts, sources)
61
        return self._merge_hit_sets_from_sources(hit_sets_from_sources, params)
62
63
64
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
65
    """Hyperparameter optimizer for the ensemble backend"""
66
67
    def __init__(self, backend, corpus, metric):
68
        super().__init__(backend, corpus, metric)
69
        self._sources = [
70
            project_id
71
            for project_id, _ in annif.util.parse_sources(
72
                backend.config_params["sources"]
73
            )
74
        ]
75
76
    def _prepare(self, n_jobs=1):
77
        self._gold_subjects = []
78
        self._source_hits = []
79
80
        for project_id in self._sources:
81
            project = self._backend.project.registry.get_project(project_id)
82
            project.initialize()
83
84
        psmap = annif.parallel.ProjectSuggestMap(
85
            self._backend.project.registry,
86
            self._sources,
87
            backend_params=None,
88
            limit=int(self._backend.params["limit"]),
89
            threshold=0.0,
90
        )
91
92
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
93
94
        with pool_class(jobs) as pool:
95
            for hit_sets, subject_sets in pool.imap_unordered(
96
                psmap.suggest_batch, self._corpus.doc_batches
97
            ):
98
                self._gold_subjects.extend(subject_sets)
99
                self._source_hits.extend(self._hit_sets_to_list(hit_sets))
100
101
    def _hit_sets_to_list(self, hit_sets):
102
        """Convert a dict of lists of hits to a list of dicts of hits, e.g.
103
        {"proj-1": [p-1-doc-1-hits, p-1-doc-2-hits]
104
         "proj-2": [p-2-doc-1-hits, p-2-doc-2-hits]}
105
        to
106
        [{"proj-1": p-1-doc-1-hits, "proj-2": p-2-doc-1-hits},
107
         {"proj-1": p-1-doc-2-hits, "proj-2": p-2-doc-2-hits}]
108
        """
109
        return [
110
            dict(zip(hit_sets.keys(), doc_hits)) for doc_hits in zip(*hit_sets.values())
111
        ]
112
113
    def _normalize(self, hps):
114
        total = sum(hps.values())
115
        return {source: hps[source] / total for source in hps}
116
117
    def _format_cfg_line(self, hps):
118
        return "sources=" + ",".join(
119
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
120
        )
121
122
    def _objective(self, trial):
123
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
124
        weights = {
125
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
126
            for project_id in self._sources
127
        }
128
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
129
            weighted_hits = []
130
            for project_id, hits in srchits.items():
131
                weighted_hits.append(
132
                    annif.suggestion.WeightedSuggestionsBatch(
133
                        hit_sets=[hits],
134
                        weight=weights[project_id],
135
                        subjects=self._backend.project.subjects,
136
                    )
137
                )
138
            batch.evaluate(
139
                annif.util.merge_hits(
140
                    weighted_hits, len(self._backend.project.subjects)
141
                )[0],
142
                goldsubj,
143
            )
144
        results = batch.results(metrics=[self._metric])
145
        return results[self._metric]
146
147
    def _postprocess(self, study):
148
        line = self._format_cfg_line(self._normalize(study.best_params))
149
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
150
151
152
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
153
    """Ensemble backend that combines results from multiple projects"""
154
155
    name = "ensemble"
156
157
    @property
158
    def is_trained(self):
159
        sources_trained = self._get_sources_attribute("is_trained")
160
        return all(sources_trained)
161
162
    @property
163
    def modification_time(self):
164
        mtimes = self._get_sources_attribute("modification_time")
165
        return max(filter(None, mtimes), default=None)
166
167
    def get_hp_optimizer(self, corpus, metric):
168
        return EnsembleOptimizer(self, corpus, metric)
169
170
    def _train(self, corpus, params, jobs=0):
171
        raise NotSupportedException("Training ensemble backend is not possible.")
172