Passed
Pull Request — main (#708)
by Juho
05:36 queued 02:48
created

annif.suggestion.SuggestionBatch.__init__()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
from __future__ import annotations
3
4
import collections
5
import itertools
6
from typing import TYPE_CHECKING, Iterator, List, Optional
7
8
import numpy as np
9
from scipy.sparse import csr_array
10
11
if TYPE_CHECKING:
12
    from annif.corpus.subject import SubjectIndex
13
14
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
15
16
17
def vector_to_suggestions(vector: np.ndarray, limit: int) -> Iterator:
18
    limit = min(len(vector), limit)
19
    topk_idx = np.argpartition(vector, -limit)[-limit:]
20
    return (
21
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
22
    )
23
24
25
def filter_suggestion(
26
    preds: csr_array,
27
    limit: Optional[int] = None,
28
    threshold: float = 0.0,
29
) -> csr_array:
30
    """filter a 2D sparse suggestion array (csr_array), retaining only the
31
    top K suggestions with a score above or equal to the threshold for each
32
    individual prediction; the rest will be left as zeros"""
33
34
    if limit == 0:
35
        return csr_array(preds.shape, dtype=np.float32)  # empty
36
37
    data, rows, cols = [], [], []
38
    for row in range(preds.shape[0]):
39
        arow = preds.getrow(row)
40
        if limit is not None and limit < len(arow.data):
41
            topk_idx = arow.data.argpartition(-limit)[-limit:]
42
        else:
43
            topk_idx = range(len(arow.data))
44
        for idx in topk_idx:
45
            if arow.data[idx] >= threshold:
46
                data.append(arow.data[idx])
47
                rows.append(row)
48
                cols.append(arow.indices[idx])
49
    return csr_array((data, (rows, cols)), shape=preds.shape, dtype=np.float32)
50
51
52
class SuggestionResult:
53
    """Suggestions for a single document, backed by a row of a sparse array."""
54
55
    def __init__(self, array: csr_array, idx: int) -> None:
56
        self._array = array
57
        self._idx = idx
58
59
    def __iter__(self):
60
        _, cols = self._array[[self._idx], :].nonzero()
61
        suggestions = [
62
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
63
            for col in cols
64
        ]
65
        return iter(
66
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
67
        )
68
69
    def as_vector(self) -> np.ndarray:
70
        return self._array[[self._idx], :].toarray()[0]
71
72
    def __len__(self) -> int:
73
        _, cols = self._array[[self._idx], :].nonzero()
74
        return len(cols)
75
76
77
class SuggestionBatch:
78
    """Subject suggestions for a batch of documents."""
79
80
    def __init__(self, array: csr_array) -> None:
81
        """Create a new SuggestionBatch from a csr_array"""
82
        assert isinstance(array, csr_array)
83
        self.array = array
84
85
    @classmethod
86
    def from_sequence(
87
        cls,
88
        suggestion_results: List[List[SubjectSuggestion]],
89
        subject_index: SubjectIndex,
90
        limit: Optional[int] = None,
91
    ) -> SuggestionBatch:
92
        """Create a new SuggestionBatch from a sequence where each item is
93
        a sequence of SubjectSuggestion objects."""
94
95
        deprecated = set(subject_index.deprecated_ids())
96
        data, rows, cols = [], [], []
97
        for idx, result in enumerate(suggestion_results):
98
            for suggestion in itertools.islice(result, limit):
99
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
100
                    continue
101
                data.append(min(suggestion.score, 1.0))
102
                rows.append(idx)
103
                cols.append(suggestion.subject_id)
104
        return cls(
105
            csr_array(
106
                (data, (rows, cols)),
107
                shape=(len(suggestion_results), len(subject_index)),
108
                dtype=np.float32,
109
            )
110
        )
111
112
    @classmethod
113
    def from_averaged(
114
        cls, batches: List[SuggestionBatch], weights: List[float]
115
    ) -> SuggestionBatch:
116
        """Create a new SuggestionBatch where the subject scores are the
117
        weighted average of scores in several SuggestionBatches"""
118
119
        avg_array = sum(
120
            [batch.array * weight for batch, weight in zip(batches, weights)]
121
        ) / sum(weights)
122
        return SuggestionBatch(avg_array)
123
124
    def filter(
125
        self, limit: Optional[int] = None, threshold: float = 0.0
126
    ) -> SuggestionBatch:
127
        """Return a subset of the hits, filtered by the given limit and
128
        score threshold, as another SuggestionBatch object."""
129
130
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
131
132
    def __getitem__(self, idx: int) -> SuggestionResult:
133
        if idx < 0 or idx >= len(self):
134
            raise IndexError
135
        return SuggestionResult(self.array, idx)
136
137
    def __len__(self) -> int:
138
        return self.array.shape[0]
139
140
141
class SuggestionResults:
142
    """Subject suggestions for a potentially very large number of documents."""
143
144
    def __init__(self, batches: List[SuggestionBatch]) -> None:
145
        """Initialize a new SuggestionResults from an iterable that provides
146
        SuggestionBatch objects."""
147
148
        self.batches = batches
149
150
    def filter(
151
        self, limit: Optional[int] = None, threshold: float = 0.0
152
    ) -> SuggestionResults:
153
        """Return a view of these suggestions, filtered by the given limit
154
        and/or threshold, as another SuggestionResults object."""
155
156
        return SuggestionResults(
157
            (batch.filter(limit, threshold) for batch in self.batches)
158
        )
159
160
    def __iter__(self) -> itertools.chain:
161
        return iter(itertools.chain.from_iterable(self.batches))
162