Passed
Pull Request — master (#414)
by Osma
01:45
created

AnnifHyperoptBackend.get_hp_optimizer()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Hyperparameter optimization functionality for backends"""
2
3
import abc
4
import collections
5
import warnings
6
import optuna
7
import optuna.exceptions
8
from .backend import AnnifBackend
9
from annif import logger
10
11
12
HPRecommendation = collections.namedtuple('HPRecommendation', 'lines score')
13
14
15
class TrialWriter:
16
    """Object that writes hyperparameter optimization trial results into a
17
    TSV file."""
18
19
    def __init__(self, results_file, normalize_func):
20
        self.results_file = results_file
21
        self.normalize_func = normalize_func
22
        self.header_written = False
23
24
    def write(self, study, trial):
25
        """Write the results of one trial into the results file.  On the
26
        first run, write the header line first."""
27
28
        if not self.header_written:
29
            param_names = list(trial.params.keys())
30
            print('\t'.join(['trial', 'value'] + param_names),
31
                  file=self.results_file)
32
            self.header_written = True
33
        print('\t'.join((str(e) for e in [trial.number, trial.value] +
34
                         list(self.normalize_func(trial.params).values()))),
35
              file=self.results_file)
36
37
38
class HyperparameterOptimizer:
39
    """Base class for hyperparameter optimizers"""
40
41
    def __init__(self, backend, corpus, metric):
42
        self._backend = backend
43
        self._corpus = corpus
44
        self._metric = metric
45
46
    def _prepare(self, n_jobs=1):
47
        """Prepare the optimizer for hyperparameter evaluation.  Up to
48
        n_jobs parallel threads or processes may be used during the
49
        operation."""
50
51
        pass  # pragma: no cover
52
53
    @abc.abstractmethod
54
    def _objective(self, trial):
55
        """Objective function to optimize"""
56
        pass  # pragma: no cover
57
58
    @abc.abstractmethod
59
    def _postprocess(self, study):
60
        """Convert the study results into hyperparameter recommendations"""
61
        pass  # pragma: no cover
62
63
    def _normalize(self, hps):
64
        """Normalize the given raw hyperparameters. Intended to be overridden
65
        by subclasses when necessary. The default is to keep them as-is."""
66
        return hps
67
68
    def optimize(self, n_trials, n_jobs, results_file):
69
        """Find the optimal hyperparameters by testing up to the given number
70
        of hyperparameter combinations"""
71
72
        self._prepare(n_jobs)
73
74
        if results_file:
75
            callbacks = [TrialWriter(results_file, self._normalize).write]
76
        else:
77
            callbacks = []
78
79
        study = optuna.create_study(direction='maximize')
80
        # silence the ExperimentalWarning when using the Optuna progress bar
81
        warnings.filterwarnings("ignore",
82
                                category=optuna.exceptions.ExperimentalWarning)
83
        study.optimize(self._objective,
84
                       n_trials=n_trials,
85
                       n_jobs=n_jobs,
86
                       callbacks=callbacks,
87
                       gc_after_trial=False,
88
                       show_progress_bar=(n_jobs == 1))
89
        return self._postprocess(study)
90
91
92
class AnnifHyperoptBackend(AnnifBackend):
93
    """Base class for Annif backends that can perform hyperparameter
94
    optimization"""
95
96
    @abc.abstractmethod
97
    def get_hp_optimizer(self, corpus, metric):
98
        """Get a HyperparameterOptimizer object that can look for
99
        optimal hyperparameter combinations for the given corpus,
100
        measured using the given metric"""
101
102
        pass  # pragma: no cover
103