Passed
Push — issue735-subject-filtering ( d4533d...f9dfa6 )
by Osma
03:38
created

SubjectIndexFile.load_subjects()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Subject index functionality for Annif"""
2
3
from __future__ import annotations
4
5
import csv
6
from typing import TYPE_CHECKING
7
8
import annif
9
import annif.corpus
10
import annif.util
11
12
from .types import SubjectIndex
13
14
if TYPE_CHECKING:
15
16
    from annif.corpus.subject import Subject, SubjectCorpus
17
18
19
logger = annif.logger
20
logger.addFilter(annif.util.DuplicateFilter())
21
22
23
class SubjectIndexFile(SubjectIndex):
24
    """SubjectIndex implementation backed by a file."""
25
26
    def __init__(self) -> None:
27
        self._subjects = []
28
        self._uri_idx = {}
29
        self._label_idx = {}
30
        self._languages = None
31
32
    def load_subjects(self, corpus: SubjectCorpus) -> None:
33
        """Initialize the subject index from a subject corpus"""
34
35
        self._languages = corpus.languages
36
        for subject in corpus.subjects:
37
            self.append(subject)
38
39
    def __len__(self) -> int:
40
        return len(self._subjects)
41
42
    @property
43
    def languages(self) -> list[str] | None:
44
        return self._languages
45
46
    def __getitem__(self, subject_id: int) -> Subject:
47
        return self._subjects[subject_id]
48
49
    def append(self, subject: Subject) -> None:
50
        if self._languages is None and subject.labels is not None:
51
            self._languages = list(subject.labels.keys())
52
53
        subject_id = len(self._subjects)
54
        self._uri_idx[subject.uri] = subject_id
55
        if subject.labels:
56
            for lang, label in subject.labels.items():
57
                self._label_idx[(label, lang)] = subject_id
58
        self._subjects.append(subject)
59
60
    def contains_uri(self, uri: str) -> bool:
61
        return uri in self._uri_idx
62
63
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
64
        try:
65
            return self._uri_idx[uri]
66
        except KeyError:
67
            if warnings:
68
                logger.warning("Unknown subject URI <%s>", uri)
69
            return None
70
71
    def by_label(self, label: str | None, language: str) -> int | None:
72
        try:
73
            return self._label_idx[(label, language)]
74
        except KeyError:
75
            logger.warning('Unknown subject label "%s"@%s', label, language)
76
            return None
77
78
    @property
79
    def active(self) -> list[tuple[int, Subject]]:
80
        return [
81
            (subj_id, subject)
82
            for subj_id, subject in enumerate(self._subjects)
83
            if subject.labels is not None
84
        ]
85
86
    def save(self, path: str) -> None:
87
        """Save this subject index into a file with the given path name."""
88
89
        fieldnames = ["uri", "notation"] + [f"label_{lang}" for lang in self._languages]
90
91
        with open(path, "w", encoding="utf-8", newline="") as csvfile:
92
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
93
            writer.writeheader()
94
            for subject in self:
95
                row = {"uri": subject.uri, "notation": subject.notation or ""}
96
                if subject.labels:
97
                    for lang, label in subject.labels.items():
98
                        row[f"label_{lang}"] = label
99
                writer.writerow(row)
100
101
    @classmethod
102
    def load(cls, path: str) -> SubjectIndex:
103
        """Load a subject index from a CSV file and return it."""
104
105
        corpus = annif.corpus.SubjectFileCSV(path)
106
        subject_index = cls()
107
        subject_index.load_subjects(corpus)
108
        return subject_index
109
110
111
class SubjectIndexFilter(SubjectIndex):
112
    """SubjectIndex implementation that filters another SubjectIndex based
113
    on a list of subject URIs to exclude."""
114
115
    def __init__(self, subject_index: SubjectIndex, exclude: list[str]):
116
        self._subject_index = subject_index
117
        self._exclude = set(exclude)
118
119
    def __len__(self) -> int:
120
        return len(self._subject_index)
121
122
    @property
123
    def languages(self) -> list[str] | None:
124
        return self._subject_index.languages
125
126
    def __getitem__(self, subject_id: int) -> Subject:
127
        subject = self._subject_index[subject_id]
128
        if subject and subject.uri not in self._exclude:
129
            return subject
130
        return None
131
132
    def contains_uri(self, uri: str) -> bool:
133
        if uri in self._exclude:
134
            return False
135
        return self._subject_index.contains_uri(uri)
136
137
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
138
        """return the subject ID of a subject by its URI, or None if not found.
139
        If warnings=True, log a warning message if the URI cannot be found."""
140
141
        if uri in self._exclude:
142
            return None
143
        return self._subject_index.by_uri(uri, warnings)
144
145
    def by_label(self, label: str | None, language: str) -> int | None:
146
        """return the subject ID of a subject by its label in a given
147
        language"""
148
149
        subject_id = self._subject_index.by_label(label, language)
150
        subject = self._subject_index[subject_id]
151
        if subject is not None and subject.uri not in self._exclude:
152
            return subject_id
153
        return None
154
155
    @property
156
    def active(self) -> list[tuple[int, Subject]]:
157
        """return a list of (subject_id, Subject) tuples of all subjects that
158
        are available for use"""
159
160
        return [
161
            (subject_id, subject)
162
            for subject_id, subject in self._subject_index.active
163
            if subject.uri not in self._exclude
164
        ]
165