Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

TextFile.get_data()   C

Complexity

Conditions 8

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 8
dl 0
loc 16
rs 6.6666
1
from picklable_itertools import iter_, chain
2
3
from fuel.datasets import Dataset
4
from fuel.utils.formats import open_
5
6
7
class TextFile(Dataset):
8
    r"""Reads text files and numberizes them given a dictionary.
9
10
    Parameters
11
    ----------
12
    files : list of str
13
        The names of the files in order which they should be read. Each
14
        file is expected to have a sentence per line. If the filename ends
15
        with `.gz` it will be opened using `gzip`. Note however that `gzip`
16
        file handles aren't picklable on legacy Python.
17
    dictionary : str or dict
18
        Either the path to a Pickled dictionary mapping tokens to integers,
19
        or the dictionary itself. At the very least this dictionary must
20
        map the unknown word-token to an integer.
21
    bos_token : str or None, optional
22
        The beginning-of-sentence (BOS) token in the dictionary that
23
        denotes the beginning of a sentence. Is ``<S>`` by default. If
24
        passed ``None`` no beginning of sentence markers will be added.
25
    eos_token : str or None, optional
26
        The end-of-sentence (EOS) token is ``</S>`` by default, see
27
        ``bos_taken``.
28
    unk_token : str, optional
29
        The token in the dictionary to fall back on when a token could not
30
        be found in the dictionary. ``<UNK>`` by default. Pass ``None`` if
31
        the dataset doesn't contain any out-of-vocabulary words/characters
32
        (the data request is going to crash if meets an unknown symbol).
33
34
    level : 'word' or 'character', optional
35
        If 'word' the dictionary is expected to contain full words. The
36
        sentences in the text file will be split at the spaces, and each
37
        word replaced with its number as given by the dictionary, resulting
38
        in each example being a single list of numbers. If 'character' the
39
        dictionary is expected to contain single letters as keys. A single
40
        example will be a list of character numbers, starting with the
41
        first non-whitespace character and finishing with the last one. The
42
        default is 'word'.
43
    preprocess : function, optional
44
        A function which takes a sentence (string) as an input and returns
45
        a modified string. For example ``str.lower`` in order to lowercase
46
        the sentence before numberizing.
47
    encoding : str, optional
48
        The encoding to use to read the file. Defaults to ``None``. Use
49
        UTF-8 if the dictionary you pass contains UTF-8 characters, but
50
        note that this makes the dataset unpicklable on legacy Python.
51
52
    Examples
53
    --------
54
    >>> with open('sentences.txt', 'w') as f:
55
    ...     _ = f.write("This is a sentence\n")
56
    ...     _ = f.write("This another one")
57
    >>> dictionary = {'<UNK>': 0, '</S>': 1, 'this': 2, 'a': 3, 'one': 4}
58
    >>> def lower(s):
59
    ...     return s.lower()
60
    >>> text_data = TextFile(files=['sentences.txt'],
61
    ...                      dictionary=dictionary, bos_token=None,
62
    ...                      preprocess=lower)
63
    >>> from fuel.streams import DataStream
64
    >>> for data in DataStream(text_data).get_epoch_iterator():
65
    ...     print(data)
66
    ([2, 0, 3, 0, 1],)
67
    ([2, 0, 4, 1],)
68
    >>> full_dictionary = {'this': 0, 'a': 3, 'is': 4, 'sentence': 5,
69
    ...                    'another': 6, 'one': 7}
70
    >>> text_data = TextFile(files=['sentences.txt'],
71
    ...                      dictionary=full_dictionary, bos_token=None,
72
    ...                      eos_token=None, unk_token=None,
73
    ...                      preprocess=lower)
74
    >>> for data in DataStream(text_data).get_epoch_iterator():
75
    ...     print(data)
76
    ([0, 4, 3, 5],)
77
    ([0, 6, 7],)
78
79
    .. doctest::
80
       :hide:
81
82
       >>> import os
83
       >>> os.remove('sentences.txt')
84
85
    """
86
    provides_sources = ('features',)
87
    example_iteration_scheme = None
88
89
    def __init__(self, files, dictionary, bos_token='<S>', eos_token='</S>',
90
                 unk_token='<UNK>', level='word', preprocess=None,
91
                 encoding=None):
92
        self.files = files
93
        self.dictionary = dictionary
94
        if bos_token is not None and bos_token not in dictionary:
95
            raise ValueError(
96
                "BOS token '{}' is not in the dictionary".format(bos_token))
97
        self.bos_token = bos_token
98
        if eos_token is not None and eos_token not in dictionary:
99
            raise ValueError(
100
                "EOS token '{}' is not in the dictionary".format(eos_token))
101
        self.eos_token = eos_token
102
        if unk_token is not None and unk_token not in dictionary:
103
            raise ValueError(
104
                "UNK token '{}' is not in the dictionary".format(unk_token))
105
        self.unk_token = unk_token
106
        if level not in ('word', 'character'):
107
            raise ValueError(
108
                "level should be 'word' or 'character', not '{}'"
109
                .format(level))
110
        self.level = level
111
        self.preprocess = preprocess
112
        self.encoding = encoding
113
        super(TextFile, self).__init__()
114
115
    def open(self):
116
        return chain(*[iter_(open_(f, encoding=self.encoding))
117
                       for f in self.files])
118
119
    def _get_from_dictionary(self, symbol):
120
        value = self.dictionary.get(symbol)
121
        if value is not None:
122
            return value
123
        else:
124
            if self.unk_token is None:
125
                raise KeyError("token '{}' not found in dictionary and no "
126
                               "`unk_token` given".format(symbol))
127
            return self.dictionary[self.unk_token]
128
129
    def get_data(self, state=None, request=None):
130
        if request is not None:
131
            raise ValueError
132
        sentence = next(state)
133
        if self.preprocess is not None:
134
            sentence = self.preprocess(sentence)
135
        data = [self.dictionary[self.bos_token]] if self.bos_token else []
136
        if self.level == 'word':
137
            data.extend(self._get_from_dictionary(word)
138
                        for word in sentence.split())
139
        else:
140
            data.extend(self._get_from_dictionary(char)
141
                        for char in sentence.strip())
142
        if self.eos_token:
143
            data.append(self.dictionary[self.eos_token])
144
        return (data,)
145