Completed
Push — master ( 299d84...ccff81 )
by Osma
13s queued 11s
created

annif.hit.LazyAnalysisResult.__getitem__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Representing hits from analysis."""
2
3
import abc
4
import collections
5
import numpy as np
6
7
8
AnalysisHit = collections.namedtuple('AnalysisHit', 'uri label score')
9
WeightedHits = collections.namedtuple('WeightedHits', 'hits weight')
10
11
12
class HitFilter:
13
    """A reusable filter for filtering AnalysisHit objects."""
14
15
    def __init__(self, limit=None, threshold=0.0):
16
        self._limit = limit
17
        self._threshold = threshold
18
19
    def __call__(self, orighits):
20
        return LazyAnalysisResult(
21
            lambda: orighits.filter(
22
                self._limit, self._threshold))
23
24
25
class AnalysisResult(metaclass=abc.ABCMeta):
26
    """Abstract base class for a set of hits returned by an analysis
27
    operation."""
28
29
    @property
30
    @abc.abstractmethod
31
    def hits(self):
32
        """Return the hits as an ordered sequence of AnalysisHit objects,
33
        highest scores first."""
34
        pass
35
36
    @property
37
    @abc.abstractmethod
38
    def vector(self):
39
        """Return the hits as a one-dimensional score vector
40
        where the indexes match the given subject index."""
41
        pass
42
43
    @abc.abstractmethod
44
    def filter(self, limit=None, threshold=0.0):
45
        """Return a subset of the hits, filtered by the given limit and
46
        score threshold, as another AnalysisResult object."""
47
        pass
48
49
    @abc.abstractmethod
50
    def __len__(self):
51
        """Return the number of hits with non-zero scores."""
52
        pass
53
54
    def __getitem__(self, idx):
55
        return self.hits[idx]
56
57
58
class LazyAnalysisResult(AnalysisResult):
59
    """AnalysisResult implementation that wraps another AnalysisResult which
60
    is initialized lazily only when it is actually accessed. Method calls
61
    will be proxied to the wrapped AnalysisResult."""
62
63
    def __init__(self, construct):
64
        """Create the proxy object. The given construct function will be
65
        called to create the actual AnalysisResult when it is needed."""
66
        self._construct = construct
67
        self._object = None
68
69
    def _initialize(self):
70
        if self._object is None:
71
            self._object = self._construct()
72
73
    @property
74
    def hits(self):
75
        self._initialize()
76
        return self._object.hits
77
78
    @property
79
    def vector(self):
80
        self._initialize()
81
        return self._object.vector
82
83
    def filter(self, limit=None, threshold=0.0):
84
        self._initialize()
85
        return self._object.filter(limit, threshold)
86
87
    def __len__(self):
88
        self._initialize()
89
        return len(self._object)
90
91
    def __getitem__(self, idx):
92
        self._initialize()
93
        return self._object[idx]
94
95
96
class VectorAnalysisResult(AnalysisResult):
97
    """AnalysisResult implementation based primarily on NumPy vectors."""
98
99
    def __init__(self, vector, subject_index):
100
        self._vector = vector
101
        self._subject_index = subject_index
102
        self._subject_order = None
103
        self._hits = None
104
105
    def _vector_to_hits(self):
106
        hits = []
107
        for subject_id in self.subject_order:
108
            score = self._vector[subject_id]
109
            if score <= 0.0:
110
                continue  # we can skip the remaining ones
111
            subject = self._subject_index[subject_id]
112
            hits.append(
113
                AnalysisHit(
114
                    uri=subject[0],
115
                    label=subject[1],
116
                    score=score))
117
        return ListAnalysisResult(hits, self._subject_index)
118
119
    @property
120
    def subject_order(self):
121
        if self._subject_order is None:
122
            self._subject_order = np.argsort(self._vector)[::-1]
123
        return self._subject_order
124
125
    @property
126
    def hits(self):
127
        if self._hits is None:
128
            self._hits = self._vector_to_hits()
129
        return self._hits
130
131
    @property
132
    def vector(self):
133
        return self._vector
134
135
    def filter(self, limit=None, threshold=0.0):
136
        mask = (self._vector > threshold)
137
        if limit is not None:
138
            limit_mask = np.zeros(len(self._vector), dtype=np.bool)
139
            top_k_subjects = self.subject_order[:limit]
140
            limit_mask[top_k_subjects] = True
141
            mask = mask & limit_mask
142
        return VectorAnalysisResult(self._vector * mask, self._subject_index)
143
144
    def __len__(self):
145
        return (self._vector > 0.0).sum()
146
147
148
class ListAnalysisResult(AnalysisResult):
149
    """AnalysisResult implementation based primarily on lists of hits."""
150
151
    def __init__(self, hits, subject_index):
152
        self._hits = [hit for hit in hits if hit.score > 0.0]
153
        self._subject_index = subject_index
154
        self._vector = None
155
156
    def _hits_to_vector(self):
157
        vector = np.zeros(len(self._subject_index))
158
        for hit in self._hits:
159
            subject_id = self._subject_index.by_uri(hit.uri)
160
            if subject_id is not None:
161
                vector[subject_id] = hit.score
162
        return vector
163
164
    @property
165
    def hits(self):
166
        return self._hits
167
168
    @property
169
    def vector(self):
170
        if self._vector is None:
171
            self._vector = self._hits_to_vector()
172
        return self._vector
173
174
    def filter(self, limit=None, threshold=0.0):
175
        hits = sorted(self.hits, key=lambda hit: hit.score, reverse=True)
176
        if limit is not None:
177
            hits = hits[:limit]
178
        return ListAnalysisResult([hit for hit in hits
179
                                   if hit.score >= threshold and
180
                                   hit.score > 0.0],
181
                                  self._subject_index)
182
183
    def __len__(self):
184
        return len(self._hits)
185