cache_bert   A
last analyzed

Complexity

Total Complexity 10

Size/Duplication

Total Lines 51
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 10
eloc 41
dl 0
loc 51
rs 10
c 0
b 0
f 0

3 Functions

Rating   Name   Duplication   Size   Complexity  
B cache_dataset() 0 19 6
A get_parser() 0 6 1
A main() 0 8 3
1
import numpy as np
2
import h5py
3
import json
4
import logging
5
import argparse
6
7
from e2edutch import bert
8
9
10
def cache_dataset(data_path, out_file, tokenizer, model):
11
    with open(data_path) as in_file:
12
        for doc_num, line in enumerate(in_file.readlines()):
13
            example = json.loads(line)
14
            sentences = example["sentences"]
15
            bert_final = bert.encode_sentences(sentences, tokenizer, model)
16
            # shape: (num_sent, max_sent_len, lm_size, 1)
17
            text_len = np.array([len(s) for s in sentences])
18
            file_key = example["doc_key"].replace("/", ":")
19
            if file_key in out_file.keys():
20
                del out_file[file_key]
21
22
            group = out_file.create_group(file_key)
23
            for i, (e, l) in enumerate(zip(bert_final, text_len)):
24
                e = np.array(e[:l, :, :])
25
                group[str(i)] = e
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable str does not seem to be defined.
Loading history...
26
            if doc_num % 10 == 0:
27
                logging.info("Cached {} documents in {}".format(
28
                    doc_num + 1, data_path))
29
30
31
def get_parser():
32
    parser = argparse.ArgumentParser()
33
    parser.add_argument('model_name', choices=['bertje', 'bert-nl', 'robbert'])
34
    parser.add_argument('datapath')
35
    parser.add_argument('input_files', nargs='+')
36
    return parser
37
38
39
def main(args=None):
40
    args = get_parser().parse_args()
41
    model_name = args.model_name
42
    datapath = args.datapath
43
    tokenizer, model = bert.load_bert(model_name)
44
    with h5py.File("{}/{}_cache.hdf5".format(datapath, model_name), "a") as out_file:
45
        for json_filename in args.input_files:
46
            cache_dataset(json_filename, out_file, tokenizer, model)
47
48
49
if __name__ == "__main__":
50
    main()
51