Passed
Pull Request — main (#840)
by Osma
07:08 queued 03:40
created

annif.vocab.subject_index   A

Complexity

Total Complexity 39

Size/Duplication

Total Lines 170
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 117
dl 0
loc 170
rs 9.28
c 0
b 0
f 0
wmc 39

20 Methods

Rating   Name   Duplication   Size   Complexity  
A SubjectIndexFile.load_subjects() 0 6 2
A SubjectIndexFile.languages() 0 3 1
A SubjectIndexFile.__init__() 0 5 1
A SubjectIndexFile.__len__() 0 2 1
A SubjectIndexFilter.by_uri() 0 7 2
A SubjectIndexFilter.contains_uri() 0 4 2
A SubjectIndexFile.save() 0 14 5
A SubjectIndexFilter.by_label() 0 9 3
A SubjectIndexFilter.languages() 0 3 1
A SubjectIndexFile.by_label() 0 6 2
A SubjectIndexFilter.active() 0 9 1
A SubjectIndexFile.__getitem__() 0 5 2
A SubjectIndexFile.contains_uri() 0 2 1
A SubjectIndexFile.load() 0 8 1
A SubjectIndexFilter.__getitem__() 0 5 2
A SubjectIndexFile.active() 0 6 1
A SubjectIndexFilter.__init__() 0 3 1
A SubjectIndexFile.append() 0 10 5
A SubjectIndexFile.by_uri() 0 10 4
A SubjectIndexFilter.__len__() 0 2 1
1
"""Subject index functionality for Annif"""
2
3
from __future__ import annotations
4
5
import csv
6
from typing import TYPE_CHECKING
7
8
import annif
9
import annif.corpus
10
import annif.util
11
12
from .types import SubjectIndex
13
14
if TYPE_CHECKING:
15
16
    from annif.corpus.subject import Subject, SubjectCorpus
17
18
19
logger = annif.logger
20
logger.addFilter(annif.util.DuplicateFilter())
21
22
23
class SubjectIndexFile(SubjectIndex):
24
    """SubjectIndex implementation backed by a file."""
25
26
    def __init__(self) -> None:
27
        self._subjects = []
28
        self._uri_idx = {}
29
        self._label_idx = {}
30
        self._languages = None
31
32
    def load_subjects(self, corpus: SubjectCorpus) -> None:
33
        """Initialize the subject index from a subject corpus"""
34
35
        self._languages = corpus.languages
36
        for subject in corpus.subjects:
37
            self.append(subject)
38
39
    def __len__(self) -> int:
40
        return len(self._subjects)
41
42
    @property
43
    def languages(self) -> list[str] | None:
44
        return self._languages
45
46
    def __getitem__(self, subject_id: int) -> Subject:
47
        subject = self._subjects[subject_id]
48
        if subject.labels is None:
49
            raise IndexError(f"Subject is deprecated: {subject_id}")
50
        return subject
51
52
    def append(self, subject: Subject) -> None:
53
        if self._languages is None and subject.labels is not None:
54
            self._languages = list(subject.labels.keys())
55
56
        subject_id = len(self._subjects)
57
        self._uri_idx[subject.uri] = subject_id
58
        if subject.labels:
59
            for lang, label in subject.labels.items():
60
                self._label_idx[(label, lang)] = subject_id
61
        self._subjects.append(subject)
62
63
    def contains_uri(self, uri: str) -> bool:
64
        return uri in self._uri_idx
65
66
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
67
        try:
68
            subject_id = self._uri_idx[uri]
69
            if self._subjects[subject_id].labels is None:  # deprecated
70
                return None
71
            return subject_id
72
        except KeyError:
73
            if warnings:
74
                logger.warning("Unknown subject URI <%s>", uri)
75
            return None
76
77
    def by_label(self, label: str | None, language: str) -> int | None:
78
        try:
79
            return self._label_idx[(label, language)]
80
        except KeyError:
81
            logger.warning('Unknown subject label "%s"@%s', label, language)
82
            return None
83
84
    @property
85
    def active(self) -> list[tuple[int, Subject]]:
86
        return [
87
            (subj_id, subject)
88
            for subj_id, subject in enumerate(self._subjects)
89
            if subject.labels is not None
90
        ]
91
92
    def save(self, path: str) -> None:
93
        """Save this subject index into a file with the given path name."""
94
95
        fieldnames = ["uri", "notation"] + [f"label_{lang}" for lang in self._languages]
96
97
        with open(path, "w", encoding="utf-8", newline="") as csvfile:
98
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
99
            writer.writeheader()
100
            for subject in self:
101
                row = {"uri": subject.uri, "notation": subject.notation or ""}
102
                if subject.labels:
103
                    for lang, label in subject.labels.items():
104
                        row[f"label_{lang}"] = label
105
                writer.writerow(row)
106
107
    @classmethod
108
    def load(cls, path: str) -> SubjectIndex:
109
        """Load a subject index from a CSV file and return it."""
110
111
        vocab_file = annif.vocab.VocabFileCSV(path)
112
        subject_index = cls()
113
        subject_index.load_subjects(vocab_file)
114
        return subject_index
115
116
117
class SubjectIndexFilter(SubjectIndex):
118
    """SubjectIndex implementation that filters another SubjectIndex based
119
    on a list of subject URIs to exclude."""
120
121
    def __init__(self, subject_index: SubjectIndex, exclude: list[str]):
122
        self._subject_index = subject_index
123
        self._exclude = set(exclude)
124
125
    def __len__(self) -> int:
126
        return len(self._subject_index)
127
128
    @property
129
    def languages(self) -> list[str] | None:
130
        return self._subject_index.languages
131
132
    def __getitem__(self, subject_id: int) -> Subject:
133
        subject = self._subject_index[subject_id]
134
        if subject.uri in self._exclude:
135
            raise IndexError(f"Subject is excluded: {subject.uri}")
136
        return subject
137
138
    def contains_uri(self, uri: str) -> bool:
139
        if uri in self._exclude:
140
            return False
141
        return self._subject_index.contains_uri(uri)
142
143
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
144
        """return the subject ID of a subject by its URI, or None if not found.
145
        If warnings=True, log a warning message if the URI cannot be found."""
146
147
        if uri in self._exclude:
148
            return None
149
        return self._subject_index.by_uri(uri, warnings)
150
151
    def by_label(self, label: str | None, language: str) -> int | None:
152
        """return the subject ID of a subject by its label in a given
153
        language"""
154
155
        subject_id = self._subject_index.by_label(label, language)
156
        subject = self._subject_index[subject_id]
157
        if subject is not None and subject.uri not in self._exclude:
158
            return subject_id
159
        return None
160
161
    @property
162
    def active(self) -> list[tuple[int, Subject]]:
163
        """return a list of (subject_id, Subject) tuples of all subjects that
164
        are available for use"""
165
166
        return [
167
            (subject_id, subject)
168
            for subject_id, subject in self._subject_index.active
169
            if subject.uri not in self._exclude
170
        ]
171