Passed
Pull Request — master (#414)
by Osma
02:27
created

annif.backend.ensemble   A

Complexity

Total Complexity 18

Size/Duplication

Total Lines 112
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 18
eloc 87
dl 0
loc 112
rs 10
c 0
b 0
f 0

12 Methods

Rating   Name   Duplication   Size   Complexity  
A EnsembleOptimizer._prepare() 0 13 3
A EnsembleOptimizer.__init__() 0 5 1
A EnsembleOptimizer.get_hp_space() 0 5 2
A EnsembleOptimizer._normalize() 0 3 1
A EnsembleBackend._train() 0 2 1
A EnsembleBackend.get_hp_optimizer() 0 2 1
A EnsembleOptimizer._postprocess() 0 6 1
A EnsembleBackend._suggest() 0 6 1
A EnsembleBackend._merge_hits_from_sources() 0 4 1
A EnsembleBackend._suggest_with_sources() 0 13 2
A EnsembleOptimizer._test() 0 16 3
A EnsembleBackend._normalize_hits() 0 4 1
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
from hyperopt import hp
5
from tqdm.auto import tqdm
6
import annif.suggestion
7
import annif.project
8
import annif.util
9
import annif.eval
10
from . import hyperopt
11
from annif.exception import NotSupportedException
12
13
14
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
15
    """Hyperparameter optimizer for the ensemble backend"""
16
17
    def __init__(self, backend, corpus):
18
        super().__init__(backend, corpus)
19
        self._sources = [project_id for project_id, _
20
                         in annif.util.parse_sources(
21
                             backend.config_params['sources'])]
22
23
    def get_hp_space(self):
24
        space = {}
25
        for project_id in self._sources:
26
            space[project_id] = hp.uniform(project_id, 0.0, 1.0)
27
        return space
28
29
    def _prepare(self):
30
        self._gold_subjects = []
31
        self._source_hits = []
32
33
        for doc in self._corpus.documents:
34
            self._gold_subjects.append(
35
                annif.corpus.SubjectSet((doc.uris, doc.labels)))
36
            srchits = {}
37
            for project_id in self._sources:
38
                source_project = annif.project.get_project(project_id)
39
                hits = source_project.suggest(doc.text)
40
                srchits[project_id] = hits
41
            self._source_hits.append(srchits)
42
43
    def _normalize(self, hps):
44
        total = sum(hps.values())
45
        return {source: hps[source] / total for source in hps}
46
47
    def _test(self, hps):
48
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
49
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
50
            weighted_hits = []
51
            for project_id, hits in srchits.items():
52
                weighted_hits.append(annif.suggestion.WeightedSuggestion(
53
                    hits=hits, weight=hps[project_id]))
54
            batch.evaluate(
55
                annif.util.merge_hits(
56
                    weighted_hits,
57
                    self._backend.project.subjects),
58
                goldsubj)
59
        results = batch.results()
60
        scaled = self._normalize(hps)
61
        tqdm.write(f"Trial: {scaled} score: {results['NDCG']}")
62
        return 1 - results['NDCG']
63
64
    def _postprocess(self, best, trials):
65
        scaled = self._normalize(best)
66
        lines = 'sources=' + ','.join([f"{src}:{weight:.4f}"
67
                                       for src, weight in scaled.items()])
68
        score = 1 - trials.best_trial['result']['loss']
69
        return hyperopt.HPRecommendation(lines=[lines], score=score)
70
71
72
class EnsembleBackend(hyperopt.AnnifHyperoptBackend):
73
    """Ensemble backend that combines results from multiple projects"""
74
    name = "ensemble"
75
76
    def get_hp_optimizer(self, corpus):
77
        return EnsembleOptimizer(self, corpus)
78
79
    def _normalize_hits(self, hits, source_project):
80
        """Hook for processing hits from backends. Intended to be overridden
81
        by subclasses."""
82
        return hits
83
84
    def _suggest_with_sources(self, text, sources):
85
        hits_from_sources = []
86
        for project_id, weight in sources:
87
            source_project = annif.project.get_project(project_id)
88
            hits = source_project.suggest(text)
89
            self.debug(
90
                'Got {} hits from project {}, weight {}'.format(
91
                    len(hits), source_project.project_id, weight))
92
            norm_hits = self._normalize_hits(hits, source_project)
93
            hits_from_sources.append(
94
                annif.suggestion.WeightedSuggestion(
95
                    hits=norm_hits, weight=weight))
96
        return hits_from_sources
97
98
    def _merge_hits_from_sources(self, hits_from_sources, params):
99
        """Hook for merging hits from sources. Can be overridden by
100
        subclasses."""
101
        return annif.util.merge_hits(hits_from_sources, self.project.subjects)
102
103
    def _suggest(self, text, params):
104
        sources = annif.util.parse_sources(params['sources'])
105
        hits_from_sources = self._suggest_with_sources(text, sources)
106
        merged_hits = self._merge_hits_from_sources(hits_from_sources, params)
107
        self.debug('{} hits after merging'.format(len(merged_hits)))
108
        return merged_hits
109
110
    def _train(self, corpus, params):
111
        raise NotSupportedException('Training ensemble model is not possible.')
112