Passed
Push — issue678-refactor-suggestionre... ( ec0260...82f1b2 )
by Osma
05:25 queued 02:48
created

annif.suggestion   A

Complexity

Total Complexity 39

Size/Duplication

Total Lines 182
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 111
dl 0
loc 182
rs 9.28
c 0
b 0
f 0
wmc 39

21 Methods

Rating   Name   Duplication   Size   Complexity  
A SuggestionResult.__len__() 0 4 1
A SuggestionResult.as_vector() 0 6 1
A SuggestionResult.__iter__() 0 5 1
A VectorSuggestionResult._vector_to_list_suggestion() 0 8 3
A VectorSuggestionResult.as_vector() 0 5 2
A VectorSuggestionResult.__len__() 0 2 1
A VectorSuggestionResult.subject_order() 0 5 2
A VectorSuggestionResult.__iter__() 0 4 2
A VectorSuggestionResult.__init__() 0 6 1
A SuggestionBatch.__init__() 0 3 1
A SparseSuggestionResult.as_vector() 0 5 2
A SuggestionBatch.filter() 0 5 1
A SuggestionBatch.__len__() 0 2 1
A SuggestionBatch.__getitem__() 0 4 3
A SuggestionResults.filter() 0 6 1
A SparseSuggestionResult.__init__() 0 3 1
A SuggestionResults.__iter__() 0 2 1
A SuggestionResults.__init__() 0 5 1
A SparseSuggestionResult.__iter__() 0 8 2
A SuggestionBatch.from_sequence() 0 14 5
A SparseSuggestionResult.__len__() 0 3 1

1 Function

Rating   Name   Duplication   Size   Complexity  
A filter_suggestion() 0 17 5
1
"""Representing suggested subjects."""
2
3
import abc
4
import collections
5
import itertools
6
7
import numpy as np
8
from scipy.sparse import dok_array
9
10
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
11
WeightedSuggestionsBatch = collections.namedtuple(
12
    "WeightedSuggestionsBatch", "hit_sets weight subjects"
13
)
14
15
16
def filter_suggestion(preds, limit=None, threshold=0.0):
17
    """filter a 2D sparse suggestion array (csr_array), retaining only the
18
    top K suggestions with a score above or equal to the threshold for each
19
    individual prediction; the rest will be left as zeros"""
20
21
    filtered = dok_array(preds.shape, dtype=np.float32)
22
    for row in range(preds.shape[0]):
23
        arow = preds.getrow(row)
24
        top_k = arow.data.argsort()[::-1]
25
        if limit is not None:
26
            top_k = top_k[:limit]
27
        for idx in top_k:
28
            val = arow.data[idx]
29
            if val < threshold:
30
                break
31
            filtered[row, arow.indices[idx]] = val
32
    return filtered.tocsr()
33
34
35
class SuggestionResult(metaclass=abc.ABCMeta):
36
    """Abstract base class for a set of hits returned by an analysis
37
    operation."""
38
39
    @abc.abstractmethod
40
    def __iter__(self):
41
        """Return the hits as an iterator that returns SubjectSuggestion objects,
42
        highest scores first."""
43
        pass  # pragma: no cover
44
45
    @abc.abstractmethod
46
    def as_vector(self, size, destination=None):
47
        """Return the hits as a one-dimensional score vector of given size.
48
        If destination array is given (not None) it will be used, otherwise a
49
        new array will be created."""
50
        pass  # pragma: no cover
51
52
    @abc.abstractmethod
53
    def __len__(self):
54
        """Return the number of hits with non-zero scores."""
55
        pass  # pragma: no cover
56
57
58
class VectorSuggestionResult(SuggestionResult):
59
    """SuggestionResult implementation based primarily on NumPy vectors."""
60
61
    def __init__(self, vector):
62
        vector_f32 = vector.astype(np.float32)
63
        # limit scores to the range 0.0 .. 1.0
64
        self._vector = np.minimum(np.maximum(vector_f32, 0.0), 1.0)
65
        self._subject_order = None
66
        self._lsr = None
67
68
    def _vector_to_list_suggestion(self):
69
        hits = []
70
        for subject_id in self.subject_order:
71
            score = self._vector[subject_id]
72
            if score <= 0.0:
73
                break  # we can skip the remaining ones
74
            hits.append(SubjectSuggestion(subject_id=subject_id, score=float(score)))
75
        return hits
76
77
    @property
78
    def subject_order(self):
79
        if self._subject_order is None:
80
            self._subject_order = np.argsort(self._vector)[::-1]
81
        return self._subject_order
82
83
    def __iter__(self):
84
        if self._lsr is None:
85
            self._lsr = self._vector_to_list_suggestion()
86
        return iter(self._lsr)
87
88
    def as_vector(self, size, destination=None):
89
        if destination is not None:
90
            np.copyto(destination, self._vector)
91
            return destination
92
        return self._vector
93
94
    def __len__(self):
95
        return (self._vector > 0.0).sum()
96
97
98
class SparseSuggestionResult(SuggestionResult):
99
    """SuggestionResult implementation backed by a single row of a sparse array."""
100
101
    def __init__(self, array, idx):
102
        self._array = array
103
        self._idx = idx
104
105
    def __iter__(self):
106
        _, cols = self._array[[self._idx], :].nonzero()
107
        suggestions = [
108
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
109
            for col in cols
110
        ]
111
        return iter(
112
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
113
        )
114
115
    def as_vector(self, size, destination=None):
116
        if destination is not None:
117
            print("as_vector called with destination not None")
118
            return None
119
        return self._array[[self._idx], :].toarray()[0]
120
121
    def __len__(self):
122
        _, cols = self._array[[self._idx], :].nonzero()
123
        return len(cols)
124
125
126
class SuggestionBatch:
127
    """Subject suggestions for a batch of documents."""
128
129
    def __init__(self, array):
130
        """Create a new SuggestionBatch from a csr_array"""
131
        self.array = array
132
133
    @classmethod
134
    def from_sequence(cls, suggestion_results, subject_index, limit=None):
135
        """Create a new SuggestionBatch from a sequence of SuggestionResult objects."""
136
137
        deprecated = set(subject_index.deprecated_ids())
138
139
        # create a dok_array for fast construction
140
        ar = dok_array((len(suggestion_results), len(subject_index)), dtype=np.float32)
141
        for idx, result in enumerate(suggestion_results):
142
            for suggestion in itertools.islice(result, limit):
143
                if suggestion.subject_id in deprecated or suggestion.score < 0.0:
144
                    continue
145
                ar[idx, suggestion.subject_id] = min(suggestion.score, 1.0)
146
        return cls(ar.tocsr())
147
148
    def filter(self, limit=None, threshold=0.0):
149
        """Return a subset of the hits, filtered by the given limit and
150
        score threshold, as another SuggestionBatch object."""
151
152
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
153
154
    def __getitem__(self, idx):
155
        if idx < 0 or idx >= len(self):
156
            raise IndexError
157
        return SparseSuggestionResult(self.array, idx)
158
159
    def __len__(self):
160
        return self.array.shape[0]
161
162
163
class SuggestionResults:
164
    """Subject suggestions for a potentially very large number of documents."""
165
166
    def __init__(self, batches):
167
        """Initialize a new SuggestionResults from an iterable that provides
168
        SuggestionBatch objects."""
169
170
        self.batches = batches
171
172
    def filter(self, limit=None, threshold=0.0):
173
        """Return a view of these suggestions, filtered by the given limit
174
        and/or threshold, as another SuggestionResults object."""
175
176
        return SuggestionResults(
177
            (batch.filter(limit, threshold) for batch in self.batches)
178
        )
179
180
    def __iter__(self):
181
        return iter(itertools.chain.from_iterable(self.batches))
182