Passed
Push — issue678-refactor-suggestionre... ( 82f1b2...911e14 )
by Osma
02:35
created

SparseSuggestionResult.as_vector()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
3
import abc
4
import collections
5
import itertools
6
7
import numpy as np
8
from scipy.sparse import dok_array
9
10
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
11
WeightedSuggestionsBatch = collections.namedtuple(
12
    "WeightedSuggestionsBatch", "hit_sets weight subjects"
13
)
14
15
16
def filter_suggestion(preds, limit=None, threshold=0.0):
17
    """filter a 2D sparse suggestion array (csr_array), retaining only the
18
    top K suggestions with a score above or equal to the threshold for each
19
    individual prediction; the rest will be left as zeros"""
20
21
    filtered = dok_array(preds.shape, dtype=np.float32)
22
    for row in range(preds.shape[0]):
23
        arow = preds.getrow(row)
24
        top_k = arow.data.argsort()[::-1]
25
        if limit is not None:
26
            top_k = top_k[:limit]
27
        for idx in top_k:
28
            val = arow.data[idx]
29
            if val < threshold:
30
                break
31
            filtered[row, arow.indices[idx]] = val
32
    return filtered.tocsr()
33
34
35
class SuggestionResult(metaclass=abc.ABCMeta):
36
    """Abstract base class for a set of hits returned by an analysis
37
    operation."""
38
39
    @abc.abstractmethod
40
    def __iter__(self):
41
        """Return the hits as an iterator that returns SubjectSuggestion objects,
42
        highest scores first."""
43
        pass  # pragma: no cover
44
45
    @abc.abstractmethod
46
    def __len__(self):
47
        """Return the number of hits with non-zero scores."""
48
        pass  # pragma: no cover
49
50
51
class VectorSuggestionResult(SuggestionResult):
52
    """SuggestionResult implementation based primarily on NumPy vectors."""
53
54
    def __init__(self, vector):
55
        vector_f32 = vector.astype(np.float32)
56
        # limit scores to the range 0.0 .. 1.0
57
        self._vector = np.minimum(np.maximum(vector_f32, 0.0), 1.0)
58
        self._subject_order = None
59
        self._lsr = None
60
61
    def _vector_to_list_suggestion(self):
62
        hits = []
63
        for subject_id in self.subject_order:
64
            score = self._vector[subject_id]
65
            if score <= 0.0:
66
                break  # we can skip the remaining ones
67
            hits.append(SubjectSuggestion(subject_id=subject_id, score=float(score)))
68
        return hits
69
70
    @property
71
    def subject_order(self):
72
        if self._subject_order is None:
73
            self._subject_order = np.argsort(self._vector)[::-1]
74
        return self._subject_order
75
76
    def __iter__(self):
77
        if self._lsr is None:
78
            self._lsr = self._vector_to_list_suggestion()
79
        return iter(self._lsr)
80
81
    def __len__(self):
82
        return (self._vector > 0.0).sum()
83
84
85
class SparseSuggestionResult(SuggestionResult):
86
    """SuggestionResult implementation backed by a single row of a sparse array."""
87
88
    def __init__(self, array, idx):
89
        self._array = array
90
        self._idx = idx
91
92
    def __iter__(self):
93
        _, cols = self._array[[self._idx], :].nonzero()
94
        suggestions = [
95
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
96
            for col in cols
97
        ]
98
        return iter(
99
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
100
        )
101
102
    def as_vector(self, size):
103
        return self._array[[self._idx], :].toarray()[0]
104
105
    def __len__(self):
106
        _, cols = self._array[[self._idx], :].nonzero()
107
        return len(cols)
108
109
110
class SuggestionBatch:
111
    """Subject suggestions for a batch of documents."""
112
113
    def __init__(self, array):
114
        """Create a new SuggestionBatch from a csr_array"""
115
        self.array = array
116
117
    @classmethod
118
    def from_sequence(cls, suggestion_results, subject_index, limit=None):
119
        """Create a new SuggestionBatch from a sequence of SuggestionResult objects."""
120
121
        deprecated = set(subject_index.deprecated_ids())
122
123
        # create a dok_array for fast construction
124
        ar = dok_array((len(suggestion_results), len(subject_index)), dtype=np.float32)
125
        for idx, result in enumerate(suggestion_results):
126
            for suggestion in itertools.islice(result, limit):
127
                if suggestion.subject_id in deprecated or suggestion.score < 0.0:
128
                    continue
129
                ar[idx, suggestion.subject_id] = min(suggestion.score, 1.0)
130
        return cls(ar.tocsr())
131
132
    def filter(self, limit=None, threshold=0.0):
133
        """Return a subset of the hits, filtered by the given limit and
134
        score threshold, as another SuggestionBatch object."""
135
136
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
137
138
    def __getitem__(self, idx):
139
        if idx < 0 or idx >= len(self):
140
            raise IndexError
141
        return SparseSuggestionResult(self.array, idx)
142
143
    def __len__(self):
144
        return self.array.shape[0]
145
146
147
class SuggestionResults:
148
    """Subject suggestions for a potentially very large number of documents."""
149
150
    def __init__(self, batches):
151
        """Initialize a new SuggestionResults from an iterable that provides
152
        SuggestionBatch objects."""
153
154
        self.batches = batches
155
156
    def filter(self, limit=None, threshold=0.0):
157
        """Return a view of these suggestions, filtered by the given limit
158
        and/or threshold, as another SuggestionResults object."""
159
160
        return SuggestionResults(
161
            (batch.filter(limit, threshold) for batch in self.batches)
162
        )
163
164
    def __iter__(self):
165
        return iter(itertools.chain.from_iterable(self.batches))
166