Passed
Pull Request — master (#336)
by Osma
03:41
created

annif.corpus.subject.SubjectDirectory.__init__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import annif.util
4
import numpy as np
5
from annif import logger
6
from .types import Subject
7
8
9
class SubjectFileTSV:
10
    """A subject vocabulary stored in a TSV file."""
11
12
    def __init__(self, path):
13
        self.path = path
14
15
    @property
16
    def subjects(self):
17
        with open(self.path, encoding='utf-8') as subjfile:
18
            for line in subjfile:
19
                uri, label = line.strip().split(None, 1)
20
                clean_uri = annif.util.cleanup_uri(uri)
21
                yield Subject(uri=clean_uri, label=label, text=None)
22
23
24
class SubjectIndex:
25
    """An index that remembers the associations between integers subject IDs
26
    and their URIs and labels."""
27
28
    def __init__(self, corpus):
29
        """Initialize the subject index from a subject corpus."""
30
        self._uris = []
31
        self._labels = []
32
        self._uri_idx = {}
33
        self._label_idx = {}
34
        for subject_id, subject in enumerate(corpus.subjects):
35
            self._uris.append(subject.uri)
36
            self._labels.append(subject.label)
37
            self._uri_idx[subject.uri] = subject_id
38
            self._label_idx[subject.label] = subject_id
39
40
    def __len__(self):
41
        return len(self._uris)
42
43
    def __getitem__(self, subject_id):
44
        return (self._uris[subject_id], self._labels[subject_id])
45
46
    def by_uri(self, uri):
47
        """return the subject index of a subject by its URI"""
48
        try:
49
            return self._uri_idx[uri]
50
        except KeyError:
51
            logger.warning('Unknown subject URI <%s>', uri)
52
            return None
53
54
    def by_label(self, label):
55
        """return the subject index of a subject by its label"""
56
        try:
57
            return self._label_idx[label]
58
        except KeyError:
59
            logger.warning('Unknown subject label "%s"', label)
60
            return None
61
62
    def uris_to_labels(self, uris):
63
        """return a list of labels corresponding to the given URIs; unknown
64
        URIs are ignored"""
65
66
        return [self[subject_id][1]
67
                for subject_id in (self.by_uri(uri) for uri in uris)
68
                if subject_id is not None]
69
70
    def labels_to_uris(self, labels):
71
        """return a list of URIs corresponding to the given labels; unknown
72
        labels are ignored"""
73
74
        return [self[subject_id][0]
75
                for subject_id in (self.by_label(label) for label in labels)
76
                if subject_id is not None]
77
78
    def save(self, path):
79
        """Save this subject index into a file."""
80
81
        with open(path, 'w', encoding='utf-8') as subjfile:
82
            for subject_id in range(len(self)):
83
                line = "<{}>\t{}".format(
84
                    self._uris[subject_id], self._labels[subject_id])
85
                print(line, file=subjfile)
86
87
    @classmethod
88
    def load(cls, path):
89
        """Load a subject index from a TSV file and return it."""
90
91
        corpus = SubjectFileTSV(path)
92
        return cls(corpus)
93
94
95
class SubjectSet:
96
    """Represents a set of subjects for a document."""
97
98
    def __init__(self, subj_data=None):
99
        """Create a SubjectSet and optionally initialize it from a tuple
100
        (URIs, labels)"""
101
102
        uris, labels = subj_data or ([], [])
103
        self.subject_uris = set(uris)
104
        self.subject_labels = set(labels)
105
106
    @classmethod
107
    def from_string(cls, subj_data):
108
        sset = cls()
109
        for line in subj_data.splitlines():
110
            sset._parse_line(line)
111
        return sset
112
113
    def _parse_line(self, line):
114
        vals = line.split("\t")
115
        for val in vals:
116
            val = val.strip()
117
            if val == '':
118
                continue
119
            if val.startswith('<') and val.endswith('>'):  # URI
120
                self.subject_uris.add(val[1:-1])
121
                continue
122
            self.subject_labels.add(val)
123
            return
124
125
    def has_uris(self):
126
        """returns True if the URIs for all subjects are known"""
127
        return len(self.subject_uris) >= len(self.subject_labels)
128
129
    def as_vector(self, subject_index):
130
        """Return the hits as a one-dimensional NumPy array in sklearn
131
           multilabel indicator format, using a subject index as the source
132
           of subjects."""
133
134
        vector = np.zeros(len(subject_index), dtype=np.int8)
135
        if self.has_uris():
136
            for uri in self.subject_uris:
137
                subject_id = subject_index.by_uri(uri)
138
                if subject_id is not None:
139
                    vector[subject_id] = 1
140
        else:
141
            for label in self.subject_labels:
142
                subject_id = subject_index.by_label(label)
143
                if subject_id is not None:
144
                    vector[subject_id] = 1
145
        return vector
146