Passed
Push — issue678-refactor-suggestionre... ( 818ba2...3a8eec )
by Osma
03:00
created

annif.suggestion.SuggestionResult.filter()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 4
1
"""Representing suggested subjects."""
2
3
import abc
4
import collections
5
import itertools
6
7
import numpy as np
8
from scipy.sparse import dok_array
9
10
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
11
WeightedSuggestionsBatch = collections.namedtuple(
12
    "WeightedSuggestionsBatch", "hit_sets weight subjects"
13
)
14
15
16
class SuggestionResult(metaclass=abc.ABCMeta):
17
    """Abstract base class for a set of hits returned by an analysis
18
    operation."""
19
20
    @abc.abstractmethod
21
    def as_list(self):
22
        """Return the hits as an ordered sequence of SubjectSuggestion objects,
23
        highest scores first."""
24
        pass  # pragma: no cover
25
26
    @abc.abstractmethod
27
    def as_vector(self, size, destination=None):
28
        """Return the hits as a one-dimensional score vector of given size.
29
        If destination array is given (not None) it will be used, otherwise a
30
        new array will be created."""
31
        pass  # pragma: no cover
32
33
    @abc.abstractmethod
34
    def filter(self, subject_index, limit=None, threshold=0.0):
35
        """Return a subset of the hits, filtered by the given limit and
36
        score threshold, as another SuggestionResult object."""
37
        pass  # pragma: no cover
38
39
    @abc.abstractmethod
40
    def __len__(self):
41
        """Return the number of hits with non-zero scores."""
42
        pass  # pragma: no cover
43
44
45
class LazySuggestionResult(SuggestionResult):
46
    """SuggestionResult implementation that wraps another SuggestionResult which
47
    is initialized lazily only when it is actually accessed. Method calls
48
    will be proxied to the wrapped SuggestionResult."""
49
50
    def __init__(self, construct):
51
        """Create the proxy object. The given construct function will be
52
        called to create the actual SuggestionResult when it is needed."""
53
        self._construct = construct
54
        self._object = None
55
56
    def _initialize(self):
57
        if self._object is None:
58
            self._object = self._construct()
59
60
    def as_list(self):
61
        self._initialize()
62
        return self._object.as_list()
63
64
    def as_vector(self, size, destination=None):
65
        self._initialize()
66
        return self._object.as_vector(size, destination)
67
68
    def filter(self, subject_index, limit=None, threshold=0.0):
69
        self._initialize()
70
        return self._object.filter(subject_index, limit, threshold)
71
72
    def __len__(self):
73
        self._initialize()
74
        return len(self._object)
75
76
77
class VectorSuggestionResult(SuggestionResult):
78
    """SuggestionResult implementation based primarily on NumPy vectors."""
79
80
    def __init__(self, vector):
81
        vector_f32 = vector.astype(np.float32)
82
        # limit scores to the range 0.0 .. 1.0
83
        self._vector = np.minimum(np.maximum(vector_f32, 0.0), 1.0)
84
        self._subject_order = None
85
        self._lsr = None
86
87
    def _vector_to_list_suggestion(self):
88
        hits = []
89
        for subject_id in self.subject_order:
90
            score = self._vector[subject_id]
91
            if score <= 0.0:
92
                break  # we can skip the remaining ones
93
            hits.append(SubjectSuggestion(subject_id=subject_id, score=float(score)))
94
        return ListSuggestionResult(hits)
95
96
    @property
97
    def subject_order(self):
98
        if self._subject_order is None:
99
            self._subject_order = np.argsort(self._vector)[::-1]
100
        return self._subject_order
101
102
    def as_list(self):
103
        if self._lsr is None:
104
            self._lsr = self._vector_to_list_suggestion()
105
        return self._lsr.as_list()
106
107
    def as_vector(self, size, destination=None):
108
        if destination is not None:
109
            np.copyto(destination, self._vector)
110
            return destination
111
        return self._vector
112
113
    def filter(self, subject_index, limit=None, threshold=0.0):
114
        mask = self._vector > threshold
115
        deprecated_ids = subject_index.deprecated_ids()
116
        if limit is not None:
117
            limit_mask = np.zeros_like(self._vector, dtype=bool)
118
            deprecated_set = set(deprecated_ids)
119
            top_k_subjects = itertools.islice(
120
                (subj for subj in self.subject_order if subj not in deprecated_set),
121
                limit,
122
            )
123
            limit_mask[list(top_k_subjects)] = True
124
            mask = mask & limit_mask
125
        else:
126
            deprecated_mask = np.ones_like(self._vector, dtype=bool)
127
            deprecated_mask[deprecated_ids] = False
128
            mask = mask & deprecated_mask
129
        vsr = VectorSuggestionResult(self._vector * mask)
130
        return ListSuggestionResult(vsr.as_list())
131
132
    def __len__(self):
133
        return (self._vector > 0.0).sum()
134
135
136
class ListSuggestionResult(SuggestionResult):
137
    """SuggestionResult implementation based primarily on lists of hits."""
138
139
    def __init__(self, hits):
140
        self._list = [self._enforce_score_range(hit) for hit in hits if hit.score > 0.0]
141
        self._vector = None
142
143
    @staticmethod
144
    def _enforce_score_range(hit):
145
        if hit.score > 1.0:
146
            return hit._replace(score=1.0)
147
        return hit
148
149
    def _list_to_vector(self, size, destination):
150
        if destination is None:
151
            destination = np.zeros(size, dtype=np.float32)
152
153
        for hit in self._list:
154
            if hit.subject_id is not None:
155
                destination[hit.subject_id] = hit.score
156
        return destination
157
158
    def as_list(self):
159
        return self._list
160
161
    def as_vector(self, size, destination=None):
162
        if self._vector is None:
163
            self._vector = self._list_to_vector(size, destination)
164
        return self._vector
165
166
    def filter(self, subject_index, limit=None, threshold=0.0):
167
        hits = sorted(self._list, key=lambda hit: hit.score, reverse=True)
168
        filtered_hits = [
169
            hit
170
            for hit in hits
171
            if hit.score >= threshold and hit.score > 0.0 and hit.subject_id is not None
172
        ]
173
        if limit is not None:
174
            filtered_hits = filtered_hits[:limit]
175
        return ListSuggestionResult(filtered_hits)
176
177
    def __len__(self):
178
        return len(self._list)
179
180
181
class SparseSuggestionResult(SuggestionResult):
182
    """SuggestionResult implementation backed by a single row of a sparse array."""
183
184
    def __init__(self, array, idx):
185
        self._array = array
186
        self._idx = idx
187
188
    def as_list(self):
189
        _, cols = self._array[[self._idx], :].nonzero()
190
        suggestions = [
191
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
192
            for col in cols
193
        ]
194
        return sorted(
195
            suggestions, key=lambda suggestion: suggestion.score, reverse=True
196
        )
197
198
    def as_vector(self, size, destination=None):
199
        if destination is not None:
200
            print("as_vector called with destination not None")
201
            return None
202
        return self._array[[self._idx], :].toarray()[0]
203
204
    def filter(self, subject_index, limit=None, threshold=0.0):
205
        lsr = ListSuggestionResult(self.as_list())
206
        return lsr.filter(subject_index, limit, threshold)
207
208
    def __len__(self):
209
        _, cols = self._array[[self._idx], :].nonzero()
210
        return len(cols)
211
212
213
class SuggestionBatch:
214
    """Subject suggestions for a batch of documents."""
215
216
    def __init__(self, array):
217
        """Create a new SuggestionBatch from a csr_array"""
218
        self.array = array
219
220
    @classmethod
221
    def from_sequence(cls, suggestion_results, vocab_size):
222
        """Create a new SuggestionBatch from a sequence of SuggestionResult objects."""
223
224
        # create a dok_array for fast construction
225
        ar = dok_array((len(suggestion_results), vocab_size), dtype=np.float32)
226
        for idx, result in enumerate(suggestion_results):
227
            for suggestion in result.as_list():
228
                ar[idx, suggestion.subject_id] = suggestion.score
229
        return cls(ar.tocsr())
230
231
    def filter(self, limit=None, threshold=0.0):
232
        """Return a subset of the hits, filtered by the given limit and
233
        score threshold, as another SuggestionBatch object."""
234
235
        from annif.util import filter_suggestion
236
237
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
238
239
    def __getitem__(self, idx):
240
        if idx < 0 or idx >= len(self):
241
            raise IndexError
242
        return SparseSuggestionResult(self.array, idx)
243
244
    def __len__(self):
245
        return self.array.shape[0]
246
247
248
class SuggestionResults:
249
    """Subject suggestions for a potentially very large number of documents."""
250
251
    def __init__(self, batches):
252
        """Initialize a new SuggestionResults from an iterable that provides
253
        SuggestionBatch objects."""
254
255
        self.batches = batches
256
257
    def filter(self, limit=None, threshold=0.0):
258
        """Return a view of these suggestions, filtered by the given limit
259
        and/or threshold, as another SuggestionResults object."""
260
261
        return SuggestionResults(
262
            (batch.filter(limit, threshold) for batch in self.batches)
263
        )
264
265
    def __iter__(self):
266
        return iter(itertools.chain.from_iterable(self.batches))
267