Passed
Pull Request — main (#681)
by Osma
10:12 queued 07:10
created

annif.suggestion.SuggestionBatch.filter()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
3
import collections
4
import itertools
5
6
import numpy as np
7
from scipy.sparse import csr_array, dok_array
8
9
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
10
11
12
def vector_to_suggestions(vector, limit):
13
    limit = min(len(vector), limit)
14
    topk_idx = np.argpartition(vector, -limit)[-limit:]
15
    return (
16
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
17
    )
18
19
20
def filter_suggestion(preds, limit=None, threshold=0.0):
21
    """filter a 2D sparse suggestion array (csr_array), retaining only the
22
    top K suggestions with a score above or equal to the threshold for each
23
    individual prediction; the rest will be left as zeros"""
24
25
    filtered = dok_array(preds.shape, dtype=np.float32)
26
    for row in range(preds.shape[0]):
27
        arow = preds.getrow(row)
28
        top_k = arow.data.argsort()[::-1]
29
        if limit is not None:
30
            top_k = top_k[:limit]
31
        for idx in top_k:
32
            val = arow.data[idx]
33
            if val < threshold:
34
                break
35
            filtered[row, arow.indices[idx]] = val
36
    return filtered.tocsr()
37
38
39
class SuggestionResult:
40
    """Suggestions for a single document, backed by a row of a sparse array."""
41
42
    def __init__(self, array, idx):
43
        self._array = array
44
        self._idx = idx
45
46
    def __iter__(self):
47
        _, cols = self._array[[self._idx], :].nonzero()
48
        suggestions = [
49
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
50
            for col in cols
51
        ]
52
        return iter(
53
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
54
        )
55
56
    def as_vector(self):
57
        return self._array[[self._idx], :].toarray()[0]
58
59
    def __len__(self):
60
        _, cols = self._array[[self._idx], :].nonzero()
61
        return len(cols)
62
63
64
class SuggestionBatch:
65
    """Subject suggestions for a batch of documents."""
66
67
    def __init__(self, array):
68
        """Create a new SuggestionBatch from a csr_array"""
69
        assert isinstance(array, csr_array)
70
        self.array = array
71
72
    @classmethod
73
    def from_sequence(cls, suggestion_results, subject_index, limit=None):
74
        """Create a new SuggestionBatch from a sequence where each item is
75
        a sequence of SubjectSuggestion objects."""
76
77
        deprecated = set(subject_index.deprecated_ids())
78
79
        ar = dok_array((len(suggestion_results), len(subject_index)), dtype=np.float32)
80
        for idx, result in enumerate(suggestion_results):
81
            for suggestion in itertools.islice(result, limit):
82
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
83
                    continue
84
                ar[idx, suggestion.subject_id] = min(suggestion.score, 1.0)
85
        return cls(ar.tocsr())
86
87
    @classmethod
88
    def from_averaged(cls, batches, weights):
89
        """Create a new SuggestionBatch where the subject scores are the
90
        weighted average of scores in several SuggestionBatches"""
91
92
        avg_array = sum(
93
            [batch.array * weight for batch, weight in zip(batches, weights)]
94
        ) / sum(weights)
95
        return SuggestionBatch(avg_array)
96
97
    def filter(self, limit=None, threshold=0.0):
98
        """Return a subset of the hits, filtered by the given limit and
99
        score threshold, as another SuggestionBatch object."""
100
101
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
102
103
    def __getitem__(self, idx):
104
        if idx < 0 or idx >= len(self):
105
            raise IndexError
106
        return SuggestionResult(self.array, idx)
107
108
    def __len__(self):
109
        return self.array.shape[0]
110
111
112
class SuggestionResults:
113
    """Subject suggestions for a potentially very large number of documents."""
114
115
    def __init__(self, batches):
116
        """Initialize a new SuggestionResults from an iterable that provides
117
        SuggestionBatch objects."""
118
119
        self.batches = batches
120
121
    def filter(self, limit=None, threshold=0.0):
122
        """Return a view of these suggestions, filtered by the given limit
123
        and/or threshold, as another SuggestionResults object."""
124
125
        return SuggestionResults(
126
            (batch.filter(limit, threshold) for batch in self.batches)
127
        )
128
129
    def __iter__(self):
130
        return iter(itertools.chain.from_iterable(self.batches))
131