Passed
Pull Request — main (#708)
by Juho
10:17 queued 06:14
created

annif.suggestion.SuggestionResult.__iter__()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
from __future__ import annotations
3
4
import collections
5
import itertools
6
from collections.abc import Iterable, Iterator, Sequence
7
from typing import TYPE_CHECKING
8
9
import numpy as np
10
from scipy.sparse import csr_array
11
12
if TYPE_CHECKING:
13
    from annif.corpus.subject import SubjectIndex
14
15
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
16
17
18
def vector_to_suggestions(vector: np.ndarray, limit: int) -> Iterator:
19
    limit = min(len(vector), limit)
20
    topk_idx = np.argpartition(vector, -limit)[-limit:]
21
    return (
22
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
23
    )
24
25
26
def filter_suggestion(
27
    preds: csr_array,
28
    limit: int | None = None,
29
    threshold: float = 0.0,
30
) -> csr_array:
31
    """filter a 2D sparse suggestion array (csr_array), retaining only the
32
    top K suggestions with a score above or equal to the threshold for each
33
    individual prediction; the rest will be left as zeros"""
34
35
    if limit == 0:
36
        return csr_array(preds.shape, dtype=np.float32)  # empty
37
38
    data, rows, cols = [], [], []
39
    for row in range(preds.shape[0]):
40
        arow = preds.getrow(row)
41
        if limit is not None and limit < len(arow.data):
42
            topk_idx = arow.data.argpartition(-limit)[-limit:]
43
        else:
44
            topk_idx = range(len(arow.data))
45
        for idx in topk_idx:
46
            if arow.data[idx] >= threshold:
47
                data.append(arow.data[idx])
48
                rows.append(row)
49
                cols.append(arow.indices[idx])
50
    return csr_array((data, (rows, cols)), shape=preds.shape, dtype=np.float32)
51
52
53
class SuggestionResult:
54
    """Suggestions for a single document, backed by a row of a sparse array."""
55
56
    def __init__(self, array: csr_array, idx: int) -> None:
57
        self._array = array
58
        self._idx = idx
59
60
    def __iter__(self):
61
        _, cols = self._array[[self._idx], :].nonzero()
62
        suggestions = [
63
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
64
            for col in cols
65
        ]
66
        return iter(
67
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
68
        )
69
70
    def as_vector(self) -> np.ndarray:
71
        return self._array[[self._idx], :].toarray()[0]
72
73
    def __len__(self) -> int:
74
        _, cols = self._array[[self._idx], :].nonzero()
75
        return len(cols)
76
77
78
class SuggestionBatch:
79
    """Subject suggestions for a batch of documents."""
80
81
    def __init__(self, array: csr_array) -> None:
82
        """Create a new SuggestionBatch from a csr_array"""
83
        assert isinstance(array, csr_array)
84
        self.array = array
85
86
    @classmethod
87
    def from_sequence(
88
        cls,
89
        suggestion_results: Sequence[Iterable[SubjectSuggestion]],
90
        subject_index: SubjectIndex,
91
        limit: int | None = None,
92
    ) -> SuggestionBatch:
93
        """Create a new SuggestionBatch from a sequence where each item is
94
        a sequence of SubjectSuggestion objects."""
95
96
        deprecated = set(subject_index.deprecated_ids())
97
        data, rows, cols = [], [], []
98
        for idx, result in enumerate(suggestion_results):
99
            for suggestion in itertools.islice(result, limit):
100
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
101
                    continue
102
                data.append(min(suggestion.score, 1.0))
103
                rows.append(idx)
104
                cols.append(suggestion.subject_id)
105
        return cls(
106
            csr_array(
107
                (data, (rows, cols)),
108
                shape=(len(suggestion_results), len(subject_index)),
109
                dtype=np.float32,
110
            )
111
        )
112
113
    @classmethod
114
    def from_averaged(
115
        cls, batches: list[SuggestionBatch], weights: list[float]
116
    ) -> SuggestionBatch:
117
        """Create a new SuggestionBatch where the subject scores are the
118
        weighted average of scores in several SuggestionBatches"""
119
120
        avg_array = sum(
121
            [batch.array * weight for batch, weight in zip(batches, weights)]
122
        ) / sum(weights)
123
        return SuggestionBatch(avg_array)
124
125
    def filter(
126
        self, limit: int | None = None, threshold: float = 0.0
127
    ) -> SuggestionBatch:
128
        """Return a subset of the hits, filtered by the given limit and
129
        score threshold, as another SuggestionBatch object."""
130
131
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
132
133
    def __getitem__(self, idx: int) -> SuggestionResult:
134
        if idx < 0 or idx >= len(self):
135
            raise IndexError
136
        return SuggestionResult(self.array, idx)
137
138
    def __len__(self) -> int:
139
        return self.array.shape[0]
140
141
142
class SuggestionResults:
143
    """Subject suggestions for a potentially very large number of documents."""
144
145
    def __init__(self, batches: Iterable[SuggestionBatch]) -> None:
146
        """Initialize a new SuggestionResults from an iterable that provides
147
        SuggestionBatch objects."""
148
149
        self.batches = batches
150
151
    def filter(
152
        self, limit: int | None = None, threshold: float = 0.0
153
    ) -> SuggestionResults:
154
        """Return a view of these suggestions, filtered by the given limit
155
        and/or threshold, as another SuggestionResults object."""
156
157
        return SuggestionResults(
158
            (batch.filter(limit, threshold) for batch in self.batches)
159
        )
160
161
    def __iter__(self) -> itertools.chain:
162
        return iter(itertools.chain.from_iterable(self.batches))
163