Passed
Push — issue678-refactor-suggestionre... ( 96588c...e2c657 )
by Osma
02:47
created

annif.suggestion.SparseSuggestionResult.__iter__()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
3
import abc
4
import collections
5
import itertools
6
7
import numpy as np
8
from scipy.sparse import dok_array
9
10
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
11
WeightedSuggestionsBatch = collections.namedtuple(
12
    "WeightedSuggestionsBatch", "hit_sets weight subjects"
13
)
14
15
16
def filter_suggestion(preds, limit=None, threshold=0.0):
17
    """filter a 2D sparse suggestion array (csr_array), retaining only the
18
    top K suggestions with a score above or equal to the threshold for each
19
    individual prediction; the rest will be left as zeros"""
20
21
    filtered = dok_array(preds.shape, dtype=np.float32)
22
    for row in range(preds.shape[0]):
23
        arow = preds.getrow(row)
24
        top_k = arow.data.argsort()[::-1]
25
        if limit is not None:
26
            top_k = top_k[:limit]
27
        for idx in top_k:
28
            val = arow.data[idx]
29
            if val < threshold:
30
                break
31
            filtered[row, arow.indices[idx]] = val
32
    return filtered.tocsr()
33
34
35
class SuggestionResult(metaclass=abc.ABCMeta):
36
    """Abstract base class for a set of hits returned by an analysis
37
    operation."""
38
39
    @abc.abstractmethod
40
    def __iter__(self):
41
        """Return the hits as an iterator that returns SubjectSuggestion objects,
42
        highest scores first."""
43
        pass  # pragma: no cover
44
45
    @abc.abstractmethod
46
    def as_vector(self, size, destination=None):
47
        """Return the hits as a one-dimensional score vector of given size.
48
        If destination array is given (not None) it will be used, otherwise a
49
        new array will be created."""
50
        pass  # pragma: no cover
51
52
    @abc.abstractmethod
53
    def filter(self, subject_index, limit=None, threshold=0.0):
54
        """Return a subset of the hits, filtered by the given limit and
55
        score threshold, as another SuggestionResult object."""
56
        pass  # pragma: no cover
57
58
    @abc.abstractmethod
59
    def __len__(self):
60
        """Return the number of hits with non-zero scores."""
61
        pass  # pragma: no cover
62
63
64
class VectorSuggestionResult(SuggestionResult):
65
    """SuggestionResult implementation based primarily on NumPy vectors."""
66
67
    def __init__(self, vector):
68
        vector_f32 = vector.astype(np.float32)
69
        # limit scores to the range 0.0 .. 1.0
70
        self._vector = np.minimum(np.maximum(vector_f32, 0.0), 1.0)
71
        self._subject_order = None
72
        self._lsr = None
73
74
    def _vector_to_list_suggestion(self):
75
        hits = []
76
        for subject_id in self.subject_order:
77
            score = self._vector[subject_id]
78
            if score <= 0.0:
79
                break  # we can skip the remaining ones
80
            hits.append(SubjectSuggestion(subject_id=subject_id, score=float(score)))
81
        return ListSuggestionResult(hits)
82
83
    @property
84
    def subject_order(self):
85
        if self._subject_order is None:
86
            self._subject_order = np.argsort(self._vector)[::-1]
87
        return self._subject_order
88
89
    def __iter__(self):
90
        if self._lsr is None:
91
            self._lsr = self._vector_to_list_suggestion()
92
        return iter(self._lsr)
93
94
    def as_vector(self, size, destination=None):
95
        if destination is not None:
96
            np.copyto(destination, self._vector)
97
            return destination
98
        return self._vector
99
100
    def filter(self, subject_index, limit=None, threshold=0.0):
101
        mask = self._vector > threshold
102
        deprecated_ids = subject_index.deprecated_ids()
103
        if limit is not None:
104
            limit_mask = np.zeros_like(self._vector, dtype=bool)
105
            deprecated_set = set(deprecated_ids)
106
            top_k_subjects = itertools.islice(
107
                (subj for subj in self.subject_order if subj not in deprecated_set),
108
                limit,
109
            )
110
            limit_mask[list(top_k_subjects)] = True
111
            mask = mask & limit_mask
112
        else:
113
            deprecated_mask = np.ones_like(self._vector, dtype=bool)
114
            deprecated_mask[deprecated_ids] = False
115
            mask = mask & deprecated_mask
116
        vsr = VectorSuggestionResult(self._vector * mask)
117
        return ListSuggestionResult(vsr)
118
119
    def __len__(self):
120
        return (self._vector > 0.0).sum()
121
122
123
class ListSuggestionResult(SuggestionResult):
124
    """SuggestionResult implementation based primarily on lists of hits."""
125
126
    def __init__(self, hits):
127
        self._list = [self._enforce_score_range(hit) for hit in hits if hit.score > 0.0]
128
        self._vector = None
129
130
    @staticmethod
131
    def _enforce_score_range(hit):
132
        if hit.score > 1.0:
133
            return hit._replace(score=1.0)
134
        return hit
135
136
    def _list_to_vector(self, size, destination):
137
        if destination is None:
138
            destination = np.zeros(size, dtype=np.float32)
139
140
        for hit in self._list:
141
            if hit.subject_id is not None:
142
                destination[hit.subject_id] = hit.score
143
        return destination
144
145
    def __iter__(self):
146
        return iter(self._list)
147
148
    def as_vector(self, size, destination=None):
149
        if self._vector is None:
150
            self._vector = self._list_to_vector(size, destination)
151
        return self._vector
152
153
    def filter(self, subject_index, limit=None, threshold=0.0):
154
        hits = sorted(self._list, key=lambda hit: hit.score, reverse=True)
155
        filtered_hits = [
156
            hit
157
            for hit in hits
158
            if hit.score >= threshold and hit.score > 0.0 and hit.subject_id is not None
159
        ]
160
        if limit is not None:
161
            filtered_hits = filtered_hits[:limit]
162
        return ListSuggestionResult(filtered_hits)
163
164
    def __len__(self):
165
        return len(self._list)
166
167
168
class SparseSuggestionResult(SuggestionResult):
169
    """SuggestionResult implementation backed by a single row of a sparse array."""
170
171
    def __init__(self, array, idx):
172
        self._array = array
173
        self._idx = idx
174
175
    def __iter__(self):
176
        _, cols = self._array[[self._idx], :].nonzero()
177
        suggestions = [
178
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
179
            for col in cols
180
        ]
181
        return iter(
182
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
183
        )
184
185
    def as_vector(self, size, destination=None):
186
        if destination is not None:
187
            print("as_vector called with destination not None")
188
            return None
189
        return self._array[[self._idx], :].toarray()[0]
190
191
    def filter(self, subject_index, limit=None, threshold=0.0):
192
        lsr = ListSuggestionResult(self)
193
        return lsr.filter(subject_index, limit, threshold)
194
195
    def __len__(self):
196
        _, cols = self._array[[self._idx], :].nonzero()
197
        return len(cols)
198
199
200
class SuggestionBatch:
201
    """Subject suggestions for a batch of documents."""
202
203
    def __init__(self, array):
204
        """Create a new SuggestionBatch from a csr_array"""
205
        self.array = array
206
207
    @classmethod
208
    def from_sequence(cls, suggestion_results, vocab_size):
209
        """Create a new SuggestionBatch from a sequence of SuggestionResult objects."""
210
211
        # create a dok_array for fast construction
212
        ar = dok_array((len(suggestion_results), vocab_size), dtype=np.float32)
213
        for idx, result in enumerate(suggestion_results):
214
            for suggestion in result:
215
                ar[idx, suggestion.subject_id] = suggestion.score
216
        return cls(ar.tocsr())
217
218
    def filter(self, limit=None, threshold=0.0):
219
        """Return a subset of the hits, filtered by the given limit and
220
        score threshold, as another SuggestionBatch object."""
221
222
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
223
224
    def __getitem__(self, idx):
225
        if idx < 0 or idx >= len(self):
226
            raise IndexError
227
        return SparseSuggestionResult(self.array, idx)
228
229
    def __len__(self):
230
        return self.array.shape[0]
231
232
233
class SuggestionResults:
234
    """Subject suggestions for a potentially very large number of documents."""
235
236
    def __init__(self, batches):
237
        """Initialize a new SuggestionResults from an iterable that provides
238
        SuggestionBatch objects."""
239
240
        self.batches = batches
241
242
    def filter(self, limit=None, threshold=0.0):
243
        """Return a view of these suggestions, filtered by the given limit
244
        and/or threshold, as another SuggestionResults object."""
245
246
        return SuggestionResults(
247
            (batch.filter(limit, threshold) for batch in self.batches)
248
        )
249
250
    def __iter__(self):
251
        return iter(itertools.chain.from_iterable(self.batches))
252