1
|
|
|
import numpy as np |
2
|
|
|
import h5py |
3
|
|
|
import json |
4
|
|
|
import logging |
5
|
|
|
import argparse |
6
|
|
|
|
7
|
|
|
from e2edutch import bert |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
def cache_dataset(data_path, out_file, tokenizer, model): |
11
|
|
|
with open(data_path) as in_file: |
12
|
|
|
for doc_num, line in enumerate(in_file.readlines()): |
13
|
|
|
example = json.loads(line) |
14
|
|
|
sentences = example["sentences"] |
15
|
|
|
bert_final = bert.encode_sentences(sentences, tokenizer, model) |
16
|
|
|
# shape: (num_sent, max_sent_len, lm_size, 1) |
17
|
|
|
text_len = np.array([len(s) for s in sentences]) |
18
|
|
|
file_key = example["doc_key"].replace("/", ":") |
19
|
|
|
if file_key in out_file.keys(): |
20
|
|
|
del out_file[file_key] |
21
|
|
|
|
22
|
|
|
group = out_file.create_group(file_key) |
23
|
|
|
for i, (e, l) in enumerate(zip(bert_final, text_len)): |
24
|
|
|
e = np.array(e[:l, :, :]) |
25
|
|
|
group[str(i)] = e |
|
|
|
|
26
|
|
|
if doc_num % 10 == 0: |
27
|
|
|
logging.info("Cached {} documents in {}".format( |
28
|
|
|
doc_num + 1, data_path)) |
29
|
|
|
|
30
|
|
|
|
31
|
|
|
def get_parser(): |
32
|
|
|
parser = argparse.ArgumentParser() |
33
|
|
|
parser.add_argument('model_name', choices=['bertje', 'bert-nl', 'robbert']) |
34
|
|
|
parser.add_argument('datapath') |
35
|
|
|
parser.add_argument('input_files', nargs='+') |
36
|
|
|
return parser |
37
|
|
|
|
38
|
|
|
|
39
|
|
|
def main(args=None): |
40
|
|
|
args = get_parser().parse_args() |
41
|
|
|
model_name = args.model_name |
42
|
|
|
datapath = args.datapath |
43
|
|
|
tokenizer, model = bert.load_bert(model_name) |
44
|
|
|
with h5py.File("{}/{}_cache.hdf5".format(datapath, model_name), "a") as out_file: |
45
|
|
|
for json_filename in args.input_files: |
46
|
|
|
cache_dataset(json_filename, out_file, tokenizer, model) |
47
|
|
|
|
48
|
|
|
|
49
|
|
|
if __name__ == "__main__": |
50
|
|
|
main() |
51
|
|
|
|