Passed
Pull Request — master (#414)
by Osma
02:11
created

HyperparameterOptimizer.optimize()   A

Complexity

Conditions 1

Size

Total Lines 12
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 12
rs 9.95
c 0
b 0
f 0
cc 1
nop 3
1
"""Hyperparameter optimization functionality for backends"""
2
3
import abc
4
import collections
5
import optuna
6
from .backend import AnnifBackend
7
from annif import logger
8
9
10
HPRecommendation = collections.namedtuple('HPRecommendation', 'lines score')
11
12
13
class HyperparameterOptimizer:
14
    """Base class for hyperparameter optimizers"""
15
16
    def __init__(self, backend, corpus, metric):
17
        self._backend = backend
18
        self._corpus = corpus
19
        self._metric = metric
20
21
    def _prepare(self):
22
        """Prepare the optimizer for hyperparameter evaluation"""
23
        pass  # pragma: no cover
24
25
    @abc.abstractmethod
26
    def _objective(self, trial):
27
        """Objective function to optimize"""
28
        pass  # pragma: no cover
29
30
    @abc.abstractmethod
31
    def _postprocess(self, study):
32
        """Convert the study results into hyperparameter recommendations"""
33
        pass  # pragma: no cover
34
35
    def optimize(self, n_trials, n_jobs):
36
        """Find the optimal hyperparameters by testing up to the given number
37
        of hyperparameter combinations"""
38
39
        self._prepare()
40
        study = optuna.create_study(direction='maximize')
41
        study.optimize(self._objective,
42
                       n_trials=n_trials,
43
                       n_jobs=n_jobs,
44
                       gc_after_trial=False,
45
                       show_progress_bar=True)
46
        return self._postprocess(study)
47
48
49
class AnnifHyperoptBackend(AnnifBackend):
50
    """Base class for Annif backends that can perform hyperparameter
51
    optimization"""
52
53
    @abc.abstractmethod
54
    def get_hp_optimizer(self, corpus, metric):
55
        """Get a HyperparameterOptimizer object that can look for
56
        optimal hyperparameter combinations for the given corpus,
57
        measured using the given metric"""
58
59
        pass  # pragma: no cover
60