Passed
Pull Request — master (#414)
by Osma
02:11
created

EnsembleOptimizer._prepare()   A

Complexity

Conditions 3

Size

Total Lines 14
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 14
rs 9.75
c 0
b 0
f 0
cc 3
nop 1
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.suggestion
5
import annif.util
6
import annif.eval
7
from . import backend
8
from . import hyperopt
9
from annif.exception import NotSupportedException
10
11
12
class BaseEnsembleBackend(backend.AnnifBackend):
13
    """Base class for ensemble backends"""
14
15
    def _get_sources_attribute(self, attr):
16
        params = self._get_backend_params(None)
17
        sources = annif.util.parse_sources(params['sources'])
18
        return [getattr(self.project.registry.get_project(project_id), attr)
19
                for project_id, _ in sources]
20
21
    def initialize(self):
22
        # initialize all the source projects
23
        params = self._get_backend_params(None)
24
        for project_id, _ in annif.util.parse_sources(params['sources']):
25
            project = self.project.registry.get_project(project_id)
26
            project.initialize()
27
28
    def _normalize_hits(self, hits, source_project):
29
        """Hook for processing hits from backends. Intended to be overridden
30
        by subclasses."""
31
        return hits
32
33
    def _suggest_with_sources(self, text, sources):
34
        hits_from_sources = []
35
        for project_id, weight in sources:
36
            source_project = self.project.registry.get_project(project_id)
37
            hits = source_project.suggest(text)
38
            self.debug(
39
                'Got {} hits from project {}, weight {}'.format(
40
                    len(hits), source_project.project_id, weight))
41
            norm_hits = self._normalize_hits(hits, source_project)
42
            hits_from_sources.append(
43
                annif.suggestion.WeightedSuggestion(
44
                    hits=norm_hits,
45
                    weight=weight,
46
                    subjects=source_project.subjects))
47
        return hits_from_sources
48
49
    def _merge_hits_from_sources(self, hits_from_sources, params):
50
        """Hook for merging hits from sources. Can be overridden by
51
        subclasses."""
52
        return annif.util.merge_hits(hits_from_sources, self.project.subjects)
53
54
    def _suggest(self, text, params):
55
        sources = annif.util.parse_sources(params['sources'])
56
        hits_from_sources = self._suggest_with_sources(text, sources)
57
        merged_hits = self._merge_hits_from_sources(hits_from_sources, params)
58
        self.debug('{} hits after merging'.format(len(merged_hits)))
59
        return merged_hits
60
61
62
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
63
    """Hyperparameter optimizer for the ensemble backend"""
64
65
    def __init__(self, backend, corpus, metric):
66
        super().__init__(backend, corpus, metric)
67
        self._sources = [project_id for project_id, _
68
                         in annif.util.parse_sources(
69
                             backend.config_params['sources'])]
70
71
    def _prepare(self):
72
        self._gold_subjects = []
73
        self._source_hits = []
74
75
        for doc in self._corpus.documents:
76
            self._gold_subjects.append(
77
                annif.corpus.SubjectSet((doc.uris, doc.labels)))
78
            srchits = {}
79
            for project_id in self._sources:
80
                registry = self._backend.project.registry
81
                source_project = registry.get_project(project_id)
82
                hits = source_project.suggest(doc.text)
83
                srchits[project_id] = hits
84
            self._source_hits.append(srchits)
85
86
    def _normalize(self, hps):
87
        total = sum(hps.values())
88
        return {source: hps[source] / total for source in hps}
89
90
    def _format_cfg_line(self, hps):
91
        return 'sources=' + ','.join([f"{src}:{weight:.4f}"
92
                                      for src, weight in hps.items()])
93
94
    def _objective(self, trial):
95
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
96
        weights = {project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
97
                   for project_id in self._sources}
98
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
99
            weighted_hits = []
100
            for project_id, hits in srchits.items():
101
                weighted_hits.append(annif.suggestion.WeightedSuggestion(
102
                    hits=hits,
103
                    weight=weights[project_id],
104
                    subjects=self._backend.project.subjects))
105
            batch.evaluate(
106
                annif.util.merge_hits(
107
                    weighted_hits,
108
                    self._backend.project.subjects),
109
                goldsubj)
110
        results = batch.results(metrics=[self._metric])
111
        return results[self._metric]
112
113
    def _postprocess(self, study):
114
        line = self._format_cfg_line(self._normalize(study.best_params))
115
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
116
117
118
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
119
    """Ensemble backend that combines results from multiple projects"""
120
    name = "ensemble"
121
122
    @property
123
    def is_trained(self):
124
        sources_trained = self._get_sources_attribute('is_trained')
125
        return all(sources_trained)
126
127
    @property
128
    def modification_time(self):
129
        mtimes = self._get_sources_attribute('modification_time')
130
        return max(filter(None, mtimes), default=None)
131
132
    def get_hp_optimizer(self, corpus, metric):
133
        return EnsembleOptimizer(self, corpus, metric)
134
135
    def _train(self, corpus, params):
136
        raise NotSupportedException(
137
            'Training ensemble backend is not possible.')
138