Passed
Push — issue678-refactor-suggestionre... ( 311240...092cdc )
by Osma
03:05
created

annif.suggestion.SuggestionResult.as_vector()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""Representing suggested subjects."""
2
3
import collections
4
import itertools
5
6
import numpy as np
7
from scipy.sparse import csr_array, dok_array
8
9
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
10
WeightedSuggestionsBatch = collections.namedtuple(
11
    "WeightedSuggestionsBatch", "hit_sets weight subjects"
12
)
13
14
15
def vector_to_suggestions(vector, limit):
16
    limit = min(len(vector), limit)
17
    topk_idx = np.argpartition(vector, -limit)[-limit:]
18
    return (
19
        SubjectSuggestion(subject_id=idx, score=float(vector[idx])) for idx in topk_idx
20
    )
21
22
23
def filter_suggestion(preds, limit=None, threshold=0.0):
24
    """filter a 2D sparse suggestion array (csr_array), retaining only the
25
    top K suggestions with a score above or equal to the threshold for each
26
    individual prediction; the rest will be left as zeros"""
27
28
    filtered = dok_array(preds.shape, dtype=np.float32)
29
    for row in range(preds.shape[0]):
30
        arow = preds.getrow(row)
31
        top_k = arow.data.argsort()[::-1]
32
        if limit is not None:
33
            top_k = top_k[:limit]
34
        for idx in top_k:
35
            val = arow.data[idx]
36
            if val < threshold:
37
                break
38
            filtered[row, arow.indices[idx]] = val
39
    return filtered.tocsr()
40
41
42
class SuggestionResult:
43
    """Suggestions for a single document, backed by a row of a sparse array."""
44
45
    def __init__(self, array, idx):
46
        self._array = array
47
        self._idx = idx
48
49
    def __iter__(self):
50
        _, cols = self._array[[self._idx], :].nonzero()
51
        suggestions = [
52
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
53
            for col in cols
54
        ]
55
        return iter(
56
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
57
        )
58
59
    def as_vector(self):
60
        return self._array[[self._idx], :].toarray()[0]
61
62
    def __len__(self):
63
        _, cols = self._array[[self._idx], :].nonzero()
64
        return len(cols)
65
66
67
class SuggestionBatch:
68
    """Subject suggestions for a batch of documents."""
69
70
    def __init__(self, array):
71
        """Create a new SuggestionBatch from a csr_array"""
72
        assert isinstance(array, csr_array)
73
        self.array = array
74
75
    @classmethod
76
    def from_sequence(cls, suggestion_results, subject_index, limit=None):
77
        """Create a new SuggestionBatch from a sequence where each item is
78
        a sequence of SubjectSuggestion objects."""
79
80
        deprecated = set(subject_index.deprecated_ids())
81
82
        ar = dok_array((len(suggestion_results), len(subject_index)), dtype=np.float32)
83
        for idx, result in enumerate(suggestion_results):
84
            for suggestion in itertools.islice(result, limit):
85
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
86
                    continue
87
                ar[idx, suggestion.subject_id] = min(suggestion.score, 1.0)
88
        return cls(ar.tocsr())
89
90
    @classmethod
91
    def from_averaged(cls, batches, weights):
92
        """Create a new SuggestionBatch where the subject scores are the
93
        weighted average of scores in several SuggestionBatches"""
94
95
        avg_array = sum(
96
            [batch.array * weight for batch, weight in zip(batches, weights)]
97
        ) / sum(weights)
98
        return SuggestionBatch(avg_array)
99
100
    def filter(self, limit=None, threshold=0.0):
101
        """Return a subset of the hits, filtered by the given limit and
102
        score threshold, as another SuggestionBatch object."""
103
104
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
105
106
    def __getitem__(self, idx):
107
        if idx < 0 or idx >= len(self):
108
            raise IndexError
109
        return SuggestionResult(self.array, idx)
110
111
    def __len__(self):
112
        return self.array.shape[0]
113
114
115
class SuggestionResults:
116
    """Subject suggestions for a potentially very large number of documents."""
117
118
    def __init__(self, batches):
119
        """Initialize a new SuggestionResults from an iterable that provides
120
        SuggestionBatch objects."""
121
122
        self.batches = batches
123
124
    def filter(self, limit=None, threshold=0.0):
125
        """Return a view of these suggestions, filtered by the given limit
126
        and/or threshold, as another SuggestionResults object."""
127
128
        return SuggestionResults(
129
            (batch.filter(limit, threshold) for batch in self.batches)
130
        )
131
132
    def __iter__(self):
133
        return iter(itertools.chain.from_iterable(self.batches))
134