Passed
Push — main ( da1836...1db6a8 )
by Osma
07:26 queued 04:14
created

HPObjective._objective_wrapper()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Hyperparameter optimization functionality for backends"""
2
3
from __future__ import annotations
4
5
import abc
6
import collections
7
import tempfile
8
from typing import TYPE_CHECKING, Any, Callable
9
10
import optuna
11
import optuna.exceptions
12
13
import annif.parallel
14
15
from .backend import AnnifBackend
16
17
if TYPE_CHECKING:
18
    from click.utils import LazyFile
19
    from optuna.study.study import Study
20
    from optuna.trial import Trial
21
22
    from annif.corpus.document import DocumentCorpus
23
24
HPRecommendation = collections.namedtuple("HPRecommendation", "lines score")
25
26
27
class TrialWriter:
28
    """Object that writes hyperparameter optimization trial results into a
29
    TSV file."""
30
31
    def __init__(self, results_file: LazyFile, normalize_func: Callable) -> None:
32
        self.results_file = results_file
33
        self.normalize_func = normalize_func
34
        self.header_written = False
35
36
    def write(self, trial_data: dict[str, Any]) -> None:
37
        """Write the results of one trial into the results file.  On the
38
        first run, write the header line first."""
39
40
        if not self.header_written:
41
            param_names = list(trial_data["params"].keys())
42
            print("\t".join(["trial", "value"] + param_names), file=self.results_file)
43
            self.header_written = True
44
        print(
45
            "\t".join(
46
                (
47
                    str(e)
48
                    for e in [trial_data["number"], trial_data["value"]]
49
                    + list(self.normalize_func(trial_data["params"]).values())
50
                )
51
            ),
52
            file=self.results_file,
53
        )
54
55
56
class HPObjective(annif.parallel.BaseWorker):
57
    """Base class for hyperparameter optimizer objective functions"""
58
59
    @classmethod
60
    def objective(cls, trial: Trial, args) -> float:
61
        """Objective function to optimize. To be implemented by subclasses."""
62
63
        pass  # pragma: no cover
64
65
    @classmethod
66
    def _objective_wrapper(cls, trial: Trial) -> float:
67
        return cls.objective(trial, cls.args)
68
69
    @classmethod
70
    def run_trial(
71
        cls, trial_id: int, storage_url: str, study_name: str
72
    ) -> dict[str, Any]:
73
74
        # use a callback to set the completed trial, to avoid race conditions
75
        completed_trial = []
76
77
        def set_trial_callback(study: Study, trial: Trial) -> None:
78
            completed_trial.append(trial)
79
80
        study = optuna.load_study(storage=storage_url, study_name=study_name)
81
        study.optimize(
82
            cls._objective_wrapper,
83
            n_trials=1,
84
            callbacks=[set_trial_callback],
85
        )
86
87
        return {
88
            "number": completed_trial[0].number,
89
            "value": completed_trial[0].value,
90
            "params": completed_trial[0].params,
91
        }
92
93
94
class HyperparameterOptimizer:
95
    """Base class for hyperparameter optimizers"""
96
97
    def __init__(
98
        self,
99
        backend: AnnifBackend,
100
        corpus: DocumentCorpus,
101
        metric: str,
102
        objective: HPObjective,
103
    ) -> None:
104
        self._backend = backend
105
        self._corpus = corpus
106
        self._metric = metric
107
        self._objective = objective
108
109
    def _prepare(self, n_jobs: int = 1):
110
        """Prepare the optimizer for hyperparameter evaluation.  Up to
111
        n_jobs parallel threads or processes may be used during the
112
        operation. The return value will be passed to the objective function."""
113
114
        pass  # pragma: no cover
115
116
    @abc.abstractmethod
117
    def _postprocess(self, study: Study) -> HPRecommendation:
118
        """Convert the study results into hyperparameter recommendations"""
119
        pass  # pragma: no cover
120
121
    def _normalize(self, hps: dict[str, float]) -> dict[str, float]:
122
        """Normalize the given raw hyperparameters. Intended to be overridden
123
        by subclasses when necessary. The default is to keep them as-is."""
124
        return hps
125
126
    def optimize(
127
        self, n_trials: int, n_jobs: int, results_file: LazyFile | None
128
    ) -> HPRecommendation:
129
        """Find the optimal hyperparameters by testing up to the given number
130
        of hyperparameter combinations"""
131
132
        objective_args = self._prepare(n_jobs)
133
        self._objective.init(objective_args)
134
135
        writer = TrialWriter(results_file, self._normalize) if results_file else None
136
        write_callback = writer.write if writer else None
137
138
        temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
139
        storage_url = f"sqlite:///{temp_db.name}"
140
141
        study = optuna.create_study(direction="maximize", storage=storage_url)
142
143
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
144
        with pool_class(jobs) as pool:
145
            for i in range(n_trials):
146
                pool.apply_async(
147
                    self._objective.run_trial,
148
                    args=(i, storage_url, study.study_name),
149
                    callback=write_callback,
150
                )
151
            pool.close()
152
            pool.join()
153
154
        return self._postprocess(study)
155
156
157
class AnnifHyperoptBackend(AnnifBackend):
158
    """Base class for Annif backends that can perform hyperparameter
159
    optimization"""
160
161
    @abc.abstractmethod
162
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str):
163
        """Get a HyperparameterOptimizer object that can look for
164
        optimal hyperparameter combinations for the given corpus,
165
        measured using the given metric"""
166
167
        pass  # pragma: no cover
168