Passed
Pull Request — main (#840)
by Osma
06:44 queued 03:36
created

annif.corpus.subject.SubjectFileTSV.subjects()   A

Complexity

Conditions 3

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 3
nop 1
1
"""Classes for supporting subject corpora expressed as directories or files"""
2
3
from __future__ import annotations
4
5
import csv
6
import os.path
7
from typing import TYPE_CHECKING, Any
8
9
import annif
10
import annif.util
11
12
from .skos import serialize_subjects_to_skos
13
from .types import Subject, SubjectCorpus
14
15
if TYPE_CHECKING:
16
    from collections.abc import Generator, Iterator
17
18
    import numpy as np
19
20
    from annif.vocab import SubjectIndex
21
22
23
class SubjectFileTSV(SubjectCorpus):
24
    """A monolingual subject vocabulary stored in a TSV file."""
25
26
    def __init__(self, path: str, language: str) -> None:
27
        """initialize the SubjectFileTSV given a path to a TSV file and the
28
        language of the vocabulary"""
29
30
        self.path = path
31
        self.language = language
32
33
    def _parse_line(self, line: str) -> Iterator[Subject]:
34
        vals = line.strip().split("\t", 2)
35
        clean_uri = annif.util.cleanup_uri(vals[0])
36
        label = vals[1] if len(vals) >= 2 else None
37
        labels = {self.language: label} if label else None
38
        notation = vals[2] if len(vals) >= 3 else None
39
        yield Subject(uri=clean_uri, labels=labels, notation=notation)
40
41
    @property
42
    def languages(self) -> list[str]:
43
        return [self.language]
44
45
    @property
46
    def subjects(self) -> Generator:
47
        with open(self.path, encoding="utf-8-sig") as subjfile:
48
            for line in subjfile:
49
                yield from self._parse_line(line)
50
51
    def save_skos(self, path: str) -> None:
52
        """Save the contents of the subject vocabulary into a SKOS/Turtle
53
        file with the given path name."""
54
        serialize_subjects_to_skos(self.subjects, path)
55
56
57
class SubjectFileCSV(SubjectCorpus):
58
    """A multilingual subject vocabulary stored in a CSV file."""
59
60
    def __init__(self, path: str) -> None:
61
        """initialize the SubjectFileCSV given a path to a CSV file"""
62
        self.path = path
63
64
    def _parse_row(self, row: dict[str, str]) -> Iterator[Subject]:
65
        labels = {
66
            fname.replace("label_", ""): value or None
67
            for fname, value in row.items()
68
            if fname.startswith("label_")
69
        }
70
71
        # if there are no labels in any language, set labels to None
72
        # indicating a deprecated subject
73
        if set(labels.values()) == {None}:
74
            labels = None
75
76
        yield Subject(
77
            uri=annif.util.cleanup_uri(row["uri"]),
78
            labels=labels,
79
            notation=row.get("notation", None) or None,
80
        )
81
82
    @property
83
    def languages(self) -> list[str]:
84
        # infer the supported languages from the CSV column names
85
        with open(self.path, encoding="utf-8-sig") as csvfile:
86
            reader = csv.reader(csvfile)
87
            fieldnames = next(reader, None)
88
89
        return [
90
            fname.replace("label_", "")
91
            for fname in fieldnames
92
            if fname.startswith("label_")
93
        ]
94
95
    @property
96
    def subjects(self) -> Generator:
97
        with open(self.path, encoding="utf-8-sig") as csvfile:
98
            reader = csv.DictReader(csvfile)
99
            for row in reader:
100
                yield from self._parse_row(row)
101
102
    def save_skos(self, path: str) -> None:
103
        """Save the contents of the subject vocabulary into a SKOS/Turtle
104
        file with the given path name."""
105
        serialize_subjects_to_skos(self.subjects, path)
106
107
    @staticmethod
108
    def is_csv_file(path: str) -> bool:
109
        """return True if the path looks like a CSV file"""
110
111
        return os.path.splitext(path)[1].lower() == ".csv"
112
113
114
class SubjectSet:
115
    """Represents a set of subjects for a document."""
116
117
    def __init__(self, subject_ids: Any | None = None) -> None:
118
        """Create a SubjectSet and optionally initialize it from an iterable
119
        of subject IDs"""
120
121
        if subject_ids:
122
            # use set comprehension to eliminate possible duplicates
123
            self._subject_ids = list(
124
                {subject_id for subject_id in subject_ids if subject_id is not None}
125
            )
126
        else:
127
            self._subject_ids = []
128
129
    def __len__(self) -> int:
130
        return len(self._subject_ids)
131
132
    def __getitem__(self, idx: int) -> int:
133
        return self._subject_ids[idx]
134
135
    def __bool__(self) -> bool:
136
        return bool(self._subject_ids)
137
138
    def __eq__(self, other: Any) -> bool:
139
        if isinstance(other, SubjectSet):
140
            return self._subject_ids == other._subject_ids
141
142
        return False
143
144
    @classmethod
145
    def from_string(
146
        cls, subj_data: str, subject_index: SubjectIndex, language: str
147
    ) -> SubjectSet:
148
        subject_ids = set()
149
        for line in subj_data.splitlines():
150
            uri, label = cls._parse_line(line)
151
            if uri is not None:
152
                subject_ids.add(subject_index.by_uri(uri))
153
            else:
154
                subject_ids.add(subject_index.by_label(label, language))
155
        return cls(subject_ids)
156
157
    @staticmethod
158
    def _parse_line(
159
        line: str,
160
    ) -> tuple[str | None, str | None]:
161
        uri = label = None
162
        vals = line.split("\t")
163
        for val in vals:
164
            val = val.strip()
165
            if val == "":
166
                continue
167
            if val.startswith("<") and val.endswith(">"):  # URI
168
                uri = val[1:-1]
169
                continue
170
            label = val
171
            break
172
        return uri, label
173
174
    def as_vector(
175
        self, size: int | None = None, destination: np.ndarray | None = None
0 ignored issues
show
introduced by
The variable np does not seem to be defined in case TYPE_CHECKING on line 15 is False. Are you sure this can never be the case?
Loading history...
176
    ) -> np.ndarray:
177
        """Return the hits as a one-dimensional NumPy array in sklearn
178
        multilabel indicator format. Use destination array if given (not
179
        None), otherwise create and return a new one of the given size."""
180
181
        if destination is None:
182
            import numpy as np
183
184
            assert size is not None and size > 0
185
            destination = np.zeros(size, dtype=bool)
186
187
        destination[list(self._subject_ids)] = True
188
189
        return destination
190