Passed
Pull Request — main (#681)
by Osma
05:43 queued 03:03
created

PAVBackend._normalize_suggestion_batch()   A

Complexity

Conditions 2

Size

Total Lines 17
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 13
nop 3
dl 0
loc 17
rs 9.75
c 0
b 0
f 0
1
"""PAV ensemble backend that combines results from multiple projects and
2
learns which concept suggestions from each backend are trustworthy using the
3
PAV algorithm, a.k.a. isotonic regression, to turn raw scores returned by
4
individual backends into probabilities."""
5
6
import os.path
7
8
import joblib
9
import numpy as np
10
from scipy.sparse import coo_matrix, csc_matrix
11
from sklearn.isotonic import IsotonicRegression
12
13
import annif.corpus
14
import annif.util
15
from annif.exception import NotInitializedException, NotSupportedException
16
from annif.suggestion import SubjectSuggestion, SuggestionBatch
17
18
from . import backend, ensemble
19
20
21
class PAVBackend(ensemble.BaseEnsembleBackend):
22
    """PAV ensemble backend that combines results from multiple projects"""
23
24
    name = "pav"
25
26
    MODEL_FILE_PREFIX = "pav-model-"
27
28
    # defaults for uninitialized instances
29
    _models = None
30
31
    DEFAULT_PARAMETERS = {"min-docs": 10}
32
33
    def default_params(self):
34
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
35
        params.update(self.DEFAULT_PARAMETERS)
36
        return params
37
38
    def initialize(self, parallel=False):
39
        super().initialize(parallel)
40
        if self._models is not None:
41
            return  # already initialized
42
        self._models = {}
43
        sources = annif.util.parse_sources(self.params["sources"])
44
        for source_project_id, _ in sources:
45
            model_filename = self.MODEL_FILE_PREFIX + source_project_id
46
            path = os.path.join(self.datadir, model_filename)
47
            if os.path.exists(path):
48
                self.debug("loading PAV model from {}".format(path))
49
                self._models[source_project_id] = joblib.load(path)
50
            else:
51
                raise NotInitializedException(
52
                    "PAV model file '{}' not found".format(path),
53
                    backend_id=self.backend_id,
54
                )
55
56
    def _get_model(self, source_project_id):
57
        self.initialize()
58
        return self._models[source_project_id]
59
60
    def _normalize_suggestion_batch(self, batch, source_project):
61
        reg_models = self._get_model(source_project.project_id)
62
        pav_batch = [
63
            [
64
                SubjectSuggestion(
65
                    subject_id=sugg.subject_id,
66
                    score=reg_models[sugg.subject_id].predict([sugg.score])[0],
67
                )
68
                if sugg.subject_id in reg_models
69
                else SubjectSuggestion(
70
                    subject_id=sugg.subject_id, score=sugg.score
71
                )  # default to raw score
72
                for sugg in result
73
            ]
74
            for result in batch
75
        ]
76
        return SuggestionBatch.from_sequence(pav_batch, self.project.subjects)
77
78
    @staticmethod
79
    def _suggest_train_corpus(source_project, corpus):
80
        # lists for constructing score matrix
81
        data, row, col = [], [], []
82
        # lists for constructing true label matrix
83
        trow, tcol = [], []
84
85
        ndocs = 0
86
        for docid, doc in enumerate(corpus.documents):
87
            hits = source_project.suggest([doc.text])[0]
88
            vector = hits.as_vector()
89
            for cid in np.flatnonzero(vector):
90
                data.append(vector[cid])
91
                row.append(docid)
92
                col.append(cid)
93
            for cid in np.flatnonzero(
94
                doc.subject_set.as_vector(len(source_project.subjects))
95
            ):
96
                trow.append(docid)
97
                tcol.append(cid)
98
            ndocs += 1
99
        scores = coo_matrix(
100
            (data, (row, col)),
101
            shape=(ndocs, len(source_project.subjects)),
102
            dtype=np.float32,
103
        )
104
        true = coo_matrix(
105
            (np.ones(len(trow), dtype=bool), (trow, tcol)),
106
            shape=(ndocs, len(source_project.subjects)),
107
            dtype=bool,
108
        )
109
        return csc_matrix(scores), csc_matrix(true)
110
111
    def _create_pav_model(self, source_project_id, min_docs, corpus):
112
        self.info(
113
            "creating PAV model for source {}, min_docs={}".format(
114
                source_project_id, min_docs
115
            )
116
        )
117
        source_project = self.project.registry.get_project(source_project_id)
118
        # suggest subjects for the training corpus
119
        scores, true = self._suggest_train_corpus(source_project, corpus)
120
        # create the concept-specific PAV regression models
121
        pav_regressions = {}
122
        for cid in range(len(source_project.subjects)):
123
            if true[:, cid].sum() < min_docs:
124
                continue  # don't create model b/c of too few examples
125
            reg = IsotonicRegression(out_of_bounds="clip")
126
            cid_scores = scores[:, cid].toarray().flatten().astype(np.float64)
127
            reg.fit(cid_scores, true[:, cid].toarray().flatten())
128
            pav_regressions[cid] = reg
129
        self.info("created PAV model for {} concepts".format(len(pav_regressions)))
130
        model_filename = self.MODEL_FILE_PREFIX + source_project_id
131
        annif.util.atomic_save(
132
            pav_regressions, self.datadir, model_filename, method=joblib.dump
133
        )
134
135
    def _train(self, corpus, params, jobs=0):
136
        if corpus == "cached":
137
            raise NotSupportedException(
138
                "Training pav project from cached data not supported."
139
            )
140
        if corpus.is_empty():
141
            raise NotSupportedException(
142
                "training backend {} with no documents".format(self.backend_id)
143
            )
144
        self.info("creating PAV models")
145
        sources = annif.util.parse_sources(self.params["sources"])
146
        min_docs = int(params["min-docs"])
147
        for source_project_id, _ in sources:
148
            self._create_pav_model(source_project_id, min_docs, corpus)
149