Passed
Push — upgrade-to-connexion3 ( e417e0...5d7ec9 )
by Juho
09:39 queued 05:10
created

annif.suggestion.SuggestionBatch.from_sequence()   B

Complexity

Conditions 5

Size

Total Lines 24
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 20
nop 4
dl 0
loc 24
rs 8.9332
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
from __future__ import annotations
3
4
import collections
5
import itertools
6
from typing import TYPE_CHECKING
7
8
import numpy as np
9
from scipy.sparse import csr_array
10
11
if TYPE_CHECKING:
12
    from collections.abc import Iterable, Iterator, Sequence
13
14
    from annif.corpus.subject import SubjectIndex
15
16
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
17
18
19
def vector_to_suggestions(vector: np.ndarray, limit: int) -> Iterator:
20
    limit = min(len(vector), limit)
21
    topk_idx = np.argpartition(vector, -limit)[-limit:]
22
    return (
23
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
24
    )
25
26
27
def filter_suggestion(
28
    preds: csr_array,
29
    limit: int | None = None,
30
    threshold: float = 0.0,
31
) -> csr_array:
32
    """filter a 2D sparse suggestion array (csr_array), retaining only the
33
    top K suggestions with a score above or equal to the threshold for each
34
    individual prediction; the rest will be left as zeros"""
35
36
    if limit == 0:
37
        return csr_array(preds.shape, dtype=np.float32)  # empty
38
39
    data, rows, cols = [], [], []
40
    for row in range(preds.shape[0]):
41
        arow = preds.getrow(row)
42
        if limit is not None and limit < len(arow.data):
43
            topk_idx = arow.data.argpartition(-limit)[-limit:]
44
        else:
45
            topk_idx = range(len(arow.data))
46
        for idx in topk_idx:
47
            if arow.data[idx] >= threshold:
48
                data.append(arow.data[idx])
49
                rows.append(row)
50
                cols.append(arow.indices[idx])
51
    return csr_array((data, (rows, cols)), shape=preds.shape, dtype=np.float32)
52
53
54
class SuggestionResult:
55
    """Suggestions for a single document, backed by a row of a sparse array."""
56
57
    def __init__(self, array: csr_array, idx: int) -> None:
58
        self._array = array
59
        self._idx = idx
60
61
    def __iter__(self):
62
        _, cols = self._array[[self._idx], :].nonzero()
63
        suggestions = [
64
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
65
            for col in cols
66
        ]
67
        return iter(
68
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
69
        )
70
71
    def as_vector(self) -> np.ndarray:
72
        return self._array[[self._idx], :].toarray()[0]
73
74
    def __len__(self) -> int:
75
        _, cols = self._array[[self._idx], :].nonzero()
76
        return len(cols)
77
78
79
class SuggestionBatch:
80
    """Subject suggestions for a batch of documents."""
81
82
    def __init__(self, array: csr_array) -> None:
83
        """Create a new SuggestionBatch from a csr_array"""
84
        assert isinstance(array, csr_array)
85
        self.array = array
86
87
    @classmethod
88
    def from_sequence(
89
        cls,
90
        suggestion_results: Sequence[Iterable[SubjectSuggestion]],
91
        subject_index: SubjectIndex,
92
        limit: int | None = None,
93
    ) -> SuggestionBatch:
94
        """Create a new SuggestionBatch from a sequence where each item is
95
        a sequence of SubjectSuggestion objects."""
96
97
        deprecated = set(subject_index.deprecated_ids())
98
        data, rows, cols = [], [], []
99
        for idx, result in enumerate(suggestion_results):
100
            for suggestion in itertools.islice(result, limit):
101
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
102
                    continue
103
                data.append(min(suggestion.score, 1.0))
104
                rows.append(idx)
105
                cols.append(suggestion.subject_id)
106
        return cls(
107
            csr_array(
108
                (data, (rows, cols)),
109
                shape=(len(suggestion_results), len(subject_index)),
110
                dtype=np.float32,
111
            )
112
        )
113
114
    @classmethod
115
    def from_averaged(
116
        cls, batches: list[SuggestionBatch], weights: list[float]
117
    ) -> SuggestionBatch:
118
        """Create a new SuggestionBatch where the subject scores are the
119
        weighted average of scores in several SuggestionBatches"""
120
121
        avg_array = sum(
122
            [batch.array * weight for batch, weight in zip(batches, weights)]
123
        ) / sum(weights)
124
        return SuggestionBatch(avg_array)
125
126
    def filter(
127
        self, limit: int | None = None, threshold: float = 0.0
128
    ) -> SuggestionBatch:
129
        """Return a subset of the hits, filtered by the given limit and
130
        score threshold, as another SuggestionBatch object."""
131
132
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
133
134
    def __getitem__(self, idx: int) -> SuggestionResult:
135
        if idx < 0 or idx >= len(self):
136
            raise IndexError
137
        return SuggestionResult(self.array, idx)
138
139
    def __len__(self) -> int:
140
        return self.array.shape[0]
141
142
143
class SuggestionResults:
144
    """Subject suggestions for a potentially very large number of documents."""
145
146
    def __init__(self, batches: Iterable[SuggestionBatch]) -> None:
147
        """Initialize a new SuggestionResults from an iterable that provides
148
        SuggestionBatch objects."""
149
150
        self.batches = batches
151
152
    def filter(
153
        self, limit: int | None = None, threshold: float = 0.0
154
    ) -> SuggestionResults:
155
        """Return a view of these suggestions, filtered by the given limit
156
        and/or threshold, as another SuggestionResults object."""
157
158
        return SuggestionResults(
159
            (batch.filter(limit, threshold) for batch in self.batches)
160
        )
161
162
    def __iter__(self) -> itertools.chain:
163
        return iter(itertools.chain.from_iterable(self.batches))
164