Completed
Push — main ( 693ab2...415d94 )
by Osma
18s queued 16s
created

annif.suggestion.SuggestionResults.__init__()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
3
import collections
4
import itertools
5
6
import numpy as np
7
from scipy.sparse import csr_array
8
9
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
10
11
12
def vector_to_suggestions(vector, limit):
13
    limit = min(len(vector), limit)
14
    topk_idx = np.argpartition(vector, -limit)[-limit:]
15
    return (
16
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
17
    )
18
19
20
def filter_suggestion(preds, limit=None, threshold=0.0):
21
    """filter a 2D sparse suggestion array (csr_array), retaining only the
22
    top K suggestions with a score above or equal to the threshold for each
23
    individual prediction; the rest will be left as zeros"""
24
25
    if limit == 0:
26
        return csr_array(preds.shape, dtype=np.float32)  # empty
27
28
    data, rows, cols = [], [], []
29
    for row in range(preds.shape[0]):
30
        arow = preds.getrow(row)
31
        if limit is not None and limit < len(arow.data):
32
            topk_idx = arow.data.argpartition(-limit)[-limit:]
33
        else:
34
            topk_idx = range(len(arow.data))
35
        for idx in topk_idx:
36
            if arow.data[idx] >= threshold:
37
                data.append(arow.data[idx])
38
                rows.append(row)
39
                cols.append(arow.indices[idx])
40
    return csr_array((data, (rows, cols)), shape=preds.shape, dtype=np.float32)
41
42
43
class SuggestionResult:
44
    """Suggestions for a single document, backed by a row of a sparse array."""
45
46
    def __init__(self, array, idx):
47
        self._array = array
48
        self._idx = idx
49
50
    def __iter__(self):
51
        _, cols = self._array[[self._idx], :].nonzero()
52
        suggestions = [
53
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
54
            for col in cols
55
        ]
56
        return iter(
57
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
58
        )
59
60
    def as_vector(self):
61
        return self._array[[self._idx], :].toarray()[0]
62
63
    def __len__(self):
64
        _, cols = self._array[[self._idx], :].nonzero()
65
        return len(cols)
66
67
68
class SuggestionBatch:
69
    """Subject suggestions for a batch of documents."""
70
71
    def __init__(self, array):
72
        """Create a new SuggestionBatch from a csr_array"""
73
        assert isinstance(array, csr_array)
74
        self.array = array
75
76
    @classmethod
77
    def from_sequence(cls, suggestion_results, subject_index, limit=None):
78
        """Create a new SuggestionBatch from a sequence where each item is
79
        a sequence of SubjectSuggestion objects."""
80
81
        deprecated = set(subject_index.deprecated_ids())
82
        data, rows, cols = [], [], []
83
        for idx, result in enumerate(suggestion_results):
84
            for suggestion in itertools.islice(result, limit):
85
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
86
                    continue
87
                data.append(min(suggestion.score, 1.0))
88
                rows.append(idx)
89
                cols.append(suggestion.subject_id)
90
        return cls(
91
            csr_array(
92
                (data, (rows, cols)),
93
                shape=(len(suggestion_results), len(subject_index)),
94
                dtype=np.float32,
95
            )
96
        )
97
98
    @classmethod
99
    def from_averaged(cls, batches, weights):
100
        """Create a new SuggestionBatch where the subject scores are the
101
        weighted average of scores in several SuggestionBatches"""
102
103
        avg_array = sum(
104
            [batch.array * weight for batch, weight in zip(batches, weights)]
105
        ) / sum(weights)
106
        return SuggestionBatch(avg_array)
107
108
    def filter(self, limit=None, threshold=0.0):
109
        """Return a subset of the hits, filtered by the given limit and
110
        score threshold, as another SuggestionBatch object."""
111
112
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
113
114
    def __getitem__(self, idx):
115
        if idx < 0 or idx >= len(self):
116
            raise IndexError
117
        return SuggestionResult(self.array, idx)
118
119
    def __len__(self):
120
        return self.array.shape[0]
121
122
123
class SuggestionResults:
124
    """Subject suggestions for a potentially very large number of documents."""
125
126
    def __init__(self, batches):
127
        """Initialize a new SuggestionResults from an iterable that provides
128
        SuggestionBatch objects."""
129
130
        self.batches = batches
131
132
    def filter(self, limit=None, threshold=0.0):
133
        """Return a view of these suggestions, filtered by the given limit
134
        and/or threshold, as another SuggestionResults object."""
135
136
        return SuggestionResults(
137
            (batch.filter(limit, threshold) for batch in self.batches)
138
        )
139
140
    def __iter__(self):
141
        return iter(itertools.chain.from_iterable(self.batches))
142