annif.suggestion.filter_suggestion()   B
last analyzed

Complexity

Conditions 7

Size

Total Lines 25
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 18
nop 3
dl 0
loc 25
rs 8
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
3
from __future__ import annotations
4
5
import collections
6
import itertools
7
from typing import TYPE_CHECKING
8
9
import numpy as np
10
from scipy.sparse import csr_array
11
12
if TYPE_CHECKING:
13
    from collections.abc import Iterable, Iterator, Sequence
14
15
    from annif.corpus.subject import SubjectIndex
16
17
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
18
19
20
def vector_to_suggestions(vector: np.ndarray, limit: int) -> Iterator:
21
    limit = min(len(vector), limit)
22
    topk_idx = np.argpartition(vector, -limit)[-limit:]
23
    return (
24
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
25
    )
26
27
28
def filter_suggestion(
29
    preds: csr_array,
30
    limit: int | None = None,
31
    threshold: float = 0.0,
32
) -> csr_array:
33
    """filter a 2D sparse suggestion array (csr_array), retaining only the
34
    top K suggestions with a score above or equal to the threshold for each
35
    individual prediction; the rest will be left as zeros"""
36
37
    if limit == 0:
38
        return csr_array(preds.shape, dtype=np.float32)  # empty
39
40
    data, rows, cols = [], [], []
41
    for row in range(preds.shape[0]):
42
        arow = preds[[row]]
43
        if limit is not None and limit < len(arow.data):
44
            topk_idx = arow.data.argpartition(-limit)[-limit:]
45
        else:
46
            topk_idx = range(len(arow.data))
47
        for idx in topk_idx:
48
            if arow.data[idx] >= threshold:
49
                data.append(arow.data[idx])
50
                rows.append(row)
51
                cols.append(arow.indices[idx])
52
    return csr_array((data, (rows, cols)), shape=preds.shape, dtype=np.float32)
53
54
55
class SuggestionResult:
56
    """Suggestions for a single document, backed by a row of a sparse array."""
57
58
    def __init__(self, array: csr_array, idx: int) -> None:
59
        self._array = array
60
        self._idx = idx
61
62
    def __iter__(self):
63
        _, cols = self._array[[self._idx], :].nonzero()
64
        suggestions = [
65
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
66
            for col in cols
67
        ]
68
        return iter(
69
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
70
        )
71
72
    def as_vector(self) -> np.ndarray:
73
        return self._array[[self._idx], :].toarray()[0]
74
75
    def __len__(self) -> int:
76
        _, cols = self._array[[self._idx], :].nonzero()
77
        return len(cols)
78
79
80
class SuggestionBatch:
81
    """Subject suggestions for a batch of documents."""
82
83
    def __init__(self, array: csr_array) -> None:
84
        """Create a new SuggestionBatch from a csr_array"""
85
        assert isinstance(array, csr_array)
86
        self.array = array
87
88
    @classmethod
89
    def from_sequence(
90
        cls,
91
        suggestion_results: Sequence[Iterable[SubjectSuggestion]],
92
        subject_index: SubjectIndex,
93
        limit: int | None = None,
94
    ) -> SuggestionBatch:
95
        """Create a new SuggestionBatch from a sequence where each item is
96
        a sequence of SubjectSuggestion objects."""
97
98
        data, rows, cols = [], [], []
99
        for idx, result in enumerate(suggestion_results):
100
            for suggestion in itertools.islice(result, limit):
101
                if suggestion.score <= 0.0:
102
                    continue
103
                try:  # check for deprecated subjects
104
                    _ = subject_index[suggestion.subject_id]
105
                except IndexError:
106
                    continue
107
                data.append(min(suggestion.score, 1.0))
108
                rows.append(idx)
109
                cols.append(suggestion.subject_id)
110
        return cls(
111
            csr_array(
112
                (data, (rows, cols)),
113
                shape=(len(suggestion_results), len(subject_index)),
114
                dtype=np.float32,
115
            )
116
        )
117
118
    @classmethod
119
    def from_averaged(
120
        cls, batches: list[SuggestionBatch], weights: list[float]
121
    ) -> SuggestionBatch:
122
        """Create a new SuggestionBatch where the subject scores are the
123
        weighted average of scores in several SuggestionBatches"""
124
125
        avg_array = sum(
126
            [batch.array * weight for batch, weight in zip(batches, weights)]
127
        ) / sum(weights)
128
        return SuggestionBatch(avg_array)
129
130
    def filter(
131
        self, limit: int | None = None, threshold: float = 0.0
132
    ) -> SuggestionBatch:
133
        """Return a subset of the hits, filtered by the given limit and
134
        score threshold, as another SuggestionBatch object."""
135
136
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
137
138
    def __getitem__(self, idx: int) -> SuggestionResult:
139
        if idx < 0 or idx >= len(self):
140
            raise IndexError
141
        return SuggestionResult(self.array, idx)
142
143
    def __len__(self) -> int:
144
        return self.array.shape[0]
145
146
147
class SuggestionResults:
148
    """Subject suggestions for a potentially very large number of documents."""
149
150
    def __init__(self, batches: Iterable[SuggestionBatch]) -> None:
151
        """Initialize a new SuggestionResults from an iterable that provides
152
        SuggestionBatch objects."""
153
154
        self.batches = batches
155
156
    def filter(
157
        self, limit: int | None = None, threshold: float = 0.0
158
    ) -> SuggestionResults:
159
        """Return a view of these suggestions, filtered by the given limit
160
        and/or threshold, as another SuggestionResults object."""
161
162
        return SuggestionResults(
163
            (batch.filter(limit, threshold) for batch in self.batches)
164
        )
165
166
    def __iter__(self) -> itertools.chain:
167
        return iter(itertools.chain.from_iterable(self.batches))
168