Passed
Pull Request — master (#414)
by Osma
02:14
created

EnsembleOptimizer._prepare()   A

Complexity

Conditions 3

Size

Total Lines 13
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 13
rs 9.8
c 0
b 0
f 0
cc 3
nop 1
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
from hyperopt import hp
5
import annif.suggestion
6
import annif.project
7
import annif.util
8
import annif.eval
9
from . import hyperopt
10
from annif.exception import NotSupportedException
11
12
13
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
14
    """Hyperparameter optimizer for the ensemble backend"""
15
16
    def __init__(self, backend, corpus):
17
        super().__init__(backend, corpus)
18
        self._sources = [project_id for project_id, _
19
                         in annif.util.parse_sources(
20
                             backend.config_params['sources'])]
21
22
    def get_hp_space(self):
23
        space = {}
24
        for project_id in self._sources:
25
            space[project_id] = hp.uniform(project_id, 0.0, 1.0)
26
        return space
27
28
    def _prepare(self):
29
        self._gold_subjects = []
30
        self._source_hits = []
31
32
        for doc in self._corpus.documents:
33
            self._gold_subjects.append(
34
                annif.corpus.SubjectSet((doc.uris, doc.labels)))
35
            srchits = {}
36
            for project_id in self._sources:
37
                source_project = annif.project.get_project(project_id)
38
                hits = source_project.suggest(doc.text)
39
                srchits[project_id] = hits
40
            self._source_hits.append(srchits)
41
42
    def _test(self, hps):
43
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
44
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
45
            weighted_hits = []
46
            for project_id, hits in srchits.items():
47
                weighted_hits.append(annif.suggestion.WeightedSuggestion(
48
                    hits=hits, weight=hps[project_id]))
49
            batch.evaluate(
50
                annif.util.merge_hits(
51
                    weighted_hits,
52
                    self._backend.project.subjects),
53
                goldsubj)
54
        results = batch.results()
55
        return 1 - results['NDCG']
56
57
58
class EnsembleBackend(hyperopt.AnnifHyperoptBackend):
59
    """Ensemble backend that combines results from multiple projects"""
60
    name = "ensemble"
61
62
    def get_hp_optimizer(self, corpus):
63
        return EnsembleOptimizer(self, corpus)
64
65
    def _normalize_hits(self, hits, source_project):
66
        """Hook for processing hits from backends. Intended to be overridden
67
        by subclasses."""
68
        return hits
69
70
    def _suggest_with_sources(self, text, sources):
71
        hits_from_sources = []
72
        for project_id, weight in sources:
73
            source_project = annif.project.get_project(project_id)
74
            hits = source_project.suggest(text)
75
            self.debug(
76
                'Got {} hits from project {}, weight {}'.format(
77
                    len(hits), source_project.project_id, weight))
78
            norm_hits = self._normalize_hits(hits, source_project)
79
            hits_from_sources.append(
80
                annif.suggestion.WeightedSuggestion(
81
                    hits=norm_hits, weight=weight))
82
        return hits_from_sources
83
84
    def _merge_hits_from_sources(self, hits_from_sources, params):
85
        """Hook for merging hits from sources. Can be overridden by
86
        subclasses."""
87
        return annif.util.merge_hits(hits_from_sources, self.project.subjects)
88
89
    def _suggest(self, text, params):
90
        sources = annif.util.parse_sources(params['sources'])
91
        hits_from_sources = self._suggest_with_sources(text, sources)
92
        merged_hits = self._merge_hits_from_sources(hits_from_sources, params)
93
        self.debug('{} hits after merging'.format(len(merged_hits)))
94
        return merged_hits
95
96
    def _train(self, corpus, params):
97
        raise NotSupportedException('Training ensemble model is not possible.')
98