Passed
Push — issue735-subject-filtering ( 4656c4...e8dd91 )
by Osma
12:23
created

annif.vocab.SubjectIndexFile.__init__()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
"""Vocabulary management functionality for Annif"""
2
3
from __future__ import annotations
4
5
import abc
6
import csv
7
import os.path
8
from typing import TYPE_CHECKING
9
10
import annif
11
import annif.corpus
12
import annif.util
13
from annif.datadir import DatadirMixin
14
from annif.exception import NotInitializedException
15
16
if TYPE_CHECKING:
17
    from rdflib.graph import Graph
18
19
    from annif.corpus.skos import SubjectFileSKOS
20
    from annif.corpus.subject import Subject, SubjectCorpus
21
22
23
logger = annif.logger
24
logger.addFilter(annif.util.DuplicateFilter())
25
26
27
class AnnifVocabulary(DatadirMixin):
28
    """Class representing a subject vocabulary which can be used by multiple
29
    Annif projects."""
30
31
    # defaults for uninitialized instances
32
    _subjects = None
33
34
    # constants
35
    INDEX_FILENAME_DUMP = "subjects.dump.gz"
36
    INDEX_FILENAME_TTL = "subjects.ttl"
37
    INDEX_FILENAME_CSV = "subjects.csv"
38
39
    def __init__(self, vocab_id: str, datadir: str) -> None:
40
        DatadirMixin.__init__(self, datadir, "vocabs", vocab_id)
41
        self.vocab_id = vocab_id
42
        self._skos_vocab = None
43
44
    def _create_subject_index(self, subject_corpus: SubjectCorpus) -> SubjectIndex:
45
        subjects = SubjectIndexFile()
46
        subjects.load_subjects(subject_corpus)
47
        annif.util.atomic_save(subjects, self.datadir, self.INDEX_FILENAME_CSV)
48
        return subjects
49
50
    def _update_subject_index(self, subject_corpus: SubjectCorpus) -> SubjectIndex:
51
        old_subjects = self.subjects
52
        new_subjects = SubjectIndexFile()
53
        new_subjects.load_subjects(subject_corpus)
54
        updated_subjects = SubjectIndexFile()
55
56
        for old_subject in old_subjects:
57
            if new_subjects.contains_uri(old_subject.uri):
58
                new_subject = new_subjects[new_subjects.by_uri(old_subject.uri)]
59
            else:  # subject removed from new corpus
60
                new_subject = annif.corpus.Subject(
61
                    uri=old_subject.uri, labels=None, notation=None
62
                )
63
            updated_subjects.append(new_subject)
64
        for new_subject in new_subjects:
65
            if not old_subjects.contains_uri(new_subject.uri):
66
                updated_subjects.append(new_subject)
67
        annif.util.atomic_save(updated_subjects, self.datadir, self.INDEX_FILENAME_CSV)
68
        return updated_subjects
69
70
    @property
71
    def subjects(self) -> SubjectIndex:
72
        if self._subjects is None:
73
            path = os.path.join(self.datadir, self.INDEX_FILENAME_CSV)
74
            if os.path.exists(path):
75
                logger.debug("loading subjects from %s", path)
76
                self._subjects = SubjectIndexFile.load(path)
77
            else:
78
                raise NotInitializedException("subject file {} not found".format(path))
79
        return self._subjects
80
81
    @property
82
    def skos(self) -> SubjectFileSKOS:
83
        """return the subject vocabulary from SKOS file"""
84
        if self._skos_vocab is not None:
85
            return self._skos_vocab
86
87
        # attempt to load graph from dump file
88
        dumppath = os.path.join(self.datadir, self.INDEX_FILENAME_DUMP)
89
        if os.path.exists(dumppath):
90
            logger.debug(f"loading graph dump from {dumppath}")
91
            try:
92
                self._skos_vocab = annif.corpus.SubjectFileSKOS(dumppath)
93
            except ModuleNotFoundError:
94
                # Probably dump has been saved using a different rdflib version
95
                logger.debug("could not load graph dump, using turtle file")
96
            else:
97
                return self._skos_vocab
98
99
        # graph dump file not found - parse ttl file instead
100
        path = os.path.join(self.datadir, self.INDEX_FILENAME_TTL)
101
        if os.path.exists(path):
102
            logger.debug(f"loading graph from {path}")
103
            self._skos_vocab = annif.corpus.SubjectFileSKOS(path)
104
            # store the dump file so we can use it next time
105
            self._skos_vocab.save_skos(path)
106
            return self._skos_vocab
107
108
        raise NotInitializedException(f"graph file {path} not found")
109
110
    def __len__(self) -> int:
111
        return len(self.subjects)
112
113
    @property
114
    def languages(self) -> list[str]:
115
        try:
116
            return self.subjects.languages
117
        except NotInitializedException:
118
            return []
119
120
    def load_vocabulary(
121
        self,
122
        subject_corpus: SubjectCorpus,
123
        force: bool = False,
124
    ) -> None:
125
        """Load subjects from a subject corpus and save them into one
126
        or more subject index files as well as a SKOS/Turtle file for later
127
        use. If force=True, replace the existing subject index completely."""
128
129
        if not force and os.path.exists(
130
            os.path.join(self.datadir, self.INDEX_FILENAME_CSV)
131
        ):
132
            logger.info("updating existing subject index")
133
            self._subjects = self._update_subject_index(subject_corpus)
134
        else:
135
            logger.info("creating subject index")
136
            self._subjects = self._create_subject_index(subject_corpus)
137
138
        skosfile = os.path.join(self.datadir, self.INDEX_FILENAME_TTL)
139
        logger.info(f"saving vocabulary into SKOS file {skosfile}")
140
        subject_corpus.save_skos(skosfile)
141
142
    def as_graph(self) -> Graph:
143
        """return the vocabulary as an rdflib graph"""
144
        return self.skos.graph
145
146
    def dump(self) -> dict[str, str | list | int | bool]:
147
        """return this vocabulary as a dict"""
148
149
        try:
150
            languages = list(sorted(self.languages))
151
            size = len(self)
152
            loaded = True
153
        except NotInitializedException:
154
            languages = []
155
            size = None
156
            loaded = False
157
158
        return {
159
            "vocab_id": self.vocab_id,
160
            "languages": languages,
161
            "size": size,
162
            "loaded": loaded,
163
        }
164
165
166
class SubjectIndex(metaclass=abc.ABCMeta):
167
    """Base class for an index that remembers the associations between
168
    integer subject IDs and their URIs and labels."""
169
170
    def __len__(self) -> int:
171
        pass
172
173
    @property
174
    def languages(self) -> list[str] | None:
175
        pass
176
177
    def __getitem__(self, subject_id: int) -> Subject:
178
        pass
179
180
    def contains_uri(self, uri: str) -> bool:
181
        pass
182
183
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
184
        """return the subject ID of a subject by its URI, or None if not found.
185
        If warnings=True, log a warning message if the URI cannot be found."""
186
        pass
187
188
    def by_label(self, label: str | None, language: str) -> int | None:
189
        """return the subject ID of a subject by its label in a given
190
        language"""
191
        pass
192
193
    def active(self) -> list[tuple[int, Subject]]:
194
        """return a list of (subject_id, subject) tuples of all subjects that
195
        are not deprecated"""
196
197
198
class SubjectIndexFile(SubjectIndex):
199
    def __init__(self) -> None:
200
        self._subjects = []
201
        self._uri_idx = {}
202
        self._label_idx = {}
203
        self._languages = None
204
205
    def load_subjects(self, corpus: SubjectCorpus) -> None:
206
        """Initialize the subject index from a subject corpus"""
207
208
        self._languages = corpus.languages
209
        for subject in corpus.subjects:
210
            self.append(subject)
211
212
    def __len__(self) -> int:
213
        return len(self._subjects)
214
215
    @property
216
    def languages(self) -> list[str] | None:
217
        return self._languages
218
219
    def __getitem__(self, subject_id: int) -> Subject:
220
        return self._subjects[subject_id]
221
222
    def append(self, subject: Subject) -> None:
223
        if self._languages is None and subject.labels is not None:
224
            self._languages = list(subject.labels.keys())
225
226
        subject_id = len(self._subjects)
227
        self._uri_idx[subject.uri] = subject_id
228
        if subject.labels:
229
            for lang, label in subject.labels.items():
230
                self._label_idx[(label, lang)] = subject_id
231
        self._subjects.append(subject)
232
233
    def contains_uri(self, uri: str) -> bool:
234
        return uri in self._uri_idx
235
236
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
237
        try:
238
            return self._uri_idx[uri]
239
        except KeyError:
240
            if warnings:
241
                logger.warning("Unknown subject URI <%s>", uri)
242
            return None
243
244
    def by_label(self, label: str | None, language: str) -> int | None:
245
        try:
246
            return self._label_idx[(label, language)]
247
        except KeyError:
248
            logger.warning('Unknown subject label "%s"@%s', label, language)
249
            return None
250
251
    def deprecated_ids(self) -> list[int]:
252
        """return indices of deprecated subjects"""
253
254
        return [
255
            subject_id
256
            for subject_id, subject in enumerate(self._subjects)
257
            if subject.labels is None
258
        ]
259
260
    @property
261
    def active(self) -> list[tuple[int, Subject]]:
262
        return [
263
            (subj_id, subject)
264
            for subj_id, subject in enumerate(self._subjects)
265
            if subject.labels is not None
266
        ]
267
268
    def save(self, path: str) -> None:
269
        """Save this subject index into a file with the given path name."""
270
271
        fieldnames = ["uri", "notation"] + [f"label_{lang}" for lang in self._languages]
272
273
        with open(path, "w", encoding="utf-8", newline="") as csvfile:
274
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
275
            writer.writeheader()
276
            for subject in self:
277
                row = {"uri": subject.uri, "notation": subject.notation or ""}
278
                if subject.labels:
279
                    for lang, label in subject.labels.items():
280
                        row[f"label_{lang}"] = label
281
                writer.writerow(row)
282
283
    @classmethod
284
    def load(cls, path: str) -> SubjectIndex:
285
        """Load a subject index from a CSV file and return it."""
286
287
        corpus = annif.corpus.SubjectFileCSV(path)
288
        subject_index = cls()
289
        subject_index.load_subjects(corpus)
290
        return subject_index
291