Passed
Push — issue678-refactor-suggestionre... ( 5b9518...548b8b )
by Osma
15:45 queued 13:03
created

annif.suggestion.VectorSuggestionResult.__len__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Representing suggested subjects."""
2
3
import collections
4
import itertools
5
6
import numpy as np
7
from scipy.sparse import csr_array, dok_array
8
9
SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
10
WeightedSuggestionsBatch = collections.namedtuple(
11
    "WeightedSuggestionsBatch", "hit_sets weight subjects"
12
)
13
14
15
def filter_suggestion(preds, limit=None, threshold=0.0):
16
    """filter a 2D sparse suggestion array (csr_array), retaining only the
17
    top K suggestions with a score above or equal to the threshold for each
18
    individual prediction; the rest will be left as zeros"""
19
20
    filtered = dok_array(preds.shape, dtype=np.float32)
21
    for row in range(preds.shape[0]):
22
        arow = preds.getrow(row)
23
        top_k = arow.data.argsort()[::-1]
24
        if limit is not None:
25
            top_k = top_k[:limit]
26
        for idx in top_k:
27
            val = arow.data[idx]
28
            if val < threshold:
29
                break
30
            filtered[row, arow.indices[idx]] = val
31
    return filtered.tocsr()
32
33
34
class SuggestionResult:
35
    """Suggestions for a single document, backed by a row of a sparse array."""
36
37
    def __init__(self, array, idx):
38
        self._array = array
39
        self._idx = idx
40
41
    def __iter__(self):
42
        _, cols = self._array[[self._idx], :].nonzero()
43
        suggestions = [
44
            SubjectSuggestion(subject_id=col, score=float(self._array[self._idx, col]))
45
            for col in cols
46
        ]
47
        return iter(
48
            sorted(suggestions, key=lambda suggestion: suggestion.score, reverse=True)
49
        )
50
51
    def as_vector(self):
52
        return self._array[[self._idx], :].toarray()[0]
53
54
    def __len__(self):
55
        _, cols = self._array[[self._idx], :].nonzero()
56
        return len(cols)
57
58
59
class SuggestionBatch:
60
    """Subject suggestions for a batch of documents."""
61
62
    def __init__(self, array):
63
        """Create a new SuggestionBatch from a csr_array"""
64
        assert isinstance(array, csr_array)
65
        self.array = array
66
67
    @staticmethod
68
    def _vector_to_suggestions(vector):
69
        hits = []
70
        for subject_id in np.argsort(vector)[::-1]:
71
            score = vector[subject_id]
72
            if score <= 0.0:
73
                break  # we can skip the remaining ones
74
            hits.append(SubjectSuggestion(subject_id=subject_id, score=float(score)))
75
        return hits
76
77
    @classmethod
78
    def from_sequence(cls, suggestion_results, subject_index, limit=None):
79
        """Create a new SuggestionBatch from a sequence of SuggestionResult objects."""
80
81
        deprecated = set(subject_index.deprecated_ids())
82
83
        ar = dok_array((len(suggestion_results), len(subject_index)), dtype=np.float32)
84
        for idx, result in enumerate(suggestion_results):
85
            if isinstance(result, np.ndarray):
86
                result = cls._vector_to_suggestions(result)
87
            for suggestion in itertools.islice(result, limit):
88
                if suggestion.subject_id in deprecated or suggestion.score <= 0.0:
89
                    continue
90
                ar[idx, suggestion.subject_id] = min(suggestion.score, 1.0)
91
        return cls(ar.tocsr())
92
93
    def filter(self, limit=None, threshold=0.0):
94
        """Return a subset of the hits, filtered by the given limit and
95
        score threshold, as another SuggestionBatch object."""
96
97
        return SuggestionBatch(filter_suggestion(self.array, limit, threshold))
98
99
    def __getitem__(self, idx):
100
        if idx < 0 or idx >= len(self):
101
            raise IndexError
102
        return SuggestionResult(self.array, idx)
103
104
    def __len__(self):
105
        return self.array.shape[0]
106
107
108
class SuggestionResults:
109
    """Subject suggestions for a potentially very large number of documents."""
110
111
    def __init__(self, batches):
112
        """Initialize a new SuggestionResults from an iterable that provides
113
        SuggestionBatch objects."""
114
115
        self.batches = batches
116
117
    def filter(self, limit=None, threshold=0.0):
118
        """Return a view of these suggestions, filtered by the given limit
119
        and/or threshold, as another SuggestionResults object."""
120
121
        return SuggestionResults(
122
            (batch.filter(limit, threshold) for batch in self.batches)
123
        )
124
125
    def __iter__(self):
126
        return iter(itertools.chain.from_iterable(self.batches))
127