LMDataset   A
last analyzed

Complexity

Total Complexity 18

Size/Duplication

Total Lines 60
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 60
rs 10
wmc 18

5 Methods

Rating   Name   Duplication   Size   Complexity  
A valid_set() 0 2 1
A __init__() 0 16 2
A convert_to_data() 0 5 2
A train_set() 0 2 1
F read_data() 0 28 12
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
import numpy as np
6
from deepy.dataset import Dataset
7
from deepy.core.env import FLOATX
8
logging = loggers.getLogger(__name__)
9
10
class LMDataset(Dataset):
11
12
13
    def __init__(self, vocab, train_path, valid_path, history_len=-1, char_based=False, max_tokens=999, min_tokens=0, sort=True):
14
        """
15
        Generate data for training with RNN
16
        :type vocab: vocab.Vocab
17
        """
18
        assert history_len == -1 or history_len >= 1
19
        self.vocab = vocab
20
        self.history_len = history_len
21
        self.char_based = char_based
22
        self.sort = sort
23
24
        self.min_tokens = min_tokens
25
        self.max_tokens = max_tokens
26
27
        self._train_set = self.read_data(train_path)
28
        self._valid_set = self.read_data(valid_path)
29
30
    def train_set(self):
31
        return self._train_set
32
33
    def valid_set(self):
34
        return self._valid_set
35
36
    def read_data(self, path):
37
        data = []
38
        sent_count = 0
39
        for line in open(path).xreadlines():
40
            line = line.strip()
41
            wc = len(line) if self.char_based else line.count(" ") + 1
42
            if wc < self.min_tokens or wc > self.max_tokens:
43
                continue
44
            sent_count += 1
45
            sequence = [self.vocab.sent_index]
46
            tokens = line if self.char_based else line.split(" ")
47
            for w in tokens:
48
                sequence.append(self.vocab.index(w))
49
            sequence.append(self.vocab.sent_index)
50
51
            if self.history_len == -1:
52
                # Full sentence
53
                data.append(self.convert_to_data(sequence))
54
            else:
55
                # trunk by trunk
56
                for begin in range(0, len(sequence), self.history_len):
57
                    trunk = sequence[begin: begin + self.history_len + 1]
58
                    if len(trunk) > 1:
59
                        data.append(self.convert_to_data(trunk))
60
        if self.sort:
61
            data.sort(key=lambda x: len(x[1]))
62
        logging.info("loaded from %s: %d sentences, %d data pieces " % (path, sent_count, len(data)))
63
        return data
64
65
    def convert_to_data(self, seq):
66
        assert len(seq) >= 2
67
        input_indices = seq[:-1]
68
        target_indices = seq[1:]
69
        return input_indices, target_indices
70
71
72