|
1
|
|
|
import numpy as np |
|
2
|
|
|
import h5py |
|
3
|
|
|
import json |
|
4
|
|
|
import logging |
|
5
|
|
|
import argparse |
|
6
|
|
|
|
|
7
|
|
|
from e2edutch import bert |
|
8
|
|
|
|
|
9
|
|
|
|
|
10
|
|
|
def cache_dataset(data_path, out_file, tokenizer, model): |
|
11
|
|
|
with open(data_path) as in_file: |
|
12
|
|
|
for doc_num, line in enumerate(in_file.readlines()): |
|
13
|
|
|
example = json.loads(line) |
|
14
|
|
|
sentences = example["sentences"] |
|
15
|
|
|
bert_final = bert.encode_sentences(sentences, tokenizer, model) |
|
16
|
|
|
# shape: (num_sent, max_sent_len, lm_size, 1) |
|
17
|
|
|
text_len = np.array([len(s) for s in sentences]) |
|
18
|
|
|
file_key = example["doc_key"].replace("/", ":") |
|
19
|
|
|
if file_key in out_file.keys(): |
|
20
|
|
|
del out_file[file_key] |
|
21
|
|
|
|
|
22
|
|
|
group = out_file.create_group(file_key) |
|
23
|
|
|
for i, (e, l) in enumerate(zip(bert_final, text_len)): |
|
24
|
|
|
e = np.array(e[:l, :, :]) |
|
25
|
|
|
group[str(i)] = e |
|
|
|
|
|
|
26
|
|
|
if doc_num % 10 == 0: |
|
27
|
|
|
logging.info("Cached {} documents in {}".format( |
|
28
|
|
|
doc_num + 1, data_path)) |
|
29
|
|
|
|
|
30
|
|
|
|
|
31
|
|
|
def get_parser(): |
|
32
|
|
|
parser = argparse.ArgumentParser() |
|
33
|
|
|
parser.add_argument('model_name', choices=['bertje', 'bert-nl', 'robbert']) |
|
34
|
|
|
parser.add_argument('datapath') |
|
35
|
|
|
parser.add_argument('input_files', nargs='+') |
|
36
|
|
|
return parser |
|
37
|
|
|
|
|
38
|
|
|
|
|
39
|
|
|
def main(args=None): |
|
40
|
|
|
args = get_parser().parse_args() |
|
41
|
|
|
model_name = args.model_name |
|
42
|
|
|
datapath = args.datapath |
|
43
|
|
|
tokenizer, model = bert.load_bert(model_name) |
|
44
|
|
|
with h5py.File("{}/{}_cache.hdf5".format(datapath, model_name), "a") as out_file: |
|
45
|
|
|
for json_filename in args.input_files: |
|
46
|
|
|
cache_dataset(json_filename, out_file, tokenizer, model) |
|
47
|
|
|
|
|
48
|
|
|
|
|
49
|
|
|
if __name__ == "__main__": |
|
50
|
|
|
main() |
|
51
|
|
|
|