Passed
Pull Request — master (#414)
by Osma
02:32
created

ce()   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 2
nop 1
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
from tqdm.auto import tqdm
5
import annif.suggestion
6
import annif.project
7
import annif.util
8
import annif.eval
9
from . import hyperopt
10
from annif.exception import NotSupportedException
11
12
13
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
14
    """Hyperparameter optimizer for the ensemble backend"""
15
16
    def __init__(self, backend, corpus, metric):
17
        super().__init__(backend, corpus, metric)
18
        self._sources = [project_id for project_id, _
19
                         in annif.util.parse_sources(
20
                             backend.config_params['sources'])]
21
22
    def _prepare(self):
23
        self._gold_subjects = []
24
        self._source_hits = []
25
26
        for doc in self._corpus.documents:
27
            self._gold_subjects.append(
28
                annif.corpus.SubjectSet((doc.uris, doc.labels)))
29
            srchits = {}
30
            for project_id in self._sources:
31
                source_project = annif.project.get_project(project_id)
32
                hits = source_project.suggest(doc.text)
33
                srchits[project_id] = hits
34
            self._source_hits.append(srchits)
35
36
    def _normalize(self, hps):
37
        total = sum(hps.values())
38
        return {source: hps[source] / total for source in hps}
39
40
    def _format_cfg_line(self, hps):
41
        return 'sources=' + ','.join([f"{src}:{weight:.4f}"
42
                                      for src, weight in hps.items()])
43
44
    def _objective(self, trial):
45
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
46
        weights = {project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
47
                   for project_id in self._sources}
48
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
49
            weighted_hits = []
50
            for project_id, hits in srchits.items():
51
                weighted_hits.append(annif.suggestion.WeightedSuggestion(
52
                    hits=hits, weight=weights[project_id]))
53
            batch.evaluate(
54
                annif.util.merge_hits(
55
                    weighted_hits,
56
                    self._backend.project.subjects),
57
                goldsubj)
58
        results = batch.results()
59
        line = self._format_cfg_line(self._normalize(weights))
60
        return results[self._metric]
61
62
    def _postprocess(self, study):
63
        line = self._format_cfg_line(self._normalize(study.best_params))
64
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
65
66
67
class EnsembleBackend(hyperopt.AnnifHyperoptBackend):
68
    """Ensemble backend that combines results from multiple projects"""
69
    name = "ensemble"
70
71
    def get_hp_optimizer(self, corpus, metric):
72
        return EnsembleOptimizer(self, corpus, metric)
73
74
    def _normalize_hits(self, hits, source_project):
75
        """Hook for processing hits from backends. Intended to be overridden
76
        by subclasses."""
77
        return hits
78
79
    def _suggest_with_sources(self, text, sources):
80
        hits_from_sources = []
81
        for project_id, weight in sources:
82
            source_project = annif.project.get_project(project_id)
83
            hits = source_project.suggest(text)
84
            self.debug(
85
                'Got {} hits from project {}, weight {}'.format(
86
                    len(hits), source_project.project_id, weight))
87
            norm_hits = self._normalize_hits(hits, source_project)
88
            hits_from_sources.append(
89
                annif.suggestion.WeightedSuggestion(
90
                    hits=norm_hits, weight=weight))
91
        return hits_from_sources
92
93
    def _merge_hits_from_sources(self, hits_from_sources, params):
94
        """Hook for merging hits from sources. Can be overridden by
95
        subclasses."""
96
        return annif.util.merge_hits(hits_from_sources, self.project.subjects)
97
98
    def _suggest(self, text, params):
99
        sources = annif.util.parse_sources(params['sources'])
100
        hits_from_sources = self._suggest_with_sources(text, sources)
101
        merged_hits = self._merge_hits_from_sources(hits_from_sources, params)
102
        self.debug('{} hits after merging'.format(len(merged_hits)))
103
        return merged_hits
104
105
    def _train(self, corpus, params):
106
        raise NotSupportedException('Training ensemble model is not possible.')
107