Completed
Pull Request — master (#344)
by Osma
06:44
created

annif.corpus.subject.SubjectFileTSV.save_skos()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
import annif.util
4
import numpy as np
5
from annif import logger
6
from .types import Subject
7
from .skos import serialize_subjects_to_skos
8
9
10
class SubjectFileTSV:
11
    """A subject vocabulary stored in a TSV file."""
12
13
    def __init__(self, path):
14
        self.path = path
15
16
    @property
17
    def subjects(self):
18
        with open(self.path, encoding='utf-8') as subjfile:
19
            for line in subjfile:
20
                uri, label = line.strip().split(None, 1)
21
                clean_uri = annif.util.cleanup_uri(uri)
22
                yield Subject(uri=clean_uri, label=label, text=None)
23
24
    def save_skos(self, path, language):
25
        """Save the contents of the subject vocabulary into a SKOS/Turtle
26
        file with the given path name."""
27
        serialize_subjects_to_skos(self.subjects, language, path)
28
29
30
class SubjectIndex:
31
    """An index that remembers the associations between integers subject IDs
32
    and their URIs and labels."""
33
34
    def __init__(self, corpus):
35
        """Initialize the subject index from a subject corpus."""
36
        self._uris = []
37
        self._labels = []
38
        self._uri_idx = {}
39
        self._label_idx = {}
40
        for subject_id, subject in enumerate(corpus.subjects):
41
            self._uris.append(subject.uri)
42
            self._labels.append(subject.label)
43
            self._uri_idx[subject.uri] = subject_id
44
            self._label_idx[subject.label] = subject_id
45
46
    def __len__(self):
47
        return len(self._uris)
48
49
    def __getitem__(self, subject_id):
50
        return (self._uris[subject_id], self._labels[subject_id])
51
52
    def by_uri(self, uri):
53
        """return the subject index of a subject by its URI"""
54
        try:
55
            return self._uri_idx[uri]
56
        except KeyError:
57
            logger.warning('Unknown subject URI <%s>', uri)
58
            return None
59
60
    def by_label(self, label):
61
        """return the subject index of a subject by its label"""
62
        try:
63
            return self._label_idx[label]
64
        except KeyError:
65
            logger.warning('Unknown subject label "%s"', label)
66
            return None
67
68
    def uris_to_labels(self, uris):
69
        """return a list of labels corresponding to the given URIs; unknown
70
        URIs are ignored"""
71
72
        return [self[subject_id][1]
73
                for subject_id in (self.by_uri(uri) for uri in uris)
74
                if subject_id is not None]
75
76
    def labels_to_uris(self, labels):
77
        """return a list of URIs corresponding to the given labels; unknown
78
        labels are ignored"""
79
80
        return [self[subject_id][0]
81
                for subject_id in (self.by_label(label) for label in labels)
82
                if subject_id is not None]
83
84
    def save(self, path):
85
        """Save this subject index into a file."""
86
87
        with open(path, 'w', encoding='utf-8') as subjfile:
88
            for subject_id in range(len(self)):
89
                line = "<{}>\t{}".format(
90
                    self._uris[subject_id], self._labels[subject_id])
91
                print(line, file=subjfile)
92
93
    @classmethod
94
    def load(cls, path):
95
        """Load a subject index from a TSV file and return it."""
96
97
        corpus = SubjectFileTSV(path)
98
        return cls(corpus)
99
100
101
class SubjectSet:
102
    """Represents a set of subjects for a document."""
103
104
    def __init__(self, subj_data=None):
105
        """Create a SubjectSet and optionally initialize it from a tuple
106
        (URIs, labels)"""
107
108
        uris, labels = subj_data or ([], [])
109
        self.subject_uris = set(uris)
110
        self.subject_labels = set(labels)
111
112
    @classmethod
113
    def from_string(cls, subj_data):
114
        sset = cls()
115
        for line in subj_data.splitlines():
116
            sset._parse_line(line)
117
        return sset
118
119
    def _parse_line(self, line):
120
        vals = line.split("\t")
121
        for val in vals:
122
            val = val.strip()
123
            if val == '':
124
                continue
125
            if val.startswith('<') and val.endswith('>'):  # URI
126
                self.subject_uris.add(val[1:-1])
127
                continue
128
            self.subject_labels.add(val)
129
            return
130
131
    def has_uris(self):
132
        """returns True if the URIs for all subjects are known"""
133
        return len(self.subject_uris) >= len(self.subject_labels)
134
135
    def as_vector(self, subject_index):
136
        """Return the hits as a one-dimensional NumPy array in sklearn
137
           multilabel indicator format, using a subject index as the source
138
           of subjects."""
139
140
        vector = np.zeros(len(subject_index), dtype=bool)
141
        if self.has_uris():
142
            for uri in self.subject_uris:
143
                subject_id = subject_index.by_uri(uri)
144
                if subject_id is not None:
145
                    vector[subject_id] = True
146
        else:
147
            for label in self.subject_labels:
148
                subject_id = subject_index.by_label(label)
149
                if subject_id is not None:
150
                    vector[subject_id] = True
151
        return vector
152