Passed
Pull Request — master (#608)
by Osma
02:49
created

annif.corpus.subject.SubjectIndex.__len__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 2
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 2
loc 2
rs 10
c 0
b 0
f 0
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import csv
4
import numpy as np
5
import annif.util
6
from annif import logger
7
from .types import Subject, SubjectCorpus
8
from .skos import serialize_subjects_to_skos
9
10
11 View Code Duplication
class SubjectFileTSV(SubjectCorpus):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
12
    """A monolingual subject vocabulary stored in a TSV file."""
13
14
    def __init__(self, path, language):
15
        """initialize the SubjectFileTSV given a path to a TSV file and the
16
        language of the vocabulary"""
17
18
        self.path = path
19
        self.language = language
20
21
    def _parse_line(self, line):
22
        vals = line.strip().split('\t', 2)
23
        clean_uri = annif.util.cleanup_uri(vals[0])
24
        label = vals[1] if len(vals) >= 2 else None
25
        labels = {self.language: label} if label else None
26
        notation = vals[2] if len(vals) >= 3 else None
27
        yield Subject(uri=clean_uri,
28
                      labels=labels,
29
                      notation=notation)
30
31
    @property
32
    def languages(self):
33
        return [self.language]
34
35
    @property
36
    def subjects(self):
37
        with open(self.path, encoding='utf-8-sig') as subjfile:
38
            for line in subjfile:
39
                yield from self._parse_line(line)
40
41
    def save_skos(self, path):
42
        """Save the contents of the subject vocabulary into a SKOS/Turtle
43
        file with the given path name."""
44
        serialize_subjects_to_skos(self.subjects, path)
45
46
47 View Code Duplication
class SubjectFileCSV(SubjectCorpus):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
48
    """A multilingual subject vocabulary stored in a CSV file."""
49
50
    def __init__(self, path):
51
        """initialize the SubjectFileCSV given a path to a CSV file"""
52
        self.path = path
53
54
    def _parse_row(self, row):
55
        labels = {
56
            fname.replace('label_', ''): value or None
57
            for fname, value in row.items()
58
            if fname.startswith('label_')
59
        }
60
        yield Subject(uri=annif.util.cleanup_uri(row['uri']),
61
                      labels=labels,
62
                      notation=row.get('notation', None) or None)
63
64
    @property
65
    def languages(self):
66
        # infer the supported languages from the CSV column names
67
        with open(self.path, encoding='utf-8-sig') as csvfile:
68
            reader = csv.reader(csvfile)
69
            fieldnames = next(reader, None)
70
71
        return [fname.replace('label_', '')
72
                for fname in fieldnames
73
                if fname.startswith('label_')]
74
75
    @property
76
    def subjects(self):
77
        with open(self.path, encoding='utf-8-sig') as csvfile:
78
            reader = csv.DictReader(csvfile)
79
            for row in reader:
80
                yield from self._parse_row(row)
81
82
    def save_skos(self, path):
83
        """Save the contents of the subject vocabulary into a SKOS/Turtle
84
        file with the given path name."""
85
        serialize_subjects_to_skos(self.subjects, path)
86
87
88 View Code Duplication
class SubjectIndex:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
89
    """An index that remembers the associations between integers subject IDs
90
    and their URIs and labels."""
91
92
    def __init__(self):
93
        self._subjects = []
94
        self._uri_idx = {}
95
        self._label_idx = {}
96
        self._languages = None
97
98
    def load_subjects(self, corpus):
99
        """Initialize the subject index from a subject corpus"""
100
101
        self._languages = corpus.languages
102
        for subject in corpus.subjects:
103
            self.append(subject)
104
105
    def __len__(self):
106
        return len(self._subjects)
107
108
    def __getitem__(self, subject_id):
109
        return self._subjects[subject_id]
110
111
    def append(self, subject):
112
        if self._languages is None:
113
            self._languages = list(subject.labels.keys())
114
115
        subject_id = len(self._subjects)
116
        self._uri_idx[subject.uri] = subject_id
117
        if subject.labels:
118
            for lang, label in subject.labels.items():
119
                self._label_idx[(label, lang)] = subject_id
120
        self._subjects.append(subject)
121
122
    def contains_uri(self, uri):
123
        return uri in self._uri_idx
124
125
    def by_uri(self, uri, warnings=True):
126
        """return the subject ID of a subject by its URI, or None if not found.
127
        If warnings=True, log a warning message if the URI cannot be found."""
128
        try:
129
            return self._uri_idx[uri]
130
        except KeyError:
131
            if warnings:
132
                logger.warning('Unknown subject URI <%s>', uri)
133
            return None
134
135
    def by_label(self, label, language):
136
        """return the subject ID of a subject by its label in a given
137
        language"""
138
        try:
139
            return self._label_idx[(label, language)]
140
        except KeyError:
141
            logger.warning('Unknown subject label "%s"@%s', label, language)
142
            return None
143
144
    def deprecated_ids(self):
145
        """return indices of deprecated subjects"""
146
147
        return [subject_id for subject_id, subject in enumerate(self._subjects)
148
                if subject.labels is None]
149
150
    @property
151
    def active(self):
152
        """return a list of (subject_id, subject) tuples of all subjects that
153
        are not deprecated"""
154
155
        return [(subj_id, subject)
156
                for subj_id, subject
157
                in enumerate(self._subjects)
158
                if subject.labels is not None]
159
160
    def save(self, path):
161
        """Save this subject index into a file with the given path name."""
162
163
        fieldnames = ['uri', 'notation'] + \
164
            [f'label_{lang}' for lang in self._languages]
165
166
        with open(path, 'w', encoding='utf-8', newline='') as csvfile:
167
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
168
            writer.writeheader()
169
            for subject in self:
170
                row = {'uri': subject.uri,
171
                       'notation': subject.notation or ''}
172
                if subject.labels:
173
                    for lang, label in subject.labels.items():
174
                        row[f'label_{lang}'] = label
175
                writer.writerow(row)
176
177
    @classmethod
178
    def load(cls, path):
179
        """Load a subject index from a CSV file and return it."""
180
181
        corpus = SubjectFileCSV(path)
182
        subject_index = cls()
183
        subject_index.load_subjects(corpus)
184
        return subject_index
185
186
187 View Code Duplication
class SubjectSet:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
188
    """Represents a set of subjects for a document."""
189
190
    def __init__(self, subject_ids=None):
191
        """Create a SubjectSet and optionally initialize it from an iterable
192
        of subject IDs"""
193
194
        if subject_ids:
195
            # use set comprehension to eliminate possible duplicates
196
            self._subject_ids = list({subject_id
197
                                      for subject_id in subject_ids
198
                                      if subject_id is not None})
199
        else:
200
            self._subject_ids = []
201
202
    def __len__(self):
203
        return len(self._subject_ids)
204
205
    def __getitem__(self, idx):
206
        return self._subject_ids[idx]
207
208
    def __bool__(self):
209
        return bool(self._subject_ids)
210
211
    def __eq__(self, other):
212
        if isinstance(other, SubjectSet):
213
            return self._subject_ids == other._subject_ids
214
215
        return False
216
217
    @classmethod
218
    def from_string(cls, subj_data, subject_index, language):
219
        subject_ids = set()
220
        for line in subj_data.splitlines():
221
            uri, label = cls._parse_line(line)
222
            if uri is not None:
223
                subject_ids.add(subject_index.by_uri(uri))
224
            else:
225
                subject_ids.add(subject_index.by_label(label, language))
226
        return cls(subject_ids)
227
228
    @staticmethod
229
    def _parse_line(line):
230
        uri = label = None
231
        vals = line.split("\t")
232
        for val in vals:
233
            val = val.strip()
234
            if val == '':
235
                continue
236
            if val.startswith('<') and val.endswith('>'):  # URI
237
                uri = val[1:-1]
238
                continue
239
            label = val
240
            break
241
        return uri, label
242
243
    def as_vector(self, size=None, destination=None):
244
        """Return the hits as a one-dimensional NumPy array in sklearn
245
           multilabel indicator format. Use destination array if given (not
246
           None), otherwise create and return a new one of the given size."""
247
248
        if destination is None:
249
            destination = np.zeros(size, dtype=bool)
250
251
        destination[list(self._subject_ids)] = True
252
253
        return destination
254