Passed
Pull Request — master (#608)
by Osma
03:32
created

annif.corpus.subject.SubjectSet.__len__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 2
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 2
loc 2
rs 10
c 0
b 0
f 0
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import csv
4
import numpy as np
5
import annif.util
6
from annif import logger
7
from .types import Subject, SubjectCorpus
8
from .skos import serialize_subjects_to_skos
9
10
11 View Code Duplication
class SubjectFileTSV(SubjectCorpus):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
12
    """A monolingual subject vocabulary stored in a TSV file."""
13
14
    def __init__(self, path, language):
15
        """initialize the SubjectFileTSV given a path to a TSV file and the
16
        language of the vocabulary"""
17
18
        self.path = path
19
        self.language = language
20
21
    def _parse_line(self, line):
22
        vals = line.strip().split('\t', 2)
23
        clean_uri = annif.util.cleanup_uri(vals[0])
24
        label = vals[1] if len(vals) >= 2 else None
25
        labels = {self.language: label} if label else None
26
        notation = vals[2] if len(vals) >= 3 else None
27
        yield Subject(uri=clean_uri,
28
                      labels=labels,
29
                      notation=notation)
30
31
    @property
32
    def languages(self):
33
        return [self.language]
34
35
    @property
36
    def subjects(self):
37
        with open(self.path, encoding='utf-8-sig') as subjfile:
38
            for line in subjfile:
39
                yield from self._parse_line(line)
40
41
    def save_skos(self, path):
42
        """Save the contents of the subject vocabulary into a SKOS/Turtle
43
        file with the given path name."""
44
        serialize_subjects_to_skos(self.subjects, path)
45
46
47 View Code Duplication
class SubjectFileCSV(SubjectCorpus):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
48
    """A multilingual subject vocabulary stored in a CSV file."""
49
50
    def __init__(self, path):
51
        """initialize the SubjectFileCSV given a path to a CSV file"""
52
        self.path = path
53
54
    def _parse_row(self, row):
55
        labels = {
56
            fname.replace('label_', ''): value or None
57
            for fname, value in row.items()
58
            if fname.startswith('label_')
59
        }
60
        yield Subject(uri=annif.util.cleanup_uri(row['uri']),
61
                      labels=labels,
62
                      notation=row.get('notation', None) or None)
63
64
    @property
65
    def languages(self):
66
        # infer the supported languages from the CSV column names
67
        with open(self.path, encoding='utf-8-sig') as csvfile:
68
            reader = csv.reader(csvfile)
69
            fieldnames = next(reader, None)
70
71
        return [fname.replace('label_', '')
72
                for fname in fieldnames
73
                if fname.startswith('label_')]
74
75
    @property
76
    def subjects(self):
77
        with open(self.path, encoding='utf-8-sig') as csvfile:
78
            reader = csv.DictReader(csvfile)
79
            for row in reader:
80
                yield from self._parse_row(row)
81
82
    def save_skos(self, path):
83
        """Save the contents of the subject vocabulary into a SKOS/Turtle
84
        file with the given path name."""
85
        serialize_subjects_to_skos(self.subjects, path)
86
87
88 View Code Duplication
class SubjectIndex:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
89
    """An index that remembers the associations between integers subject IDs
90
    and their URIs and labels."""
91
92
    def __init__(self):
93
        self._subjects = []
94
        self._uri_idx = {}
95
        self._label_idx = {}
96
        self._languages = None
97
98
    def load_subjects(self, corpus):
99
        """Initialize the subject index from a subject corpus"""
100
101
        self._languages = corpus.languages
102
        for subject in corpus.subjects:
103
            self.append(subject)
104
105
    def __len__(self):
106
        return len(self._subjects)
107
108
    def __getitem__(self, subject_id):
109
        return self._subjects[subject_id]
110
111
    def append(self, subject):
112
        if self._languages is None:
113
            self._languages = list(subject.labels.keys())
114
115
        subject_id = len(self._subjects)
116
        self._uri_idx[subject.uri] = subject_id
117
        if subject.labels:
118
            for lang, label in subject.labels.items():
119
                self._label_idx[(label, lang)] = subject_id
120
        self._subjects.append(subject)
121
122
    def contains_uri(self, uri):
123
        return uri in self._uri_idx
124
125
    def by_uri(self, uri, warnings=True):
126
        """return the subject ID of a subject by its URI, or None if not found.
127
        If warnings=True, log a warning message if the URI cannot be found."""
128
        try:
129
            return self._uri_idx[uri]
130
        except KeyError:
131
            if warnings:
132
                logger.warning('Unknown subject URI <%s>', uri)
133
            return None
134
135
    def by_label(self, label, language):
136
        """return the subject ID of a subject by its label in a given
137
        language"""
138
        try:
139
            return self._label_idx[(label, language)]
140
        except KeyError:
141
            logger.warning('Unknown subject label "%s"@%s', label, language)
142
            return None
143
144
    def deprecated_ids(self):
145
        """return indices of deprecated subjects"""
146
147
        return [subject_id for subject_id, subject in enumerate(self._subjects)
148
                if subject.labels is None]
149
150
    @property
151
    def active(self):
152
        """return a list of (subject_id, subject) tuples of all subjects that
153
        are not deprecated"""
154
155
        return [(subj_id, subject)
156
                for subj_id, subject
157
                in enumerate(self._subjects)
158
                if subject.labels is not None]
159
160
    def save(self, path):
161
        """Save this subject index into a file with the given path name."""
162
163
        fieldnames = ['uri', 'notation']
164
        for lang in self._languages:
165
            fieldnames.append(f'label_{lang}')
166
167
        with open(path, 'w', encoding='utf-8', newline='') as csvfile:
168
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
169
            writer.writeheader()
170
            for subject in self:
171
                row = {'uri': subject.uri,
172
                       'notation': subject.notation or ''}
173
                if subject.labels:
174
                    for lang, label in subject.labels.items():
175
                        row[f'label_{lang}'] = label
176
                writer.writerow(row)
177
178
    @classmethod
179
    def load(cls, path):
180
        """Load a subject index from a CSV file and return it."""
181
182
        corpus = SubjectFileCSV(path)
183
        subject_index = cls()
184
        subject_index.load_subjects(corpus)
185
        return subject_index
186
187
188 View Code Duplication
class SubjectSet:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
189
    """Represents a set of subjects for a document."""
190
191
    def __init__(self, subject_ids=None):
192
        """Create a SubjectSet and optionally initialize it from an iterable
193
        of subject IDs"""
194
195
        if subject_ids:
196
            # use set comprehension to eliminate possible duplicates
197
            self._subject_ids = list({subject_id
198
                                      for subject_id in subject_ids
199
                                      if subject_id is not None})
200
        else:
201
            self._subject_ids = []
202
203
    def __len__(self):
204
        return len(self._subject_ids)
205
206
    def __getitem__(self, idx):
207
        return self._subject_ids[idx]
208
209
    def __bool__(self):
210
        return bool(self._subject_ids)
211
212
    def __eq__(self, other):
213
        if isinstance(other, SubjectSet):
214
            return self._subject_ids == other._subject_ids
215
216
        return False
217
218
    @classmethod
219
    def from_string(cls, subj_data, subject_index, language):
220
        subject_ids = set()
221
        for line in subj_data.splitlines():
222
            uri, label = cls._parse_line(line)
223
            if uri is not None:
224
                subject_ids.add(subject_index.by_uri(uri))
225
            else:
226
                subject_ids.add(subject_index.by_label(label, language))
227
        return cls(subject_ids)
228
229
    @staticmethod
230
    def _parse_line(line):
231
        uri = label = None
232
        vals = line.split("\t")
233
        for val in vals:
234
            val = val.strip()
235
            if val == '':
236
                continue
237
            if val.startswith('<') and val.endswith('>'):  # URI
238
                uri = val[1:-1]
239
                continue
240
            label = val
241
            break
242
        return uri, label
243
244
    def as_vector(self, size=None, destination=None):
245
        """Return the hits as a one-dimensional NumPy array in sklearn
246
           multilabel indicator format. Use destination array if given (not
247
           None), otherwise create and return a new one of the given size."""
248
249
        if destination is None:
250
            destination = np.zeros(size, dtype=bool)
251
252
        destination[list(self._subject_ids)] = True
253
254
        return destination
255