Passed
Push — issue678-refactor-suggestionre... ( 311240...092cdc )
by Osma
03:05
created

EnsembleOptimizer._postprocess()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
from annif.suggestion import SuggestionBatch, vector_to_suggestions
10
11
from . import backend, hyperopt
12
13
14
class BaseEnsembleBackend(backend.AnnifBackend):
15
    """Base class for ensemble backends"""
16
17
    def _get_sources_attribute(self, attr):
18
        params = self._get_backend_params(None)
19
        sources = annif.util.parse_sources(params["sources"])
20
        return [
21
            getattr(self.project.registry.get_project(project_id), attr)
22
            for project_id, _ in sources
23
        ]
24
25
    def initialize(self, parallel=False):
26
        # initialize all the source projects
27
        params = self._get_backend_params(None)
28
        for project_id, _ in annif.util.parse_sources(params["sources"]):
29
            project = self.project.registry.get_project(project_id)
30
            project.initialize(parallel)
31
32
    def _normalize_suggestion_batch(self, batch, source_project):
33
        """Hook for processing a batch of suggestions from backends.
34
        Intended to be overridden by subclasses."""
35
        return batch
36
37
    def _suggest_with_sources(self, texts, sources):
38
        batches_from_sources = []
39
        for project_id, weight in sources:
40
            source_project = self.project.registry.get_project(project_id)
41
            batch = source_project.suggest(texts)
42
            norm_batch = self._normalize_suggestion_batch(batch, source_project)
43
            batches_from_sources.append(
44
                annif.suggestion.WeightedSuggestionsBatch(
45
                    hit_sets=norm_batch,
46
                    weight=weight,
47
                    subjects=source_project.subjects,
48
                )
49
            )
50
        return batches_from_sources
51
52
    def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
53
        """Hook for merging hit sets from sources. Can be overridden by
54
        subclasses."""
55
        return annif.util.merge_hits(hit_sets_from_sources)
56
57
    def _suggest_batch(self, texts, params):
58
        sources = annif.util.parse_sources(params["sources"])
59
        hit_sets_from_sources = self._suggest_with_sources(texts, sources)
60
        return annif.suggestion.SuggestionBatch.from_sequence(
61
            [
62
                vector_to_suggestions(row, int(params["limit"]))
63
                for row in self._merge_hit_sets_from_sources(
64
                    hit_sets_from_sources, params
65
                )
66
            ],
67
            self.project.subjects,
68
        )
69
70
71
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
72
    """Hyperparameter optimizer for the ensemble backend"""
73
74
    def __init__(self, backend, corpus, metric):
75
        super().__init__(backend, corpus, metric)
76
        self._sources = [
77
            project_id
78
            for project_id, _ in annif.util.parse_sources(
79
                backend.config_params["sources"]
80
            )
81
        ]
82
83
    def _prepare(self, n_jobs=1):
84
        self._gold_batches = []
85
        self._source_batches = []
86
87
        for project_id in self._sources:
88
            project = self._backend.project.registry.get_project(project_id)
89
            project.initialize()
90
91
        psmap = annif.parallel.ProjectSuggestMap(
92
            self._backend.project.registry,
93
            self._sources,
94
            backend_params=None,
95
            limit=int(self._backend.params["limit"]),
96
            threshold=0.0,
97
        )
98
99
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
100
101
        with pool_class(jobs) as pool:
102
            for gold_batch, suggestions in pool.imap_unordered(
103
                psmap.suggest_batch, self._corpus.doc_batches
104
            ):
105
                self._gold_batches.append(gold_batch)
106
                self._source_batches.append(suggestions)
107
108
    def _normalize(self, hps):
109
        total = sum(hps.values())
110
        return {source: hps[source] / total for source in hps}
111
112
    def _format_cfg_line(self, hps):
113
        return "sources=" + ",".join(
114
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
115
        )
116
117
    def _objective(self, trial):
118
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
119
        proj_weights = {
120
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
121
            for project_id in self._sources
122
        }
123
        for gold_batch, src_batches in zip(self._gold_batches, self._source_batches):
124
            batches = [src_batches[project_id] for project_id in self._sources]
125
            weights = [proj_weights[project_id] for project_id in self._sources]
126
            avg_batch = SuggestionBatch.from_averaged(batches, weights).filter(
127
                limit=int(self._backend.params["limit"])
128
            )
129
            eval_batch.evaluate_many(avg_batch, gold_batch)
130
        results = eval_batch.results(metrics=[self._metric])
131
        return results[self._metric]
132
133
    def _postprocess(self, study):
134
        line = self._format_cfg_line(self._normalize(study.best_params))
135
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
136
137
138
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
139
    """Ensemble backend that combines results from multiple projects"""
140
141
    name = "ensemble"
142
143
    @property
144
    def is_trained(self):
145
        sources_trained = self._get_sources_attribute("is_trained")
146
        return all(sources_trained)
147
148
    @property
149
    def modification_time(self):
150
        mtimes = self._get_sources_attribute("modification_time")
151
        return max(filter(None, mtimes), default=None)
152
153
    def get_hp_optimizer(self, corpus, metric):
154
        return EnsembleOptimizer(self, corpus, metric)
155
156
    def _train(self, corpus, params, jobs=0):
157
        raise NotSupportedException("Training ensemble backend is not possible.")
158