Passed
Pull Request — master (#608)
by Osma
02:57
created

annif.corpus.subject.SubjectSet.__init__()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 6

Duplication

Lines 11
Ratio 100 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 2
dl 11
loc 11
rs 10
c 0
b 0
f 0
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import csv
4
import numpy as np
5
import annif.util
6
import os.path
7
from annif import logger
8
from .types import Subject, SubjectCorpus
9
from .skos import serialize_subjects_to_skos
10
11
12 View Code Duplication
class SubjectFileTSV(SubjectCorpus):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
13
    """A monolingual subject vocabulary stored in a TSV file."""
14
15
    def __init__(self, path, language):
16
        """initialize the SubjectFileTSV given a path to a TSV file and the
17
        language of the vocabulary"""
18
19
        self.path = path
20
        self.language = language
21
22
    def _parse_line(self, line):
23
        vals = line.strip().split('\t', 2)
24
        clean_uri = annif.util.cleanup_uri(vals[0])
25
        label = vals[1] if len(vals) >= 2 else None
26
        labels = {self.language: label} if label else None
27
        notation = vals[2] if len(vals) >= 3 else None
28
        yield Subject(uri=clean_uri,
29
                      labels=labels,
30
                      notation=notation)
31
32
    @property
33
    def languages(self):
34
        return [self.language]
35
36
    @property
37
    def subjects(self):
38
        with open(self.path, encoding='utf-8-sig') as subjfile:
39
            for line in subjfile:
40
                yield from self._parse_line(line)
41
42
    def save_skos(self, path):
43
        """Save the contents of the subject vocabulary into a SKOS/Turtle
44
        file with the given path name."""
45
        serialize_subjects_to_skos(self.subjects, path)
46
47
48 View Code Duplication
class SubjectFileCSV(SubjectCorpus):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
49
    """A multilingual subject vocabulary stored in a CSV file."""
50
51
    def __init__(self, path):
52
        """initialize the SubjectFileCSV given a path to a CSV file"""
53
        self.path = path
54
55
    def _parse_row(self, row):
56
        labels = {
57
            fname.replace('label_', ''): value or None
58
            for fname, value in row.items()
59
            if fname.startswith('label_')
60
        }
61
62
        # if there are no labels in any language, set labels to None
63
        # indicating a deprecated subject
64
        if set(labels.values()) == {None}:
65
            labels = None
66
67
        yield Subject(uri=annif.util.cleanup_uri(row['uri']),
68
                      labels=labels,
69
                      notation=row.get('notation', None) or None)
70
71
    @property
72
    def languages(self):
73
        # infer the supported languages from the CSV column names
74
        with open(self.path, encoding='utf-8-sig') as csvfile:
75
            reader = csv.reader(csvfile)
76
            fieldnames = next(reader, None)
77
78
        return [fname.replace('label_', '')
79
                for fname in fieldnames
80
                if fname.startswith('label_')]
81
82
    @property
83
    def subjects(self):
84
        with open(self.path, encoding='utf-8-sig') as csvfile:
85
            reader = csv.DictReader(csvfile)
86
            for row in reader:
87
                yield from self._parse_row(row)
88
89
    def save_skos(self, path):
90
        """Save the contents of the subject vocabulary into a SKOS/Turtle
91
        file with the given path name."""
92
        serialize_subjects_to_skos(self.subjects, path)
93
94
    @staticmethod
95
    def is_csv_file(path):
96
        """return True if the path looks like a CSV file"""
97
98
        return os.path.splitext(path)[1].lower() == '.csv'
99
100
101 View Code Duplication
class SubjectIndex:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
102
    """An index that remembers the associations between integers subject IDs
103
    and their URIs and labels."""
104
105
    def __init__(self):
106
        self._subjects = []
107
        self._uri_idx = {}
108
        self._label_idx = {}
109
        self._languages = None
110
111
    def load_subjects(self, corpus):
112
        """Initialize the subject index from a subject corpus"""
113
114
        self._languages = corpus.languages
115
        for subject in corpus.subjects:
116
            self.append(subject)
117
118
    def __len__(self):
119
        return len(self._subjects)
120
121
    def __getitem__(self, subject_id):
122
        return self._subjects[subject_id]
123
124
    def append(self, subject):
125
        if self._languages is None and subject.labels is not None:
126
            self._languages = list(subject.labels.keys())
127
128
        subject_id = len(self._subjects)
129
        self._uri_idx[subject.uri] = subject_id
130
        if subject.labels:
131
            for lang, label in subject.labels.items():
132
                self._label_idx[(label, lang)] = subject_id
133
        self._subjects.append(subject)
134
135
    def contains_uri(self, uri):
136
        return uri in self._uri_idx
137
138
    def by_uri(self, uri, warnings=True):
139
        """return the subject ID of a subject by its URI, or None if not found.
140
        If warnings=True, log a warning message if the URI cannot be found."""
141
        try:
142
            return self._uri_idx[uri]
143
        except KeyError:
144
            if warnings:
145
                logger.warning('Unknown subject URI <%s>', uri)
146
            return None
147
148
    def by_label(self, label, language):
149
        """return the subject ID of a subject by its label in a given
150
        language"""
151
        try:
152
            return self._label_idx[(label, language)]
153
        except KeyError:
154
            logger.warning('Unknown subject label "%s"@%s', label, language)
155
            return None
156
157
    def deprecated_ids(self):
158
        """return indices of deprecated subjects"""
159
160
        return [subject_id for subject_id, subject in enumerate(self._subjects)
161
                if subject.labels is None]
162
163
    @property
164
    def active(self):
165
        """return a list of (subject_id, subject) tuples of all subjects that
166
        are not deprecated"""
167
168
        return [(subj_id, subject)
169
                for subj_id, subject
170
                in enumerate(self._subjects)
171
                if subject.labels is not None]
172
173
    def save(self, path):
174
        """Save this subject index into a file with the given path name."""
175
176
        fieldnames = ['uri', 'notation'] + \
177
            [f'label_{lang}' for lang in self._languages]
178
179
        with open(path, 'w', encoding='utf-8', newline='') as csvfile:
180
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
181
            writer.writeheader()
182
            for subject in self:
183
                row = {'uri': subject.uri,
184
                       'notation': subject.notation or ''}
185
                if subject.labels:
186
                    for lang, label in subject.labels.items():
187
                        row[f'label_{lang}'] = label
188
                writer.writerow(row)
189
190
    @classmethod
191
    def load(cls, path):
192
        """Load a subject index from a CSV file and return it."""
193
194
        corpus = SubjectFileCSV(path)
195
        subject_index = cls()
196
        subject_index.load_subjects(corpus)
197
        return subject_index
198
199
200 View Code Duplication
class SubjectSet:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
201
    """Represents a set of subjects for a document."""
202
203
    def __init__(self, subject_ids=None):
204
        """Create a SubjectSet and optionally initialize it from an iterable
205
        of subject IDs"""
206
207
        if subject_ids:
208
            # use set comprehension to eliminate possible duplicates
209
            self._subject_ids = list({subject_id
210
                                      for subject_id in subject_ids
211
                                      if subject_id is not None})
212
        else:
213
            self._subject_ids = []
214
215
    def __len__(self):
216
        return len(self._subject_ids)
217
218
    def __getitem__(self, idx):
219
        return self._subject_ids[idx]
220
221
    def __bool__(self):
222
        return bool(self._subject_ids)
223
224
    def __eq__(self, other):
225
        if isinstance(other, SubjectSet):
226
            return self._subject_ids == other._subject_ids
227
228
        return False
229
230
    @classmethod
231
    def from_string(cls, subj_data, subject_index, language):
232
        subject_ids = set()
233
        for line in subj_data.splitlines():
234
            uri, label = cls._parse_line(line)
235
            if uri is not None:
236
                subject_ids.add(subject_index.by_uri(uri))
237
            else:
238
                subject_ids.add(subject_index.by_label(label, language))
239
        return cls(subject_ids)
240
241
    @staticmethod
242
    def _parse_line(line):
243
        uri = label = None
244
        vals = line.split("\t")
245
        for val in vals:
246
            val = val.strip()
247
            if val == '':
248
                continue
249
            if val.startswith('<') and val.endswith('>'):  # URI
250
                uri = val[1:-1]
251
                continue
252
            label = val
253
            break
254
        return uri, label
255
256
    def as_vector(self, size=None, destination=None):
257
        """Return the hits as a one-dimensional NumPy array in sklearn
258
           multilabel indicator format. Use destination array if given (not
259
           None), otherwise create and return a new one of the given size."""
260
261
        if destination is None:
262
            destination = np.zeros(size, dtype=bool)
263
264
        destination[list(self._subject_ids)] = True
265
266
        return destination
267