Passed
Push — issue735-subject-filtering ( 3cd194 )
by Osma
03:54
created

annif.vocab.SubjectIndex.contains_uri()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""Vocabulary management functionality for Annif"""
2
3
from __future__ import annotations
4
5
import csv
6
import os.path
7
from typing import TYPE_CHECKING
8
9
import annif
10
import annif.corpus
11
import annif.util
12
from annif.datadir import DatadirMixin
13
from annif.exception import NotInitializedException
14
15
if TYPE_CHECKING:
16
    from rdflib.graph import Graph
17
18
    from annif.corpus.skos import SubjectFileSKOS
19
    from annif.corpus.subject import Subject, SubjectCorpus
20
21
22
logger = annif.logger
23
24
25
class AnnifVocabulary(DatadirMixin):
26
    """Class representing a subject vocabulary which can be used by multiple
27
    Annif projects."""
28
29
    # defaults for uninitialized instances
30
    _subjects = None
31
32
    # constants
33
    INDEX_FILENAME_DUMP = "subjects.dump.gz"
34
    INDEX_FILENAME_TTL = "subjects.ttl"
35
    INDEX_FILENAME_CSV = "subjects.csv"
36
37
    def __init__(self, vocab_id: str, datadir: str) -> None:
38
        DatadirMixin.__init__(self, datadir, "vocabs", vocab_id)
39
        self.vocab_id = vocab_id
40
        self._skos_vocab = None
41
42
    def _create_subject_index(self, subject_corpus: SubjectCorpus) -> SubjectIndex:
43
        subjects = SubjectIndex()
44
        subjects.load_subjects(subject_corpus)
45
        annif.util.atomic_save(subjects, self.datadir, self.INDEX_FILENAME_CSV)
46
        return subjects
47
48
    def _update_subject_index(self, subject_corpus: SubjectCorpus) -> SubjectIndex:
49
        old_subjects = self.subjects
50
        new_subjects = SubjectIndex()
51
        new_subjects.load_subjects(subject_corpus)
52
        updated_subjects = SubjectIndex()
53
54
        for old_subject in old_subjects:
55
            if new_subjects.contains_uri(old_subject.uri):
56
                new_subject = new_subjects[new_subjects.by_uri(old_subject.uri)]
57
            else:  # subject removed from new corpus
58
                new_subject = annif.corpus.Subject(
59
                    uri=old_subject.uri, labels=None, notation=None
60
                )
61
            updated_subjects.append(new_subject)
62
        for new_subject in new_subjects:
63
            if not old_subjects.contains_uri(new_subject.uri):
64
                updated_subjects.append(new_subject)
65
        annif.util.atomic_save(updated_subjects, self.datadir, self.INDEX_FILENAME_CSV)
66
        return updated_subjects
67
68
    @property
69
    def subjects(self) -> SubjectIndex:
70
        if self._subjects is None:
71
            path = os.path.join(self.datadir, self.INDEX_FILENAME_CSV)
72
            if os.path.exists(path):
73
                logger.debug("loading subjects from %s", path)
74
                self._subjects = SubjectIndex.load(path)
75
            else:
76
                raise NotInitializedException("subject file {} not found".format(path))
77
        return self._subjects
78
79
    @property
80
    def skos(self) -> SubjectFileSKOS:
81
        """return the subject vocabulary from SKOS file"""
82
        if self._skos_vocab is not None:
83
            return self._skos_vocab
84
85
        # attempt to load graph from dump file
86
        dumppath = os.path.join(self.datadir, self.INDEX_FILENAME_DUMP)
87
        if os.path.exists(dumppath):
88
            logger.debug(f"loading graph dump from {dumppath}")
89
            try:
90
                self._skos_vocab = annif.corpus.SubjectFileSKOS(dumppath)
91
            except ModuleNotFoundError:
92
                # Probably dump has been saved using a different rdflib version
93
                logger.debug("could not load graph dump, using turtle file")
94
            else:
95
                return self._skos_vocab
96
97
        # graph dump file not found - parse ttl file instead
98
        path = os.path.join(self.datadir, self.INDEX_FILENAME_TTL)
99
        if os.path.exists(path):
100
            logger.debug(f"loading graph from {path}")
101
            self._skos_vocab = annif.corpus.SubjectFileSKOS(path)
102
            # store the dump file so we can use it next time
103
            self._skos_vocab.save_skos(path)
104
            return self._skos_vocab
105
106
        raise NotInitializedException(f"graph file {path} not found")
107
108
    def __len__(self) -> int:
109
        return len(self.subjects)
110
111
    @property
112
    def languages(self) -> list[str]:
113
        try:
114
            return self.subjects.languages
115
        except NotInitializedException:
116
            return []
117
118
    def load_vocabulary(
119
        self,
120
        subject_corpus: SubjectCorpus,
121
        force: bool = False,
122
    ) -> None:
123
        """Load subjects from a subject corpus and save them into one
124
        or more subject index files as well as a SKOS/Turtle file for later
125
        use. If force=True, replace the existing subject index completely."""
126
127
        if not force and os.path.exists(
128
            os.path.join(self.datadir, self.INDEX_FILENAME_CSV)
129
        ):
130
            logger.info("updating existing subject index")
131
            self._subjects = self._update_subject_index(subject_corpus)
132
        else:
133
            logger.info("creating subject index")
134
            self._subjects = self._create_subject_index(subject_corpus)
135
136
        skosfile = os.path.join(self.datadir, self.INDEX_FILENAME_TTL)
137
        logger.info(f"saving vocabulary into SKOS file {skosfile}")
138
        subject_corpus.save_skos(skosfile)
139
140
    def as_graph(self) -> Graph:
141
        """return the vocabulary as an rdflib graph"""
142
        return self.skos.graph
143
144
    def dump(self) -> dict[str, str | list | int | bool]:
145
        """return this vocabulary as a dict"""
146
147
        try:
148
            languages = list(sorted(self.languages))
149
            size = len(self)
150
            loaded = True
151
        except NotInitializedException:
152
            languages = []
153
            size = None
154
            loaded = False
155
156
        return {
157
            "vocab_id": self.vocab_id,
158
            "languages": languages,
159
            "size": size,
160
            "loaded": loaded,
161
        }
162
163
164
class SubjectIndex:
165
    """An index that remembers the associations between integers subject IDs
166
    and their URIs and labels."""
167
168
    def __init__(self) -> None:
169
        self._subjects = []
170
        self._uri_idx = {}
171
        self._label_idx = {}
172
        self._languages = None
173
174
    def load_subjects(self, corpus: SubjectCorpus) -> None:
175
        """Initialize the subject index from a subject corpus"""
176
177
        self._languages = corpus.languages
178
        for subject in corpus.subjects:
179
            self.append(subject)
180
181
    def __len__(self) -> int:
182
        return len(self._subjects)
183
184
    @property
185
    def languages(self) -> list[str] | None:
186
        return self._languages
187
188
    def __getitem__(self, subject_id: int) -> Subject:
189
        return self._subjects[subject_id]
190
191
    def append(self, subject: Subject) -> None:
192
        if self._languages is None and subject.labels is not None:
193
            self._languages = list(subject.labels.keys())
194
195
        subject_id = len(self._subjects)
196
        self._uri_idx[subject.uri] = subject_id
197
        if subject.labels:
198
            for lang, label in subject.labels.items():
199
                self._label_idx[(label, lang)] = subject_id
200
        self._subjects.append(subject)
201
202
    def contains_uri(self, uri: str) -> bool:
203
        return uri in self._uri_idx
204
205
    def by_uri(self, uri: str, warnings: bool = True) -> int | None:
206
        """return the subject ID of a subject by its URI, or None if not found.
207
        If warnings=True, log a warning message if the URI cannot be found."""
208
        try:
209
            return self._uri_idx[uri]
210
        except KeyError:
211
            if warnings:
212
                logger.warning("Unknown subject URI <%s>", uri)
213
            return None
214
215
    def by_label(self, label: str | None, language: str) -> int | None:
216
        """return the subject ID of a subject by its label in a given
217
        language"""
218
        try:
219
            return self._label_idx[(label, language)]
220
        except KeyError:
221
            logger.warning('Unknown subject label "%s"@%s', label, language)
222
            return None
223
224
    def deprecated_ids(self) -> list[int]:
225
        """return indices of deprecated subjects"""
226
227
        return [
228
            subject_id
229
            for subject_id, subject in enumerate(self._subjects)
230
            if subject.labels is None
231
        ]
232
233
    @property
234
    def active(self) -> list[tuple[int, Subject]]:
235
        """return a list of (subject_id, subject) tuples of all subjects that
236
        are not deprecated"""
237
238
        return [
239
            (subj_id, subject)
240
            for subj_id, subject in enumerate(self._subjects)
241
            if subject.labels is not None
242
        ]
243
244
    def save(self, path: str) -> None:
245
        """Save this subject index into a file with the given path name."""
246
247
        fieldnames = ["uri", "notation"] + [f"label_{lang}" for lang in self._languages]
248
249
        with open(path, "w", encoding="utf-8", newline="") as csvfile:
250
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
251
            writer.writeheader()
252
            for subject in self:
253
                row = {"uri": subject.uri, "notation": subject.notation or ""}
254
                if subject.labels:
255
                    for lang, label in subject.labels.items():
256
                        row[f"label_{lang}"] = label
257
                writer.writerow(row)
258
259
    @classmethod
260
    def load(cls, path: str) -> SubjectIndex:
261
        """Load a subject index from a CSV file and return it."""
262
263
        corpus = annif.corpus.SubjectFileCSV(path)
264
        subject_index = cls()
265
        subject_index.load_subjects(corpus)
266
        return subject_index
267