Passed
Pull Request — main (#825)
by Osma
02:53
created

annif.lexical.tokenset.TokenSetIndex.__init__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""Index for fast matching of token sets."""
2
3
from __future__ import annotations
4
5
import collections
6
from typing import TYPE_CHECKING
7
8
if TYPE_CHECKING:
9
    from numpy import ndarray
10
11
12
class TokenSet:
13
    """Represents a set of tokens (expressed as integer token IDs) that can
14
    be matched with another set of tokens. A TokenSet can optionally
15
    be associated with a subject from the vocabulary."""
16
17
    def __init__(
18
        self,
19
        tokens: ndarray,
20
        subject_id: int | None = None,
21
        is_pref: bool = False,
22
    ) -> None:
23
        self._tokens = frozenset(tokens)
24
        self.key = tokens[0] if len(tokens) else None
25
        self.subject_id = subject_id
26
        self.is_pref = is_pref
27
28
    def __len__(self) -> int:
29
        return len(self._tokens)
30
31
    def __iter__(self):
32
        return iter(self._tokens)
33
34
    @property
35
    def tokens(self) -> frozenset:
36
        return self._tokens
37
38
    def contains(self, other: TokenSet) -> bool:
39
        """Returns True iff the tokens in the other TokenSet are all
40
        included within this TokenSet."""
41
42
        return other._tokens.issubset(self._tokens)
43
44
45
class TokenSetIndex:
46
    """A searchable index of TokenSets (representing vocabulary terms)"""
47
48
    def __init__(self) -> None:
49
        self._index = collections.defaultdict(set)
50
51
    def __len__(self) -> int:
52
        return len(self._index)
53
54
    def add(self, tset: TokenSet) -> None:
55
        """Add a TokenSet into this index"""
56
        if tset.key is not None:
57
            self._index[tset.key].add(tset)
58
59
    def _find_subj_tsets(self, tset: TokenSet) -> dict[int | None, TokenSet]:
60
        """return a dict (subject_id : TokenSet) of matches contained in the
61
        given TokenSet"""
62
63
        subj_tsets = {}
64
65
        for token in tset:
66
            for ts in self._index[token]:
67
                if tset.contains(ts) and (
68
                    ts.subject_id not in subj_tsets
69
                    or not subj_tsets[ts.subject_id].is_pref
70
                ):
71
                    subj_tsets[ts.subject_id] = ts
72
73
        return subj_tsets
74
75
    def _find_subj_ambiguity(self, tsets: list[TokenSet]):
76
        """calculate the ambiguity values (the number of other TokenSets
77
        that also match the same tokens) for the given TokenSets and return
78
        them as a dict-like object (subject_id : ambiguity_value)"""
79
80
        # group the TokenSets by their tokens, so that TokenSets with
81
        # identical tokens can be considered together all in one go
82
        elim_tsets = collections.defaultdict(set)
83
        for ts in tsets:
84
            elim_tsets[ts.tokens].add(ts.subject_id)
85
86
        subj_ambiguity = collections.Counter()
87
88
        for tokens1, subjs1 in elim_tsets.items():
89
            for tokens2, subjs2 in elim_tsets.items():
90
                if not tokens2.issuperset(tokens1):
91
                    continue
92
                for subj in subjs1:
93
                    subj_ambiguity[subj] += len(subjs2) - int(subj in subjs2)
94
95
        return subj_ambiguity
96
97
    def search(self, tset: TokenSet) -> list[tuple[TokenSet, int]]:
98
        """Return the TokenSets that are contained in the given TokenSet.
99
        The matches are returned as a list of (TokenSet, ambiguity) pairs
100
        where ambiguity is an integer indicating the number of other TokenSets
101
        that also match the same tokens."""
102
103
        subj_tsets = self._find_subj_tsets(tset)
104
        subj_ambiguity = self._find_subj_ambiguity(subj_tsets.values())
105
106
        return [
107
            (ts, subj_ambiguity[subject_id]) for subject_id, ts in subj_tsets.items()
108
        ]
109