Passed
Pull Request — master (#414)
by Osma
01:58
created

HyperparameterOptimizer._write_trial()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Hyperparameter optimization functionality for backends"""
2
3
import abc
4
import collections
5
import optuna
6
from .backend import AnnifBackend
7
from annif import logger
8
9
10
HPRecommendation = collections.namedtuple('HPRecommendation', 'lines score')
11
12
13
class HyperparameterOptimizer:
14
    """Base class for hyperparameter optimizers"""
15
16
    def __init__(self, backend, corpus, metric):
17
        self._backend = backend
18
        self._corpus = corpus
19
        self._metric = metric
20
21
    def _prepare(self, n_jobs=1):
22
        """Prepare the optimizer for hyperparameter evaluation.  Up to
23
        n_jobs parallel threads or processes may be used during the
24
        operation."""
25
26
        pass  # pragma: no cover
27
28
    @abc.abstractmethod
29
    def _objective(self, trial):
30
        """Objective function to optimize"""
31
        pass  # pragma: no cover
32
33
    @abc.abstractmethod
34
    def _postprocess(self, study):
35
        """Convert the study results into hyperparameter recommendations"""
36
        pass  # pragma: no cover
37
38
    def _write_trials_header(self, results_file, param_names):
39
        print('\t'.join(['trial', 'value'] + param_names), file=results_file)
40
41
    def _write_trial(self, results_file, trial):
42
        print('\t'.join((str(e) for e in [trial.number, trial.value] +
43
                         list(trial.params.values()))),
44
              file=results_file)
45
46
    def optimize(self, n_trials, n_jobs, results_file):
47
        """Find the optimal hyperparameters by testing up to the given number
48
        of hyperparameter combinations"""
49
50
        self._prepare(n_jobs)
51
        study = optuna.create_study(direction='maximize')
52
        study.optimize(self._objective,
53
                       n_trials=n_trials,
54
                       n_jobs=n_jobs,
55
                       gc_after_trial=False,
56
                       show_progress_bar=True)
57
        if results_file:
58
            self._write_trials_header(results_file,
59
                                      list(study.best_params.keys()))
60
            for trial in study.trials:
61
                self._write_trial(results_file, trial)
62
        return self._postprocess(study)
63
64
65
class AnnifHyperoptBackend(AnnifBackend):
66
    """Base class for Annif backends that can perform hyperparameter
67
    optimization"""
68
69
    @abc.abstractmethod
70
    def get_hp_optimizer(self, corpus, metric):
71
        """Get a HyperparameterOptimizer object that can look for
72
        optimal hyperparameter combinations for the given corpus,
73
        measured using the given metric"""
74
75
        pass  # pragma: no cover
76