Passed
Push — issue631-rest-api-language-det... ( 34c253...1cd800 )
by Osma
04:27
created

annif.suggestion.vector_to_suggestions()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 2
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
3
from __future__ import annotations
4
5
import collections
6
import itertools
7
from typing import TYPE_CHECKING
8
9
import numpy as np
10
from scipy.sparse import csr_array
11
12
if TYPE_CHECKING:
13
    from collections.abc import Iterable, Iterator, Sequence
14
15
    from annif.corpus.subject import SubjectIndex
16
17
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
18
19
20
def vector_to_suggestions(vector: np.ndarray, limit: int) -> Iterator:
21
    limit = min(len(vector), limit)
22
    topk_idx = np.argpartition(vector, -limit)[-limit:]
23
    return (
24
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
25
    )
26
27
28
def filter_suggestion(
29
    preds: csr_array,
30
    limit: int | None = None,
31
    threshold: float = 0.0,
32
) -> csr_array:
33
    """filter a 2D sparse suggestion array (csr_array), retaining only the
34
    top K suggestions with a score above or equal to the threshold for each
35
    individual prediction; the rest will be left as zeros"""
36
37
    if limit == 0:
38
        return csr_array(preds.shape, dtype=np.float32)  # empty
39
40
    data, rows, cols = [], [], []
41
    for row in range(preds.shape[0]):
42
        arow = preds[[row]]
43
        if limit is not None and limit < len(arow.data):
44
            topk_idx = arow.data.argpartition(-limit)[-limit:]
45
        else:
46
            topk_idx = range(len(arow.data))
47
        for idx in topk_idx:
48
            if arow.data[idx] >= threshold:
49
                data.append(arow.data[idx])
50
                rows.append(row)
51
                cols.append(arow.indices[idx])
52
    return csr_array((data, (rows, cols)), shape=preds.shape, dtype=np.float32)
53
54
55
class SuggestionResult:
56
    """Suggestions for a single document, backed by a row of a sparse array."""
57
58
    def __init__(self, array: csr_array, idx: int) -> None:
59
        self._array = array
60
        self._idx = idx
61
62
    def __iter__(self):
63
        _, cols = self._array[[self._idx], :].nonzero()
64
        suggestions = [
65
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
66
            for col in cols
67
        ]
68
        return iter(
69
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
70
        )
71
72
    def as_vector(self) -> np.ndarray:
73
        return self._array[[self._idx], :].toarray()[0]
74
75
    def __len__(self) -> int:
76
        _, cols = self._array[[self._idx], :].nonzero()
77
        return len(cols)
78
79
80
class SuggestionBatch:
81
    """Subject suggestions for a batch of documents."""
82
83
    def __init__(self, array: csr_array) -> None:
84
        """Create a new SuggestionBatch from a csr_array"""
85
        assert isinstance(array, csr_array)
86
        self.array = array
87
88
    @classmethod
89
    def from_sequence(
90
        cls,
91
        suggestion_results: Sequence[Iterable[SubjectSuggestion]],
92
        subject_index: SubjectIndex,
93
        limit: int | None = None,
94
    ) -> SuggestionBatch:
95
        """Create a new SuggestionBatch from a sequence where each item is
96
        a sequence of SubjectSuggestion objects."""
97
98
        deprecated = set(subject_index.deprecated_ids())
99
        data, rows, cols = [], [], []
100
        for idx, result in enumerate(suggestion_results):
101
            for suggestion in itertools.islice(result, limit):
102
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
103
                    continue
104
                data.append(min(suggestion.score, 1.0))
105
                rows.append(idx)
106
                cols.append(suggestion.subject_id)
107
        return cls(
108
            csr_array(
109
                (data, (rows, cols)),
110
                shape=(len(suggestion_results), len(subject_index)),
111
                dtype=np.float32,
112
            )
113
        )
114
115
    @classmethod
116
    def from_averaged(
117
        cls, batches: list[SuggestionBatch], weights: list[float]
118
    ) -> SuggestionBatch:
119
        """Create a new SuggestionBatch where the subject scores are the
120
        weighted average of scores in several SuggestionBatches"""
121
122
        avg_array = sum(
123
            [batch.array * weight for batch, weight in zip(batches, weights)]
124
        ) / sum(weights)
125
        return SuggestionBatch(avg_array)
126
127
    def filter(
128
        self, limit: int | None = None, threshold: float = 0.0
129
    ) -> SuggestionBatch:
130
        """Return a subset of the hits, filtered by the given limit and
131
        score threshold, as another SuggestionBatch object."""
132
133
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
134
135
    def __getitem__(self, idx: int) -> SuggestionResult:
136
        if idx < 0 or idx >= len(self):
137
            raise IndexError
138
        return SuggestionResult(self.array, idx)
139
140
    def __len__(self) -> int:
141
        return self.array.shape[0]
142
143
144
class SuggestionResults:
145
    """Subject suggestions for a potentially very large number of documents."""
146
147
    def __init__(self, batches: Iterable[SuggestionBatch]) -> None:
148
        """Initialize a new SuggestionResults from an iterable that provides
149
        SuggestionBatch objects."""
150
151
        self.batches = batches
152
153
    def filter(
154
        self, limit: int | None = None, threshold: float = 0.0
155
    ) -> SuggestionResults:
156
        """Return a view of these suggestions, filtered by the given limit
157
        and/or threshold, as another SuggestionResults object."""
158
159
        return SuggestionResults(
160
            (batch.filter(limit, threshold) for batch in self.batches)
161
        )
162
163
    def __iter__(self) -> itertools.chain:
164
        return iter(itertools.chain.from_iterable(self.batches))
165