Passed
Pull Request — master (#414)
by Osma
08:05
created

HyperparameterOptimizer.__init__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Hyperparameter optimization functionality for backends"""
2
3
import abc
4
import hyperopt
5
from .backend import AnnifBackend
6
7
8
class HyperparameterOptimizer:
9
    """Base class for hyperparameter optimizers"""
10
11
    def __init__(self, backend, corpus):
12
        self._backend = backend
13
        self._corpus = corpus
14
15
    @abc.abstractmethod
16
    def get_hp_space(self):
17
        """Get the hyperparameter space definition of this backend"""
18
        pass  # pragma: no cover
19
20
    def _prepare(self):
21
        """Prepare the optimizer for hyperparameter evaluation"""
22
        pass  # pragma: no cover
23
24
    @abc.abstractmethod
25
    def _test(self, hps):
26
        """Evaluate a set of hyperparameters"""
27
        pass  # pragma: no cover
28
29
    def _postprocess(self, best, trials):
30
        """Convert the trial results into final hyperparameter
31
        recommendation"""
32
        return (best, 1 - trials.best_trial['result']['loss'])
33
34
    def optimize(self, n_trials):
35
        """Find the optimal hyperparameters by testing up to the given number of
36
        hyperparameter combinations"""
37
38
        self._prepare()
39
        space = self.get_hp_space()
40
        trials = hyperopt.Trials()
41
        best = hyperopt.fmin(
42
            fn=self._test,
43
            space=space,
44
            algo=hyperopt.tpe.suggest,
45
            max_evals=n_trials,
46
            trials=trials)
47
        return self._postprocess(best, trials)
48
49
50
class AnnifHyperoptBackend(AnnifBackend):
51
    """Base class for Annif backends that can perform hyperparameter
52
    optimization"""
53
54
    @abc.abstractmethod
55
    def get_hp_optimizer(self, corpus):
56
        """Get a HyperparameterOptimizer object that can look for
57
        optimal hyperparameter combinations for the given corpus"""
58
59
        pass  # pragma: no cover
60