Passed
Pull Request — master (#257)
by Osma
02:44
created

annif.corpus.subject.SubjectSet.has_uris()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import glob
4
import os.path
5
import annif.util
6
import numpy as np
7
from annif import logger
8
from .types import Subject, SubjectCorpus
9
from .convert import SubjectToDocumentCorpusMixin
10
11
12
class SubjectDirectory(SubjectCorpus, SubjectToDocumentCorpusMixin):
13
    """A subject corpus in the form of a directory with .txt files."""
14
15
    def __init__(self, path):
16
        self.path = path
17
        self._filenames = sorted(glob.glob(os.path.join(path, '*.txt')))
18
19
    @property
20
    def subjects(self):
21
        for filename in self._filenames:
22
            with open(filename) as subjfile:
23
                uri, label = subjfile.readline().strip().split(' ', 1)
24
                text = ' '.join(subjfile.readlines())
25
                yield Subject(uri=uri, label=label, text=text)
26
27
28
class SubjectFileTSV(SubjectCorpus, SubjectToDocumentCorpusMixin):
29
    """A subject corpus stored in a TSV file."""
30
31
    def __init__(self, path):
32
        self.path = path
33
34
    @property
35
    def subjects(self):
36
        with open(self.path) as subjfile:
37
            for line in subjfile:
38
                uri, label = line.strip().split(None, 1)
39
                clean_uri = annif.util.cleanup_uri(uri)
40
                yield Subject(uri=clean_uri, label=label, text=None)
41
42
43
class SubjectIndex:
44
    """An index that remembers the associations between integers subject IDs
45
    and their URIs and labels."""
46
47
    def __init__(self, corpus):
48
        """Initialize the subject index from a subject corpus."""
49
        self._uris = []
50
        self._labels = []
51
        self._uri_idx = {}
52
        self._label_idx = {}
53
        for subject_id, subject in enumerate(corpus.subjects):
54
            self._uris.append(subject.uri)
55
            self._labels.append(subject.label)
56
            self._uri_idx[subject.uri] = subject_id
57
            self._label_idx[subject.label] = subject_id
58
59
    def __len__(self):
60
        return len(self._uris)
61
62
    def __getitem__(self, subject_id):
63
        return (self._uris[subject_id], self._labels[subject_id])
64
65
    def by_uri(self, uri):
66
        """return the subject index of a subject by its URI"""
67
        try:
68
            return self._uri_idx[uri]
69
        except KeyError:
70
            logger.warning('Unknown subject URI <%s>', uri)
71
            return None
72
73
    def by_label(self, label):
74
        """return the subject index of a subject by its label"""
75
        try:
76
            return self._label_idx[label]
77
        except KeyError:
78
            logger.warning('Unknown subject label "%s"', label)
79
            return None
80
81
    def save(self, path):
82
        """Save this subject index into a file."""
83
84
        with open(path, 'w') as subjfile:
85
            for subject_id in range(len(self)):
86
                line = "<{}>\t{}".format(
87
                    self._uris[subject_id], self._labels[subject_id])
88
                print(line, file=subjfile)
89
90
    @classmethod
91
    def load(cls, path):
92
        """Load a subject index from a TSV file and return it."""
93
94
        corpus = SubjectFileTSV(path)
95
        return cls(corpus)
96
97
98
class SubjectSet:
99
    """Represents a set of subjects for a document."""
100
101
    def __init__(self, subj_data=None):
102
        """Create a SubjectSet and optionally initialize it from a tuple
103
        (URIs, labels)"""
104
105
        uris, labels = subj_data or ([], [])
106
        self.subject_uris = set(uris)
107
        self.subject_labels = set(labels)
108
109
    @classmethod
110
    def from_string(cls, subj_data):
111
        sset = cls()
112
        for line in subj_data.splitlines():
113
            sset._parse_line(line)
114
        return sset
115
116
    def _parse_line(self, line):
117
        vals = line.split("\t")
118
        for val in vals:
119
            val = val.strip()
120
            if val == '':
121
                continue
122
            if val.startswith('<') and val.endswith('>'):  # URI
123
                self.subject_uris.add(val[1:-1])
124
                continue
125
            self.subject_labels.add(val)
126
            return
127
128
    def has_uris(self):
129
        """returns True if the URIs for all subjects are known"""
130
        return len(self.subject_uris) >= len(self.subject_labels)
131
132
    def as_vector(self, subject_index):
133
        """Return the hits as a one-dimensional NumPy array in sklearn
134
           multilabel indicator format, using a subject index as the source
135
           of subjects."""
136
137
        vector = np.zeros(len(subject_index), dtype=np.int8)
138
        if self.has_uris():
139
            for uri in self.subject_uris:
140
                subject_id = subject_index.by_uri(uri)
141
                if subject_id is not None:
142
                    vector[subject_id] = 1
143
        else:
144
            for label in self.subject_labels:
145
                subject_id = subject_index.by_label(label)
146
                if subject_id is not None:
147
                    vector[subject_id] = 1
148
        return vector
149