Passed
Pull Request — master (#414)
by Osma
02:10
created

HyperparameterOptimizer._postprocess()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Hyperparameter optimization functionality for backends"""
2
3
import abc
4
import collections
5
import optuna
6
from .backend import AnnifBackend
7
from annif import logger
8
9
10
HPRecommendation = collections.namedtuple('HPRecommendation', 'lines score')
11
12
13
class HyperparameterOptimizer:
14
    """Base class for hyperparameter optimizers"""
15
16
    def __init__(self, backend, corpus, metric):
17
        self._backend = backend
18
        self._corpus = corpus
19
        self._metric = metric
20
21
    def _prepare(self):
22
        """Prepare the optimizer for hyperparameter evaluation"""
23
        pass  # pragma: no cover
24
25
    @abc.abstractmethod
26
    def _objective(self, trial):
27
        """Objective function to optimize"""
28
        pass  # pragma: no cover
29
30
    @abc.abstractmethod
31
    def _postprocess(self, study):
32
        """Convert the study results into hyperparameter recommendations"""
33
        pass  # pragma: no cover
34
35
    def _write_trials_header(self, results_file, param_names):
36
        print('\t'.join(['trial', 'value'] + param_names), file=results_file)
37
38
    def _write_trial(self, results_file, trial):
39
        print('\t'.join((str(e) for e in [trial.number, trial.value] +
40
                         list(trial.params.values()))),
41
              file=results_file)
42
43
    def optimize(self, n_trials, n_jobs, results_file):
44
        """Find the optimal hyperparameters by testing up to the given number
45
        of hyperparameter combinations"""
46
47
        self._prepare(n_jobs)
48
        study = optuna.create_study(direction='maximize')
49
        study.optimize(self._objective,
50
                       n_trials=n_trials,
51
                       n_jobs=n_jobs,
52
                       gc_after_trial=False,
53
                       show_progress_bar=True)
54
        if results_file:
55
            self._write_trials_header(results_file,
56
                                      list(study.best_params.keys()))
57
            for trial in study.trials:
58
                self._write_trial(results_file, trial)
59
        return self._postprocess(study)
60
61
62
class AnnifHyperoptBackend(AnnifBackend):
63
    """Base class for Annif backends that can perform hyperparameter
64
    optimization"""
65
66
    @abc.abstractmethod
67
    def get_hp_optimizer(self, corpus, metric):
68
        """Get a HyperparameterOptimizer object that can look for
69
        optimal hyperparameter combinations for the given corpus,
70
        measured using the given metric"""
71
72
        pass  # pragma: no cover
73