experiments.lm.LMDataset.read_data()   F
last analyzed

Complexity

Conditions 12

Size

Total Lines 28

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 12
dl 0
loc 28
rs 2.7856

How to fix   Complexity   

Complexity

Complex classes like experiments.lm.LMDataset.read_data() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
import numpy as np
6
from deepy.dataset import Dataset
7
from deepy.utils import FLOATX
8
9
logging = loggers.getLogger(__name__)
10
11
class LMDataset(Dataset):
12
13
14
    def __init__(self, vocab, train_path, valid_path, history_len=-1, char_based=False, max_tokens=999, min_tokens=0, sort=True):
15
        """
16
        Generate data for training with RNN
17
        :type vocab: vocab.Vocab
18
        """
19
        assert history_len == -1 or history_len >= 1
20
        self.vocab = vocab
21
        self.history_len = history_len
22
        self.char_based = char_based
23
        self.sort = sort
24
25
        self.min_tokens = min_tokens
26
        self.max_tokens = max_tokens
27
28
        self._train_set = self.read_data(train_path)
29
        self._valid_set = self.read_data(valid_path)
30
31
    def train_set(self):
32
        return self._train_set
33
34
    def valid_set(self):
35
        return self._valid_set
36
37
    def read_data(self, path):
38
        data = []
39
        sent_count = 0
40
        for line in open(path).xreadlines():
41
            line = line.strip()
42
            wc = len(line) if self.char_based else line.count(" ") + 1
43
            if wc < self.min_tokens or wc > self.max_tokens:
44
                continue
45
            sent_count += 1
46
            sequence = [self.vocab.sent_index]
47
            tokens = line if self.char_based else line.split(" ")
48
            for w in tokens:
49
                sequence.append(self.vocab.index(w))
50
            sequence.append(self.vocab.sent_index)
51
52
            if self.history_len == -1:
53
                # Full sentence
54
                data.append(self.convert_to_data(sequence))
55
            else:
56
                # trunk by trunk
57
                for begin in range(0, len(sequence), self.history_len):
58
                    trunk = sequence[begin: begin + self.history_len + 1]
59
                    if len(trunk) > 1:
60
                        data.append(self.convert_to_data(trunk))
61
        if self.sort:
62
            data.sort(key=lambda x: len(x[1]))
63
        logging.info("loaded from %s: %d sentences, %d data pieces " % (path, sent_count, len(data)))
64
        return data
65
66
    def convert_to_data(self, seq):
67
        assert len(seq) >= 2
68
        input_indices = seq[:-1]
69
        target_indices = seq[1:]
70
        return input_indices, target_indices
71
72
73