Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

experiments.lm.Vocab   A

Complexity

Total Complexity 22

Size/Duplication

Total Lines 71
Duplicated Lines 0 %
Metric Value
dl 0
loc 71
rs 10
wmc 22

10 Methods

Rating   Name   Duplication   Size   Complexity  
A Vocab._load_fixed_size() 0 10 4
A Vocab.sent_vector() 0 3 1
A Vocab.__init__() 0 11 3
A Vocab.index() 0 5 2
A Vocab.transform() 0 4 1
A Vocab.sent_index() 0 3 1
A Vocab.word() 0 6 3
A Vocab.load() 0 10 4
A Vocab.transform_index() 0 4 1
A Vocab.add() 0 4 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
5
SENT_MARK = "</s>"
6
NULL_MARK = "<null>"
7
UNK_MARK = "<unk>"
8
9
import numpy as np
10
import logging as loggers
11
12
logging = loggers.getLogger(__name__)
13
14
15
class Vocab(object):
16
17
    def __init__(self, is_lang=True, char_based=False, null_mark=False):
18
        self.vocab_map = {}
19
        self.reversed_map = None
20
        self.size = 0
21
        self._char_based = char_based
22
        self.null_mark = null_mark
23
        if null_mark:
24
            self.add(NULL_MARK)
25
        if is_lang:
26
            self.add(SENT_MARK)
27
            self.add(UNK_MARK)
28
29
    def add(self, word):
30
        if word not in self.vocab_map:
31
            self.vocab_map[word] = self.size
32
            self.size += 1
33
34
    def index(self, word):
35
        if word in self.vocab_map:
36
            return self.vocab_map[word]
37
        else:
38
            return self.vocab_map[UNK_MARK]
39
40
    def word(self, index):
41
        if not self.reversed_map:
42
            self.reversed_map = {}
43
            for k in self.vocab_map:
44
                self.reversed_map[self.vocab_map[k]] = k
45
        return self.reversed_map[index]
46
47
    def transform(self, word):
48
        v = np.zeros(self.size, dtype=int)
49
        v[self.index(word)] = 1
50
        return v
51
52
    def transform_index(self, index):
53
        v = np.zeros(self.size, dtype=int)
54
        v[index] = 1
55
        return v
56
57
    def _load_fixed_size(self, path, max_size):
58
        from collections import Counter
59
        logging.info("fixed size: %d" % max_size)
60
        counter = Counter()
61
        for line in open(path).readlines():
62
            line = line.strip()
63
            words = line.split(" ") if not self._char_based else line
64
            counter.update(words)
65
        for w, _ in counter.most_common(max_size):
66
            self.add(w)
67
68
    def load(self, path, max_size=-1):
69
        logging.info("load data from %s" % path)
70
        if max_size > 0:
71
            self._load_fixed_size(path, max_size)
72
            return
73
        for line in open(path).xreadlines():
74
            line = line.strip()
75
            words = line.split(" ") if not self._char_based else line
76
            map(self.add, words)
77
        logging.info("vocab size: %d" % self.size)
78
79
    @property
80
    def sent_index(self):
81
        return self.index(SENT_MARK)
82
83
    @property
84
    def sent_vector(self):
85
        return self.transform(SENT_MARK)
86
87
88