Completed
Push — master ( 1877f4...a8999e )
by Osma
17s queued 10s
created

annif.corpus.subject.SubjectIndex.uris_to_labels()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import glob
4
import os.path
5
import annif.util
6
import numpy as np
7
from annif import logger
8
from .types import Subject, SubjectCorpus
9
from .convert import SubjectToDocumentCorpusMixin
10
11
12
class SubjectDirectory(SubjectCorpus, SubjectToDocumentCorpusMixin):
13
    """A subject corpus in the form of a directory with .txt files."""
14
15
    def __init__(self, path):
16
        self.path = path
17
        self._filenames = sorted(glob.glob(os.path.join(path, '*.txt')))
18
19
    @property
20
    def subjects(self):
21
        for filename in self._filenames:
22
            with open(filename, encoding='utf-8') as subjfile:
23
                uri, label = subjfile.readline().strip().split(' ', 1)
24
                text = ' '.join(subjfile.readlines())
25
                yield Subject(uri=uri, label=label, text=text)
26
27
28
class SubjectFileTSV(SubjectCorpus, SubjectToDocumentCorpusMixin):
29
    """A subject corpus stored in a TSV file."""
30
31
    def __init__(self, path):
32
        self.path = path
33
34
    @property
35
    def subjects(self):
36
        with open(self.path, encoding='utf-8') as subjfile:
37
            for line in subjfile:
38
                uri, label = line.strip().split(None, 1)
39
                clean_uri = annif.util.cleanup_uri(uri)
40
                yield Subject(uri=clean_uri, label=label, text=None)
41
42
43
class SubjectIndex:
44
    """An index that remembers the associations between integers subject IDs
45
    and their URIs and labels."""
46
47
    def __init__(self, corpus):
48
        """Initialize the subject index from a subject corpus."""
49
        self._uris = []
50
        self._labels = []
51
        self._uri_idx = {}
52
        self._label_idx = {}
53
        for subject_id, subject in enumerate(corpus.subjects):
54
            self._uris.append(subject.uri)
55
            self._labels.append(subject.label)
56
            self._uri_idx[subject.uri] = subject_id
57
            self._label_idx[subject.label] = subject_id
58
59
    def __len__(self):
60
        return len(self._uris)
61
62
    def __getitem__(self, subject_id):
63
        return (self._uris[subject_id], self._labels[subject_id])
64
65
    def by_uri(self, uri):
66
        """return the subject index of a subject by its URI"""
67
        try:
68
            return self._uri_idx[uri]
69
        except KeyError:
70
            logger.warning('Unknown subject URI <%s>', uri)
71
            return None
72
73
    def by_label(self, label):
74
        """return the subject index of a subject by its label"""
75
        try:
76
            return self._label_idx[label]
77
        except KeyError:
78
            logger.warning('Unknown subject label "%s"', label)
79
            return None
80
81
    def uris_to_labels(self, uris):
82
        """return a list of labels corresponding to the given URIs; unknown
83
        URIs are ignored"""
84
85
        return [self[subject_id][1]
86
                for subject_id in (self.by_uri(uri) for uri in uris)
1 ignored issue
show
Comprehensibility Best Practice introduced by
The variable uri does not seem to be defined.
Loading history...
87
                if subject_id is not None]
88
89
    def labels_to_uris(self, labels):
90
        """return a list of URIs corresponding to the given labels; unknown
91
        labels are ignored"""
92
93
        return [self[subject_id][0]
94
                for subject_id in (self.by_label(label) for label in labels)
1 ignored issue
show
Comprehensibility Best Practice introduced by
The variable label does not seem to be defined.
Loading history...
95
                if subject_id is not None]
96
97
    def save(self, path):
98
        """Save this subject index into a file."""
99
100
        with open(path, 'w', encoding='utf-8') as subjfile:
101
            for subject_id in range(len(self)):
102
                line = "<{}>\t{}".format(
103
                    self._uris[subject_id], self._labels[subject_id])
104
                print(line, file=subjfile)
105
106
    @classmethod
107
    def load(cls, path):
108
        """Load a subject index from a TSV file and return it."""
109
110
        corpus = SubjectFileTSV(path)
111
        return cls(corpus)
112
113
114
class SubjectSet:
115
    """Represents a set of subjects for a document."""
116
117
    def __init__(self, subj_data=None):
118
        """Create a SubjectSet and optionally initialize it from a tuple
119
        (URIs, labels)"""
120
121
        uris, labels = subj_data or ([], [])
122
        self.subject_uris = set(uris)
123
        self.subject_labels = set(labels)
124
125
    @classmethod
126
    def from_string(cls, subj_data):
127
        sset = cls()
128
        for line in subj_data.splitlines():
129
            sset._parse_line(line)
130
        return sset
131
132
    def _parse_line(self, line):
133
        vals = line.split("\t")
134
        for val in vals:
135
            val = val.strip()
136
            if val == '':
137
                continue
138
            if val.startswith('<') and val.endswith('>'):  # URI
139
                self.subject_uris.add(val[1:-1])
140
                continue
141
            self.subject_labels.add(val)
142
            return
143
144
    def has_uris(self):
145
        """returns True if the URIs for all subjects are known"""
146
        return len(self.subject_uris) >= len(self.subject_labels)
147
148
    def as_vector(self, subject_index):
149
        """Return the hits as a one-dimensional NumPy array in sklearn
150
           multilabel indicator format, using a subject index as the source
151
           of subjects."""
152
153
        vector = np.zeros(len(subject_index), dtype=np.int8)
154
        if self.has_uris():
155
            for uri in self.subject_uris:
156
                subject_id = subject_index.by_uri(uri)
157
                if subject_id is not None:
158
                    vector[subject_id] = 1
159
        else:
160
            for label in self.subject_labels:
161
                subject_id = subject_index.by_label(label)
162
                if subject_id is not None:
163
                    vector[subject_id] = 1
164
        return vector
165