|
1
|
|
|
"""Index for fast matching of token sets.""" |
|
2
|
|
|
|
|
3
|
|
|
from __future__ import annotations |
|
4
|
|
|
|
|
5
|
|
|
import collections |
|
6
|
|
|
from typing import TYPE_CHECKING |
|
7
|
|
|
|
|
8
|
|
|
if TYPE_CHECKING: |
|
9
|
|
|
from numpy import ndarray |
|
10
|
|
|
|
|
11
|
|
|
|
|
12
|
|
|
class TokenSet: |
|
13
|
|
|
"""Represents a set of tokens (expressed as integer token IDs) that can |
|
14
|
|
|
be matched with another set of tokens. A TokenSet can optionally |
|
15
|
|
|
be associated with a subject from the vocabulary.""" |
|
16
|
|
|
|
|
17
|
|
|
def __init__( |
|
18
|
|
|
self, |
|
19
|
|
|
tokens: ndarray, |
|
20
|
|
|
subject_id: int | None = None, |
|
21
|
|
|
is_pref: bool = False, |
|
22
|
|
|
) -> None: |
|
23
|
|
|
self._tokens = frozenset(tokens) |
|
24
|
|
|
self.key = tokens[0] if len(tokens) else None |
|
25
|
|
|
self.subject_id = subject_id |
|
26
|
|
|
self.is_pref = is_pref |
|
27
|
|
|
|
|
28
|
|
|
def __len__(self) -> int: |
|
29
|
|
|
return len(self._tokens) |
|
30
|
|
|
|
|
31
|
|
|
def __iter__(self): |
|
32
|
|
|
return iter(self._tokens) |
|
33
|
|
|
|
|
34
|
|
|
@property |
|
35
|
|
|
def tokens(self) -> frozenset: |
|
36
|
|
|
return self._tokens |
|
37
|
|
|
|
|
38
|
|
|
def contains(self, other: TokenSet) -> bool: |
|
39
|
|
|
"""Returns True iff the tokens in the other TokenSet are all |
|
40
|
|
|
included within this TokenSet.""" |
|
41
|
|
|
|
|
42
|
|
|
return other._tokens.issubset(self._tokens) |
|
43
|
|
|
|
|
44
|
|
|
def __setstate__(self, state): |
|
45
|
|
|
# Convert _tokens to a frozenset if it's a set. |
|
46
|
|
|
# Might happen when using models saved using Annif 1.2 or older. |
|
47
|
|
|
self.__dict__ = state |
|
48
|
|
|
if isinstance(self._tokens, set): |
|
49
|
|
|
self._tokens = frozenset(self._tokens) # pragma: no cover |
|
50
|
|
|
|
|
51
|
|
|
|
|
52
|
|
|
class TokenSetIndex: |
|
53
|
|
|
"""A searchable index of TokenSets (representing vocabulary terms)""" |
|
54
|
|
|
|
|
55
|
|
|
def __init__(self) -> None: |
|
56
|
|
|
self._index = collections.defaultdict(set) |
|
57
|
|
|
|
|
58
|
|
|
def __len__(self) -> int: |
|
59
|
|
|
return len(self._index) |
|
60
|
|
|
|
|
61
|
|
|
def add(self, tset: TokenSet) -> None: |
|
62
|
|
|
"""Add a TokenSet into this index""" |
|
63
|
|
|
if tset.key is not None: |
|
64
|
|
|
self._index[tset.key].add(tset) |
|
65
|
|
|
|
|
66
|
|
|
def _find_subj_tsets(self, tset: TokenSet) -> dict[int | None, TokenSet]: |
|
67
|
|
|
"""return a dict (subject_id : TokenSet) of matches contained in the |
|
68
|
|
|
given TokenSet""" |
|
69
|
|
|
|
|
70
|
|
|
subj_tsets = {} |
|
71
|
|
|
|
|
72
|
|
|
for token in tset: |
|
73
|
|
|
for ts in self._index[token]: |
|
74
|
|
|
if tset.contains(ts) and ( |
|
75
|
|
|
ts.subject_id not in subj_tsets |
|
76
|
|
|
or not subj_tsets[ts.subject_id].is_pref |
|
77
|
|
|
): |
|
78
|
|
|
subj_tsets[ts.subject_id] = ts |
|
79
|
|
|
|
|
80
|
|
|
return subj_tsets |
|
81
|
|
|
|
|
82
|
|
|
def _find_subj_ambiguity(self, tsets: list[TokenSet]): |
|
83
|
|
|
"""calculate the ambiguity values (the number of other TokenSets |
|
84
|
|
|
that also match the same tokens) for the given TokenSets and return |
|
85
|
|
|
them as a dict-like object (subject_id : ambiguity_value)""" |
|
86
|
|
|
|
|
87
|
|
|
# group the TokenSets by their tokens, so that TokenSets with |
|
88
|
|
|
# identical tokens can be considered together all in one go |
|
89
|
|
|
elim_tsets = collections.defaultdict(set) |
|
90
|
|
|
for ts in tsets: |
|
91
|
|
|
elim_tsets[ts.tokens].add(ts.subject_id) |
|
92
|
|
|
|
|
93
|
|
|
subj_ambiguity = collections.Counter() |
|
94
|
|
|
|
|
95
|
|
|
for tokens1, subjs1 in elim_tsets.items(): |
|
96
|
|
|
for tokens2, subjs2 in elim_tsets.items(): |
|
97
|
|
|
if not tokens2.issuperset(tokens1): |
|
98
|
|
|
continue |
|
99
|
|
|
for subj in subjs1: |
|
100
|
|
|
subj_ambiguity[subj] += len(subjs2) - int(subj in subjs2) |
|
101
|
|
|
|
|
102
|
|
|
return subj_ambiguity |
|
103
|
|
|
|
|
104
|
|
|
def search(self, tset: TokenSet) -> list[tuple[TokenSet, int]]: |
|
105
|
|
|
"""Return the TokenSets that are contained in the given TokenSet. |
|
106
|
|
|
The matches are returned as a list of (TokenSet, ambiguity) pairs |
|
107
|
|
|
where ambiguity is an integer indicating the number of other TokenSets |
|
108
|
|
|
that also match the same tokens.""" |
|
109
|
|
|
|
|
110
|
|
|
subj_tsets = self._find_subj_tsets(tset) |
|
111
|
|
|
subj_ambiguity = self._find_subj_ambiguity(subj_tsets.values()) |
|
112
|
|
|
|
|
113
|
|
|
return [ |
|
114
|
|
|
(ts, subj_ambiguity[subject_id]) for subject_id, ts in subj_tsets.items() |
|
115
|
|
|
] |
|
116
|
|
|
|