forked from rowanz/grover
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdownload_model.py
26 lines (23 loc) · 923 Bytes
/
download_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import requests
import argparse
parser = argparse.ArgumentParser(description='Download a model!')
parser.add_argument(
'model_type',
type=str,
help='Valid model names: (base|large)',
)
model_type = parser.parse_args().model_type
model_dir = os.path.join('models', model_type)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
for ext in ['data-00000-of-00001', 'index', 'meta']:
r = requests.get(f'https://storage.googleapis.com/grover-models/{model_type}/model.ckpt.{ext}', stream=True)
with open(os.path.join(model_dir, f'model.ckpt.{ext}'), 'wb') as f:
file_size = int(r.headers["content-length"])
if file_size < 1000:
raise ValueError("File doesn't exist? idk")
chunk_size = 1000
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
print(f"Just downloaded {model_type}/model.ckpt.{ext}!", flush=True)