Passed
Pull Request — master (#606)
by Osma
03:06
created

annif.corpus.subject.SubjectSet.__len__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 2
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 2
loc 2
rs 10
c 0
b 0
f 0
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import annif.util
4
import numpy as np
5
from annif import logger
6
from .types import Subject
7
from .skos import serialize_subjects_to_skos
8
9
10 View Code Duplication
class SubjectFileTSV:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
11
    """A subject vocabulary stored in a TSV file."""
12
13
    def __init__(self, path):
14
        self.path = path
15
16
    def _parse_line(self, line):
17
        vals = line.strip().split('\t', 2)
18
        clean_uri = annif.util.cleanup_uri(vals[0])
19
        label = vals[1] if len(vals) >= 2 else None
20
        notation = vals[2] if len(vals) >= 3 else None
21
        yield Subject(uri=clean_uri, label=label, notation=notation)
22
23
    @property
24
    def languages(self):
25
        # we don't have information about the language(s) of labels
26
        return None
27
28
    def subjects(self, language):
29
        with open(self.path, encoding='utf-8-sig') as subjfile:
30
            for line in subjfile:
31
                yield from self._parse_line(line)
32
33
    def save_skos(self, path, language):
34
        """Save the contents of the subject vocabulary into a SKOS/Turtle
35
        file with the given path name."""
36
        serialize_subjects_to_skos(self.subjects(language), language, path)
37
38
39 View Code Duplication
class SubjectIndex:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
40
    """An index that remembers the associations between integers subject IDs
41
    and their URIs and labels."""
42
43
    def __init__(self):
44
        self._uris = []
45
        self._labels = []
46
        self._notations = []
47
        self._uri_idx = {}
48
        self._label_idx = {}
49
50
    def load_subjects(self, corpus, language):
51
        """Initialize the subject index from a subject corpus using labels
52
        in the given language."""
53
54
        for subject_id, subject in enumerate(corpus.subjects(language)):
55
            self._append(subject_id, subject)
56
57
    def __len__(self):
58
        return len(self._uris)
59
60
    def __getitem__(self, subject_id):
61
        return Subject(uri=self._uris[subject_id],
62
                       label=self._labels[subject_id],
63
                       notation=self._notations[subject_id])
64
65
    def _append(self, subject_id, subject):
66
        self._uris.append(subject.uri)
67
        self._labels.append(subject.label)
68
        self._notations.append(subject.notation)
69
        self._uri_idx[subject.uri] = subject_id
70
        self._label_idx[subject.label] = subject_id
71
72
    def append(self, subject):
73
        subject_id = len(self._uris)
74
        self._append(subject_id, subject)
75
76
    def contains_uri(self, uri):
77
        return uri in self._uri_idx
78
79
    def by_uri(self, uri, warnings=True):
80
        """return the subject index of a subject by its URI, or None if not found.
81
        If warnings=True, log a warning message if the URI cannot be found."""
82
        try:
83
            return self._uri_idx[uri]
84
        except KeyError:
85
            if warnings:
86
                logger.warning('Unknown subject URI <%s>', uri)
87
            return None
88
89
    def by_label(self, label):
90
        """return the subject index of a subject by its label"""
91
        try:
92
            return self._label_idx[label]
93
        except KeyError:
94
            logger.warning('Unknown subject label "%s"', label)
95
            return None
96
97
    def deprecated_ids(self):
98
        """return indices of deprecated subjects"""
99
100
        return [subject_id for subject_id, label in enumerate(self._labels)
101
                if label is None]
102
103
    @property
104
    def active(self):
105
        """return a list of (subject_id, uri, label, notation) tuples of all
106
        subjects that are not deprecated"""
107
108
        return [(subj_id, uri, label, notation)
109
                for subj_id, (uri, label, notation)
110
                in enumerate(zip(self._uris, self._labels, self._notations))
111
                if label is not None]
112
113
    def save(self, path):
114
        """Save this subject index into a file."""
115
116
        with open(path, 'w', encoding='utf-8') as subjfile:
117
            for uri, label, notation in self:
118
                line = "<{}>".format(uri)
119
                if label is not None:
120
                    line += ('\t' + label)
121
                    if notation is not None:
122
                        line += ('\t' + notation)
123
                print(line, file=subjfile)
124
125
    @classmethod
126
    def load(cls, path):
127
        """Load a subject index from a TSV file and return it."""
128
129
        corpus = SubjectFileTSV(path)
130
        subject_index = cls()
131
        subject_index.load_subjects(corpus, None)
132
        return subject_index
133
134
135 View Code Duplication
class SubjectSet:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
136
    """Represents a set of subjects for a document."""
137
138
    def __init__(self, subject_ids=None):
139
        """Create a SubjectSet and optionally initialize it from an iterable
140
        of subject IDs"""
141
142
        if subject_ids:
143
            # use set comprehension to eliminate possible duplicates
144
            self._subject_ids = list({subject_id
145
                                      for subject_id in subject_ids
146
                                      if subject_id is not None})
147
        else:
148
            self._subject_ids = []
149
150
    def __len__(self):
151
        return len(self._subject_ids)
152
153
    def __getitem__(self, idx):
154
        return self._subject_ids[idx]
155
156
    def __bool__(self):
157
        return bool(self._subject_ids)
158
159
    def __eq__(self, other):
160
        if isinstance(other, SubjectSet):
161
            return self._subject_ids == other._subject_ids
162
163
        return False
164
165
    @classmethod
166
    def from_string(cls, subj_data, subject_index):
167
        subject_ids = set()
168
        for line in subj_data.splitlines():
169
            uri, label = cls._parse_line(line)
170
            if uri is not None:
171
                subject_ids.add(subject_index.by_uri(uri))
172
            else:
173
                subject_ids.add(subject_index.by_label(label))
174
        return cls(subject_ids)
175
176
    @staticmethod
177
    def _parse_line(line):
178
        uri = label = None
179
        vals = line.split("\t")
180
        for val in vals:
181
            val = val.strip()
182
            if val == '':
183
                continue
184
            if val.startswith('<') and val.endswith('>'):  # URI
185
                uri = val[1:-1]
186
                continue
187
            label = val
188
            break
189
        return uri, label
190
191
    def as_vector(self, size=None, destination=None):
192
        """Return the hits as a one-dimensional NumPy array in sklearn
193
           multilabel indicator format. Use destination array if given (not
194
           None), otherwise create and return a new one of the given size."""
195
196
        if destination is None:
197
            destination = np.zeros(size, dtype=bool)
198
199
        destination[list(self._subject_ids)] = True
200
201
        return destination
202