1
|
|
|
import requests |
|
|
|
|
2
|
|
|
import gzip |
|
|
|
|
3
|
|
|
import zipfile |
|
|
|
|
4
|
|
|
import argparse |
|
|
|
|
5
|
|
|
import logging |
|
|
|
|
6
|
|
|
from tqdm import tqdm |
7
|
|
|
from pathlib import Path |
|
|
|
|
8
|
|
|
from e2edutch import util |
9
|
|
|
|
10
|
|
|
logger = logging.getLogger() |
11
|
|
|
|
12
|
|
|
|
13
|
|
|
def download_file(url, path): |
14
|
|
|
""" |
15
|
|
|
Download a URL into a file as specified by `path`. |
16
|
|
|
""" |
17
|
|
|
# This function is copied from stanza |
18
|
|
|
# https://github.com/stanfordnlp/stanza/blob/f0338f891a03e242c7e11e440dec6e191d54ab77/stanza/resources/common.py#L103 |
19
|
|
|
r = requests.get(url, stream=True) |
|
|
|
|
20
|
|
|
with open(path, 'wb') as f: |
|
|
|
|
21
|
|
|
file_size = int(r.headers.get('content-length')) |
22
|
|
|
default_chunk_size = 131072 |
23
|
|
|
desc = 'Downloading ' + url |
24
|
|
|
with tqdm(total=file_size, unit='B', unit_scale=True, |
25
|
|
|
desc=desc) as pbar: |
26
|
|
|
for chunk in r.iter_content(chunk_size=default_chunk_size): |
27
|
|
|
if chunk: |
28
|
|
|
f.write(chunk) |
29
|
|
|
f.flush() |
30
|
|
|
pbar.update(len(chunk)) |
31
|
|
|
|
32
|
|
|
|
33
|
|
|
def download_data(config={}): |
|
|
|
|
34
|
|
|
# Create the data directory if it doesn't exist yet |
35
|
|
|
data_dir = Path(config['datapath']) |
36
|
|
|
logger.info('Downloading to {}'.format(data_dir)) |
|
|
|
|
37
|
|
|
data_dir.mkdir(parents=True, exist_ok=True) |
38
|
|
|
|
39
|
|
|
# Download word vectors |
40
|
|
|
logger.info('Download word vectors') |
41
|
|
|
url = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz" |
42
|
|
|
fname = data_dir / 'fasttext.300.vec' |
43
|
|
|
fname_gz = data_dir / 'fasttext.300.vec.gz' |
44
|
|
|
if not fname.exists(): |
45
|
|
|
download_file(url, fname_gz) |
46
|
|
|
with gzip.open(fname_gz, 'rb') as fin: |
47
|
|
|
with open(fname, 'wb') as fout: |
48
|
|
|
# We need to remove the first line |
49
|
|
|
for i, line in enumerate(fin.readlines()): |
50
|
|
|
if i > 0: |
51
|
|
|
fout.write(line) |
52
|
|
|
# Remove gz file |
53
|
|
|
fname_gz.unlink() |
54
|
|
|
else: |
55
|
|
|
logger.info('Word vectors file already exists') |
56
|
|
|
|
57
|
|
|
# Download e2e dutch model_ |
58
|
|
|
logger.info('Download e2e model') |
59
|
|
|
url = "https://surfdrive.surf.nl/files/index.php/s/UnZMyDrBEFunmQZ/download" |
60
|
|
|
fname_zip = data_dir / 'model.zip' |
61
|
|
|
log_dir_name = data_dir / 'final' |
62
|
|
|
model_file = log_dir_name / 'model.max.ckpt.index' |
63
|
|
|
if not fname_zip.exists() and not model_file.exists(): |
64
|
|
|
download_file(url, fname_zip) |
65
|
|
|
if not model_file.exists(): |
66
|
|
|
with zipfile.ZipFile(fname_zip, 'r') as zfile: |
67
|
|
|
zfile.extractall(data_dir) |
68
|
|
|
Path(data_dir / 'logs' / 'final').rename(log_dir_name) |
69
|
|
|
Path(data_dir, 'logs').rmdir() |
70
|
|
|
else: |
71
|
|
|
logger.info('E2e model file already exists') |
72
|
|
|
|
73
|
|
|
# Download char_dict |
74
|
|
|
logger.info('Download char dict') |
75
|
|
|
url = "https://github.com/Filter-Bubble/e2e-Dutch/raw/v0.2.0/data/char_vocab.dutch.txt" |
76
|
|
|
fname = data_dir / 'char_vocab.dutch.txt' |
77
|
|
|
if not fname.exists(): |
78
|
|
|
download_file(url, fname) |
79
|
|
|
else: |
80
|
|
|
logger.info('Char dict file already exists') |
81
|
|
|
|
82
|
|
|
|
83
|
|
|
def get_parser(): |
|
|
|
|
84
|
|
|
parser = argparse.ArgumentParser() |
85
|
|
|
parser.add_argument('-d', '--datapath', default=None) |
86
|
|
|
parser.add_argument('-v', '--verbose', action='store_true') |
87
|
|
|
return parser |
88
|
|
|
|
89
|
|
|
|
90
|
|
|
def main(): |
|
|
|
|
91
|
|
|
parser = get_parser() |
92
|
|
|
args = parser.parse_args() |
93
|
|
|
if args.verbose: |
94
|
|
|
# logger.setLevel(logging.INFO) |
95
|
|
|
logging.basicConfig(level=logging.INFO) |
96
|
|
|
# To do: argparse for config file |
97
|
|
|
if args.datapath is None: |
98
|
|
|
config = util.initialize_from_env(model_name='final') |
99
|
|
|
else: |
100
|
|
|
config = {'datapath': args.datapath} |
101
|
|
|
download_data(config) |
102
|
|
|
|
103
|
|
|
|
104
|
|
|
if __name__ == "__main__": |
105
|
|
|
main() |
106
|
|
|
|