Passed
Push — issue678-refactor-suggestionre... ( 0d7003...d7e7fa )
by Osma
02:49
created

BaseEnsembleBackend._merge_source_batches()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 4
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.eval
5
import annif.parallel
6
import annif.suggestion
7
import annif.util
8
from annif.exception import NotSupportedException
9
from annif.suggestion import SuggestionBatch, vector_to_suggestions
10
11
from . import backend, hyperopt
12
13
14
class BaseEnsembleBackend(backend.AnnifBackend):
15
    """Base class for ensemble backends"""
16
17
    def _get_sources_attribute(self, attr):
18
        params = self._get_backend_params(None)
19
        sources = annif.util.parse_sources(params["sources"])
20
        return [
21
            getattr(self.project.registry.get_project(project_id), attr)
22
            for project_id, _ in sources
23
        ]
24
25
    def initialize(self, parallel=False):
26
        # initialize all the source projects
27
        params = self._get_backend_params(None)
28
        for project_id, _ in annif.util.parse_sources(params["sources"]):
29
            project = self.project.registry.get_project(project_id)
30
            project.initialize(parallel)
31
32
    def _suggest_with_sources(self, texts, sources):
33
        return {
34
            project_id: self.project.registry.get_project(project_id).suggest(texts)
35
            for project_id, _ in sources
36
        }
37
38
    def _merge_source_batches(self, batch_by_source, sources, params):
39
        """Merge the given SuggestionBatches from each source into a single
40
        SuggestionBatch. The default implementation computes a weighted
41
        average based on the weights given in the sources tuple. Intended
42
        to be overridden in subclasses."""
43
44
        batches = [batch_by_source[project_id] for project_id, _ in sources]
45
        weights = [weight for _, weight in sources]
46
        return SuggestionBatch.from_averaged(batches, weights)
47
48
    def _suggest_batch(self, texts, params):
49
        sources = annif.util.parse_sources(params["sources"])
50
        batch_by_source = self._suggest_with_sources(texts, sources)
51
        merged = self._merge_source_batches(batch_by_source, sources, params)
52
        return merged.filter(limit=int(params["limit"]))
53
54
55
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
56
    """Hyperparameter optimizer for the ensemble backend"""
57
58
    def __init__(self, backend, corpus, metric):
59
        super().__init__(backend, corpus, metric)
60
        self._sources = [
61
            project_id
62
            for project_id, _ in annif.util.parse_sources(
63
                backend.config_params["sources"]
64
            )
65
        ]
66
67
    def _prepare(self, n_jobs=1):
68
        self._gold_batches = []
69
        self._source_batches = []
70
71
        for project_id in self._sources:
72
            project = self._backend.project.registry.get_project(project_id)
73
            project.initialize()
74
75
        psmap = annif.parallel.ProjectSuggestMap(
76
            self._backend.project.registry,
77
            self._sources,
78
            backend_params=None,
79
            limit=int(self._backend.params["limit"]),
80
            threshold=0.0,
81
        )
82
83
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
84
85
        with pool_class(jobs) as pool:
86
            for suggestions, gold_batch in pool.imap_unordered(
87
                psmap.suggest_batch, self._corpus.doc_batches
88
            ):
89
                self._source_batches.append(suggestions)
90
                self._gold_batches.append(gold_batch)
91
92
    def _normalize(self, hps):
93
        total = sum(hps.values())
94
        return {source: hps[source] / total for source in hps}
95
96
    def _format_cfg_line(self, hps):
97
        return "sources=" + ",".join(
98
            [f"{src}:{weight:.4f}" for src, weight in hps.items()]
99
        )
100
101
    def _objective(self, trial):
102
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
103
        proj_weights = {
104
            project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
105
            for project_id in self._sources
106
        }
107
        for gold_batch, src_batches in zip(self._gold_batches, self._source_batches):
108
            batches = [src_batches[project_id] for project_id in self._sources]
109
            weights = [proj_weights[project_id] for project_id in self._sources]
110
            avg_batch = SuggestionBatch.from_averaged(batches, weights).filter(
111
                limit=int(self._backend.params["limit"])
112
            )
113
            eval_batch.evaluate_many(avg_batch, gold_batch)
114
        results = eval_batch.results(metrics=[self._metric])
115
        return results[self._metric]
116
117
    def _postprocess(self, study):
118
        line = self._format_cfg_line(self._normalize(study.best_params))
119
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
120
121
122
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
123
    """Ensemble backend that combines results from multiple projects"""
124
125
    name = "ensemble"
126
127
    @property
128
    def is_trained(self):
129
        sources_trained = self._get_sources_attribute("is_trained")
130
        return all(sources_trained)
131
132
    @property
133
    def modification_time(self):
134
        mtimes = self._get_sources_attribute("modification_time")
135
        return max(filter(None, mtimes), default=None)
136
137
    def get_hp_optimizer(self, corpus, metric):
138
        return EnsembleOptimizer(self, corpus, metric)
139
140
    def _train(self, corpus, params, jobs=0):
141
        raise NotSupportedException("Training ensemble backend is not possible.")
142