Passed
Push — issue735-subject-filtering ( 3cd194 )
by Osma
03:54
created

annif.corpus.subject   A

Complexity

Total Complexity 37

Size/Duplication

Total Lines 194
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 120
dl 0
loc 194
rs 9.44
c 0
b 0
f 0
wmc 37

19 Methods

Rating   Name   Duplication   Size   Complexity  
A SubjectFileTSV.__init__() 0 6 1
A SubjectSet.__bool__() 0 2 1
A SubjectFileTSV.languages() 0 3 1
A SubjectFileCSV._parse_row() 0 16 2
A SubjectFileCSV.__init__() 0 3 1
A SubjectSet.__len__() 0 2 1
A SubjectSet.__eq__() 0 5 2
A SubjectFileTSV.subjects() 0 5 3
A SubjectFileCSV.is_csv_file() 0 5 1
A SubjectFileCSV.languages() 0 11 2
A SubjectFileTSV.save_skos() 0 4 1
A SubjectSet.as_vector() 0 16 2
A SubjectFileCSV.subjects() 0 6 3
A SubjectFileCSV.save_skos() 0 4 1
A SubjectSet.__getitem__() 0 2 1
A SubjectFileTSV._parse_line() 0 7 4
A SubjectSet.from_string() 0 12 3
A SubjectSet.__init__() 0 11 2
A SubjectSet._parse_line() 0 16 5
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
from __future__ import annotations
4
5
import csv
6
import os.path
7
from typing import TYPE_CHECKING, Any
8
9
import annif
10
import annif.util
11
12
from .skos import serialize_subjects_to_skos
13
from .types import Subject, SubjectCorpus
14
15
if TYPE_CHECKING:
16
    from collections.abc import Generator, Iterator
17
18
    import numpy as np
19
20
    from annif.vocab import SubjectIndex
21
22
23
logger = annif.logger.getChild("subject")
24
logger.addFilter(annif.util.DuplicateFilter())
25
26
27
class SubjectFileTSV(SubjectCorpus):
28
    """A monolingual subject vocabulary stored in a TSV file."""
29
30
    def __init__(self, path: str, language: str) -> None:
31
        """initialize the SubjectFileTSV given a path to a TSV file and the
32
        language of the vocabulary"""
33
34
        self.path = path
35
        self.language = language
36
37
    def _parse_line(self, line: str) -> Iterator[Subject]:
38
        vals = line.strip().split("\t", 2)
39
        clean_uri = annif.util.cleanup_uri(vals[0])
40
        label = vals[1] if len(vals) >= 2 else None
41
        labels = {self.language: label} if label else None
42
        notation = vals[2] if len(vals) >= 3 else None
43
        yield Subject(uri=clean_uri, labels=labels, notation=notation)
44
45
    @property
46
    def languages(self) -> list[str]:
47
        return [self.language]
48
49
    @property
50
    def subjects(self) -> Generator:
51
        with open(self.path, encoding="utf-8-sig") as subjfile:
52
            for line in subjfile:
53
                yield from self._parse_line(line)
54
55
    def save_skos(self, path: str) -> None:
56
        """Save the contents of the subject vocabulary into a SKOS/Turtle
57
        file with the given path name."""
58
        serialize_subjects_to_skos(self.subjects, path)
59
60
61
class SubjectFileCSV(SubjectCorpus):
62
    """A multilingual subject vocabulary stored in a CSV file."""
63
64
    def __init__(self, path: str) -> None:
65
        """initialize the SubjectFileCSV given a path to a CSV file"""
66
        self.path = path
67
68
    def _parse_row(self, row: dict[str, str]) -> Iterator[Subject]:
69
        labels = {
70
            fname.replace("label_", ""): value or None
71
            for fname, value in row.items()
72
            if fname.startswith("label_")
73
        }
74
75
        # if there are no labels in any language, set labels to None
76
        # indicating a deprecated subject
77
        if set(labels.values()) == {None}:
78
            labels = None
79
80
        yield Subject(
81
            uri=annif.util.cleanup_uri(row["uri"]),
82
            labels=labels,
83
            notation=row.get("notation", None) or None,
84
        )
85
86
    @property
87
    def languages(self) -> list[str]:
88
        # infer the supported languages from the CSV column names
89
        with open(self.path, encoding="utf-8-sig") as csvfile:
90
            reader = csv.reader(csvfile)
91
            fieldnames = next(reader, None)
92
93
        return [
94
            fname.replace("label_", "")
95
            for fname in fieldnames
96
            if fname.startswith("label_")
97
        ]
98
99
    @property
100
    def subjects(self) -> Generator:
101
        with open(self.path, encoding="utf-8-sig") as csvfile:
102
            reader = csv.DictReader(csvfile)
103
            for row in reader:
104
                yield from self._parse_row(row)
105
106
    def save_skos(self, path: str) -> None:
107
        """Save the contents of the subject vocabulary into a SKOS/Turtle
108
        file with the given path name."""
109
        serialize_subjects_to_skos(self.subjects, path)
110
111
    @staticmethod
112
    def is_csv_file(path: str) -> bool:
113
        """return True if the path looks like a CSV file"""
114
115
        return os.path.splitext(path)[1].lower() == ".csv"
116
117
118
class SubjectSet:
119
    """Represents a set of subjects for a document."""
120
121
    def __init__(self, subject_ids: Any | None = None) -> None:
122
        """Create a SubjectSet and optionally initialize it from an iterable
123
        of subject IDs"""
124
125
        if subject_ids:
126
            # use set comprehension to eliminate possible duplicates
127
            self._subject_ids = list(
128
                {subject_id for subject_id in subject_ids if subject_id is not None}
129
            )
130
        else:
131
            self._subject_ids = []
132
133
    def __len__(self) -> int:
134
        return len(self._subject_ids)
135
136
    def __getitem__(self, idx: int) -> int:
137
        return self._subject_ids[idx]
138
139
    def __bool__(self) -> bool:
140
        return bool(self._subject_ids)
141
142
    def __eq__(self, other: Any) -> bool:
143
        if isinstance(other, SubjectSet):
144
            return self._subject_ids == other._subject_ids
145
146
        return False
147
148
    @classmethod
149
    def from_string(
150
        cls, subj_data: str, subject_index: SubjectIndex, language: str
151
    ) -> SubjectSet:
152
        subject_ids = set()
153
        for line in subj_data.splitlines():
154
            uri, label = cls._parse_line(line)
155
            if uri is not None:
156
                subject_ids.add(subject_index.by_uri(uri))
157
            else:
158
                subject_ids.add(subject_index.by_label(label, language))
159
        return cls(subject_ids)
160
161
    @staticmethod
162
    def _parse_line(
163
        line: str,
164
    ) -> tuple[str | None, str | None]:
165
        uri = label = None
166
        vals = line.split("\t")
167
        for val in vals:
168
            val = val.strip()
169
            if val == "":
170
                continue
171
            if val.startswith("<") and val.endswith(">"):  # URI
172
                uri = val[1:-1]
173
                continue
174
            label = val
175
            break
176
        return uri, label
177
178
    def as_vector(
179
        self, size: int | None = None, destination: np.ndarray | None = None
0 ignored issues
show
introduced by
The variable np does not seem to be defined in case TYPE_CHECKING on line 15 is False. Are you sure this can never be the case?
Loading history...
180
    ) -> np.ndarray:
181
        """Return the hits as a one-dimensional NumPy array in sklearn
182
        multilabel indicator format. Use destination array if given (not
183
        None), otherwise create and return a new one of the given size."""
184
185
        if destination is None:
186
            import numpy as np
187
188
            assert size is not None and size > 0
189
            destination = np.zeros(size, dtype=bool)
190
191
        destination[list(self._subject_ids)] = True
192
193
        return destination
194