LMDataset.read_data()   F
last analyzed

Complexity

Conditions 12

Size

Total Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 12
c 1
b 0
f 0
dl 0
loc 28
rs 2.7855

How to fix   Complexity   

Complexity

Complex classes like LMDataset.read_data() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
import numpy as np
6
from deepy.dataset import Dataset
7
from deepy.core.env import FLOATX
8
logging = loggers.getLogger(__name__)
9
10
class LMDataset(Dataset):
11
12
13
    def __init__(self, vocab, train_path, valid_path, history_len=-1, char_based=False, max_tokens=999, min_tokens=0, sort=True):
14
        """
15
        Generate data for training with RNN
16
        :type vocab: vocab.Vocab
17
        """
18
        assert history_len == -1 or history_len >= 1
19
        self.vocab = vocab
20
        self.history_len = history_len
21
        self.char_based = char_based
22
        self.sort = sort
23
24
        self.min_tokens = min_tokens
25
        self.max_tokens = max_tokens
26
27
        self._train_set = self.read_data(train_path)
28
        self._valid_set = self.read_data(valid_path)
29
30
    def train_set(self):
31
        return self._train_set
32
33
    def valid_set(self):
34
        return self._valid_set
35
36
    def read_data(self, path):
37
        data = []
38
        sent_count = 0
39
        for line in open(path).xreadlines():
40
            line = line.strip()
41
            wc = len(line) if self.char_based else line.count(" ") + 1
42
            if wc < self.min_tokens or wc > self.max_tokens:
43
                continue
44
            sent_count += 1
45
            sequence = [self.vocab.sent_index]
46
            tokens = line if self.char_based else line.split(" ")
47
            for w in tokens:
48
                sequence.append(self.vocab.index(w))
49
            sequence.append(self.vocab.sent_index)
50
51
            if self.history_len == -1:
52
                # Full sentence
53
                data.append(self.convert_to_data(sequence))
54
            else:
55
                # trunk by trunk
56
                for begin in range(0, len(sequence), self.history_len):
57
                    trunk = sequence[begin: begin + self.history_len + 1]
58
                    if len(trunk) > 1:
59
                        data.append(self.convert_to_data(trunk))
60
        if self.sort:
61
            data.sort(key=lambda x: len(x[1]))
62
        logging.info("loaded from %s: %d sentences, %d data pieces " % (path, sent_count, len(data)))
63
        return data
64
65
    def convert_to_data(self, seq):
66
        assert len(seq) >= 2
67
        input_indices = seq[:-1]
68
        target_indices = seq[1:]
69
        return input_indices, target_indices
70
71
72