Passed
Pull Request — master (#414)
by Osma
01:45
created

EnsembleBackend._get_sources_attribute()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Ensemble backend that combines results from multiple projects"""
2
3
4
import annif.parallel
5
import annif.suggestion
6
import annif.util
7
import annif.eval
8
from . import backend
9
from . import hyperopt
10
from annif.exception import NotSupportedException
11
12
13
class BaseEnsembleBackend(backend.AnnifBackend):
14
    """Base class for ensemble backends"""
15
16
    def _get_sources_attribute(self, attr):
17
        params = self._get_backend_params(None)
18
        sources = annif.util.parse_sources(params['sources'])
19
        return [getattr(self.project.registry.get_project(project_id), attr)
20
                for project_id, _ in sources]
21
22
    def initialize(self):
23
        # initialize all the source projects
24
        params = self._get_backend_params(None)
25
        for project_id, _ in annif.util.parse_sources(params['sources']):
26
            project = self.project.registry.get_project(project_id)
27
            project.initialize()
28
29
    def _normalize_hits(self, hits, source_project):
30
        """Hook for processing hits from backends. Intended to be overridden
31
        by subclasses."""
32
        return hits
33
34
    def _suggest_with_sources(self, text, sources):
35
        hits_from_sources = []
36
        for project_id, weight in sources:
37
            source_project = self.project.registry.get_project(project_id)
38
            hits = source_project.suggest(text)
39
            self.debug(
40
                'Got {} hits from project {}, weight {}'.format(
41
                    len(hits), source_project.project_id, weight))
42
            norm_hits = self._normalize_hits(hits, source_project)
43
            hits_from_sources.append(
44
                annif.suggestion.WeightedSuggestion(
45
                    hits=norm_hits,
46
                    weight=weight,
47
                    subjects=source_project.subjects))
48
        return hits_from_sources
49
50
    def _merge_hits_from_sources(self, hits_from_sources, params):
51
        """Hook for merging hits from sources. Can be overridden by
52
        subclasses."""
53
        return annif.util.merge_hits(hits_from_sources, self.project.subjects)
54
55
    def _suggest(self, text, params):
56
        sources = annif.util.parse_sources(params['sources'])
57
        hits_from_sources = self._suggest_with_sources(text, sources)
58
        merged_hits = self._merge_hits_from_sources(hits_from_sources, params)
59
        self.debug('{} hits after merging'.format(len(merged_hits)))
60
        return merged_hits
61
62
63
class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
64
    """Hyperparameter optimizer for the ensemble backend"""
65
66
    def __init__(self, backend, corpus, metric):
67
        super().__init__(backend, corpus, metric)
68
        self._sources = [project_id for project_id, _
69
                         in annif.util.parse_sources(
70
                             backend.config_params['sources'])]
71
72
    def _prepare(self, n_jobs=1):
73
        self._gold_subjects = []
74
        self._source_hits = []
75
76
        for project_id in self._sources:
77
            project = self._backend.project.registry.get_project(project_id)
78
            project.initialize()
79
80
        psmap = annif.parallel.ProjectSuggestMap(
81
            self._backend.project.registry,
82
            self._sources,
83
            backend_params=None,
84
            limit=int(self._backend.params['limit']),
85
            threshold=0.0)
86
87
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
88
89
        with pool_class(jobs) as pool:
90
            for hits, uris, labels in pool.imap_unordered(
91
                    psmap.suggest, self._corpus.documents):
92
                self._gold_subjects.append(
93
                    annif.corpus.SubjectSet((uris, labels)))
94
                self._source_hits.append(hits)
95
96
    def _normalize(self, hps):
97
        total = sum(hps.values())
98
        return {source: hps[source] / total for source in hps}
99
100
    def _format_cfg_line(self, hps):
101
        return 'sources=' + ','.join([f"{src}:{weight:.4f}"
102
                                      for src, weight in hps.items()])
103
104
    def _objective(self, trial):
105
        batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
106
        weights = {project_id: trial.suggest_uniform(project_id, 0.0, 1.0)
107
                   for project_id in self._sources}
108
        for goldsubj, srchits in zip(self._gold_subjects, self._source_hits):
109
            weighted_hits = []
110
            for project_id, hits in srchits.items():
111
                weighted_hits.append(annif.suggestion.WeightedSuggestion(
112
                    hits=hits,
113
                    weight=weights[project_id],
114
                    subjects=self._backend.project.subjects))
115
            batch.evaluate(
116
                annif.util.merge_hits(
117
                    weighted_hits,
118
                    self._backend.project.subjects),
119
                goldsubj)
120
        results = batch.results(metrics=[self._metric])
121
        return results[self._metric]
122
123
    def _postprocess(self, study):
124
        line = self._format_cfg_line(self._normalize(study.best_params))
125
        return hyperopt.HPRecommendation(lines=[line], score=study.best_value)
126
127
128
class EnsembleBackend(BaseEnsembleBackend, hyperopt.AnnifHyperoptBackend):
129
    """Ensemble backend that combines results from multiple projects"""
130
    name = "ensemble"
131
132
    @property
133
    def is_trained(self):
134
        sources_trained = self._get_sources_attribute('is_trained')
135
        return all(sources_trained)
136
137
    @property
138
    def modification_time(self):
139
        mtimes = self._get_sources_attribute('modification_time')
140
        return max(filter(None, mtimes), default=None)
141
142
    def get_hp_optimizer(self, corpus, metric):
143
        return EnsembleOptimizer(self, corpus, metric)
144
145
    def _train(self, corpus, params):
146
        raise NotSupportedException(
147
            'Training ensemble backend is not possible.')
148