Passed
Pull Request — main (#840)
by Osma
08:16 queued 04:08
created

annif.vocab.SubjectIndex.active()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
"""Vocabulary management functionality for Annif"""
2
3
from __future__ import annotations
4
5
import csv
6
import os.path
7
from typing import TYPE_CHECKING
8
9
import annif
10
import annif.corpus
11
import annif.util
12
from annif.datadir import DatadirMixin
13
from annif.exception import NotInitializedException
14
15
if TYPE_CHECKING:
16
    from rdflib.graph import Graph
17
18
    from annif.corpus.skos import SubjectFileSKOS
19
    from annif.corpus.subject import Subject, SubjectCorpus
20
21
22
logger = annif.logger
23
logger.addFilter(annif.util.DuplicateFilter())
24
25
26
class AnnifVocabulary(DatadirMixin):
27
    """Class representing a subject vocabulary which can be used by multiple
28
    Annif projects."""
29
30
    # defaults for uninitialized instances
31
    _subjects = None
32
33
    # constants
34
    INDEX_FILENAME_DUMP = "subjects.dump.gz"
35
    INDEX_FILENAME_TTL = "subjects.ttl"
36
    INDEX_FILENAME_CSV = "subjects.csv"
37
38
    def __init__(self, vocab_id: str, datadir: str) -> None:
39
        DatadirMixin.__init__(self, datadir, "vocabs", vocab_id)
40
        self.vocab_id = vocab_id
41
        self._skos_vocab = None
42
43
    def _create_subject_index(self, subject_corpus: SubjectCorpus) -> SubjectIndex:
44
        subjects = SubjectIndex()
45
        subjects.load_subjects(subject_corpus)
46
        annif.util.atomic_save(subjects, self.datadir, self.INDEX_FILENAME_CSV)
47
        return subjects
48
49
    def _update_subject_index(self, subject_corpus: SubjectCorpus) -> SubjectIndex:
50
        old_subjects = self.subjects
51
        new_subjects = SubjectIndex()
52
        new_subjects.load_subjects(subject_corpus)
53
        updated_subjects = SubjectIndex()
54
55
        for old_subject in old_subjects:
56
            if new_subjects.contains_uri(old_subject.uri):
57
                new_subject = new_subjects[new_subjects.by_uri(old_subject.uri)]
58
            else:  # subject removed from new corpus
59
                new_subject = annif.corpus.Subject(
60
                    uri=old_subject.uri, labels=None, notation=None
61
                )
62
            updated_subjects.append(new_subject)
63
        for new_subject in new_subjects:
64
            if not old_subjects.contains_uri(new_subject.uri):
65
                updated_subjects.append(new_subject)
66
        annif.util.atomic_save(updated_subjects, self.datadir, self.INDEX_FILENAME_CSV)
67
        return updated_subjects
68
69
    @property
70
    def subjects(self) -> SubjectIndex:
71
        if self._subjects is None:
72
            path = os.path.join(self.datadir, self.INDEX_FILENAME_CSV)
73
            if os.path.exists(path):
74
                logger.debug("loading subjects from %s", path)
75
                self._subjects = SubjectIndex.load(path)
76
            else:
77
                raise NotInitializedException("subject file {} not found".format(path))
78
        return self._subjects
79
80
    @property
81
    def skos(self) -> SubjectFileSKOS:
82
        """return the subject vocabulary from SKOS file"""
83
        if self._skos_vocab is not None:
84
            return self._skos_vocab
85
86
        # attempt to load graph from dump file
87
        dumppath = os.path.join(self.datadir, self.INDEX_FILENAME_DUMP)
88
        if os.path.exists(dumppath):
89
            logger.debug(f"loading graph dump from {dumppath}")
90
            try:
91
                self._skos_vocab = annif.corpus.SubjectFileSKOS(dumppath)
92
            except ModuleNotFoundError:
93
                # Probably dump has been saved using a different rdflib version
94
                logger.debug("could not load graph dump, using turtle file")
95
            else:
96
                return self._skos_vocab
97
98
        # graph dump file not found - parse ttl file instead
99
        path = os.path.join(self.datadir, self.INDEX_FILENAME_TTL)
100
        if os.path.exists(path):
101
            logger.debug(f"loading graph from {path}")
102
            self._skos_vocab = annif.corpus.SubjectFileSKOS(path)
103
            # store the dump file so we can use it next time
104
            self._skos_vocab.save_skos(path)
105
            return self._skos_vocab
106
107
        raise NotInitializedException(f"graph file {path} not found")
108
109
    def __len__(self) -> int:
110
        return len(self.subjects)
111
112
    @property
113
    def languages(self) -> list[str]:
114
        try:
115
            return self.subjects.languages
116
        except NotInitializedException:
117
            return []
118
119
    def load_vocabulary(
120
        self,
121
        subject_corpus: SubjectCorpus,
122
        force: bool = False,
123
    ) -> None:
124
        """Load subjects from a subject corpus and save them into one
125
        or more subject index files as well as a SKOS/Turtle file for later
126
        use. If force=True, replace the existing subject index completely."""
127
128
        if not force and os.path.exists(
129
            os.path.join(self.datadir, self.INDEX_FILENAME_CSV)
130
        ):
131
            logger.info("updating existing subject index")
132
            self._subjects = self._update_subject_index(subject_corpus)
133
        else:
134
            logger.info("creating subject index")
135
            self._subjects = self._create_subject_index(subject_corpus)
136
137
        skosfile = os.path.join(self.datadir, self.INDEX_FILENAME_TTL)
138
        logger.info(f"saving vocabulary into SKOS file {skosfile}")
139
        subject_corpus.save_skos(skosfile)
140
141
    def as_graph(self) -> Graph:
142
        """return the vocabulary as an rdflib graph"""
143
        return self.skos.graph
144
145
    def dump(self) -> dict[str, str | list | int | bool]:
146
        """return this vocabulary as a dict"""
147
148
        try:
149
            languages = list(sorted(self.languages))
150
            size = len(self)
151
            loaded = True
152
        except NotInitializedException:
153
            languages = []
154
            size = None
155
            loaded = False
156
157
        return {
158
            "vocab_id": self.vocab_id,
159
            "languages": languages,
160
            "size": size,
161
            "loaded": loaded,
162
        }
163
164
165
class SubjectIndex:
166
    """An index that remembers the associations between integers subject IDs
167
    and their URIs and labels."""
168
169
    def __init__(self) -> None:
170
        self._subjects = []
171
        self._uri_idx = {}
172
        self._label_idx = {}
173
        self._languages = None
174
175
    def load_subjects(self, corpus: SubjectCorpus) -> None:
176
        """Initialize the subject index from a subject corpus"""
177
178
        self._languages = corpus.languages
179
        for subject in corpus.subjects:
180
            self.append(subject)
181
182
    def __len__(self) -> int:
183
        return len(self._subjects)
184
185
    @property
186
    def languages(self) -> list[str] | None:
187
        return self._languages
188
189
    def __getitem__(self, subject_id: int) -> Subject:
190
        return self._subjects[subject_id]
191
192
    def append(self, subject: Subject) -> None:
193
        if self._languages is None and subject.labels is not None:
194
            self._languages = list(subject.labels.keys())
195
196
        subject_id = len(self._subjects)
197
        self._uri_idx[subject.uri] = subject_id
198
        if subject.labels:
199
            for lang, label in subject.labels.items():
200
                self._label_idx[(label, lang)] = subject_id
201
        self._subjects.append(subject)
202
203
    def contains_uri(self, uri: str) -> bool:
204
        return uri in self._uri_idx
205
206
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
207
        """return the subject ID of a subject by its URI, or None if not found.
208
        If warnings=True, log a warning message if the URI cannot be found."""
209
        try:
210
            return self._uri_idx[uri]
211
        except KeyError:
212
            if warnings:
213
                logger.warning("Unknown subject URI <%s>", uri)
214
            return None
215
216
    def by_label(self, label: str | None, language: str) -> int | None:
217
        """return the subject ID of a subject by its label in a given
218
        language"""
219
        try:
220
            return self._label_idx[(label, language)]
221
        except KeyError:
222
            logger.warning('Unknown subject label "%s"@%s', label, language)
223
            return None
224
225
    def deprecated_ids(self) -> list[int]:
226
        """return indices of deprecated subjects"""
227
228
        return [
229
            subject_id
230
            for subject_id, subject in enumerate(self._subjects)
231
            if subject.labels is None
232
        ]
233
234
    @property
235
    def active(self) -> list[tuple[int, Subject]]:
236
        """return a list of (subject_id, subject) tuples of all subjects that
237
        are not deprecated"""
238
239
        return [
240
            (subj_id, subject)
241
            for subj_id, subject in enumerate(self._subjects)
242
            if subject.labels is not None
243
        ]
244
245
    def save(self, path: str) -> None:
246
        """Save this subject index into a file with the given path name."""
247
248
        fieldnames = ["uri", "notation"] + [f"label_{lang}" for lang in self._languages]
249
250
        with open(path, "w", encoding="utf-8", newline="") as csvfile:
251
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
252
            writer.writeheader()
253
            for subject in self:
254
                row = {"uri": subject.uri, "notation": subject.notation or ""}
255
                if subject.labels:
256
                    for lang, label in subject.labels.items():
257
                        row[f"label_{lang}"] = label
258
                writer.writerow(row)
259
260
    @classmethod
261
    def load(cls, path: str) -> SubjectIndex:
262
        """Load a subject index from a CSV file and return it."""
263
264
        corpus = annif.corpus.SubjectFileCSV(path)
265
        subject_index = cls()
266
        subject_index.load_subjects(corpus)
267
        return subject_index
268