blob: 1053530d05c559c6d9e1b21cd21a74ac4f3eef43 [file] [log] [blame]
#! /usr/bin/env python3
import os
import subprocess
import sys
import tarfile
import tempfile
from urllib.request import urlretrieve
from caffe2.python.models.download import (
deleteDirectory,
downloadFromURLToFile,
getURLFromName,
)
class SomeClass:
# largely copied from
# https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py
def _download(self, model):
model_dir = self._caffe2_model_dir(model)
assert not os.path.exists(model_dir)
os.makedirs(model_dir)
for f in ["predict_net.pb", "init_net.pb", "value_info.json"]:
url = getURLFromName(model, f)
dest = os.path.join(model_dir, f)
try:
try:
downloadFromURLToFile(url, dest, show_progress=False)
except TypeError:
# show_progress not supported prior to
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
# (Sep 17, 2017)
downloadFromURLToFile(url, dest)
except Exception as e:
print(f"Abort: {e}")
print("Cleaning up...")
deleteDirectory(model_dir)
sys.exit(1)
def _caffe2_model_dir(self, model):
caffe2_home = os.path.expanduser("~/.caffe2")
models_dir = os.path.join(caffe2_home, "models")
return os.path.join(models_dir, model)
def _onnx_model_dir(self, model):
onnx_home = os.path.expanduser("~/.onnx")
models_dir = os.path.join(onnx_home, "models")
model_dir = os.path.join(models_dir, model)
return model_dir, os.path.dirname(model_dir)
# largely copied from
# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py
def _prepare_model_data(self, model):
model_dir, models_dir = self._onnx_model_dir(model)
if os.path.exists(model_dir):
return
os.makedirs(model_dir)
url = f"https://s3.amazonaws.com/download.onnx/models/{model}.tar.gz"
# On Windows, NamedTemporaryFile cannot be opened for a
# second time
download_file = tempfile.NamedTemporaryFile(delete=False)
try:
download_file.close()
print(f"Start downloading model {model} from {url}")
urlretrieve(url, download_file.name)
print("Done")
with tarfile.open(download_file.name) as t:
t.extractall(models_dir)
except Exception as e:
print(f"Failed to prepare data for model {model}: {e}")
raise
finally:
os.remove(download_file.name)
models = [
"bvlc_alexnet",
"densenet121",
"inception_v1",
"inception_v2",
"resnet50",
# TODO currently onnx can't translate squeezenet :(
# 'squeezenet',
"vgg16",
# TODO currently vgg19 doesn't work in the CI environment,
# possibly due to OOM
# 'vgg19'
]
def download_models():
sc = SomeClass()
for model in models:
print("update-caffe2-models.py: downloading", model)
caffe2_model_dir = sc._caffe2_model_dir(model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
if not os.path.exists(caffe2_model_dir):
sc._download(model)
if not os.path.exists(onnx_model_dir):
sc._prepare_model_data(model)
def generate_models():
sc = SomeClass()
for model in models:
print("update-caffe2-models.py: generating", model)
caffe2_model_dir = sc._caffe2_model_dir(model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
subprocess.check_call(["echo", model])
with open(os.path.join(caffe2_model_dir, "value_info.json")) as f:
value_info = f.read()
subprocess.check_call(
[
"convert-caffe2-to-onnx",
"--caffe2-net-name",
model,
"--caffe2-init-net",
os.path.join(caffe2_model_dir, "init_net.pb"),
"--value-info",
value_info,
"-o",
os.path.join(onnx_model_dir, "model.pb"),
os.path.join(caffe2_model_dir, "predict_net.pb"),
]
)
subprocess.check_call(
["tar", "-czf", model + ".tar.gz", model], cwd=onnx_models_dir
)
def upload_models():
sc = SomeClass()
for model in models:
print("update-caffe2-models.py: uploading", model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
subprocess.check_call(
[
"aws",
"s3",
"cp",
model + ".tar.gz",
f"s3://download.onnx/models/{model}.tar.gz",
"--acl",
"public-read",
],
cwd=onnx_models_dir,
)
def cleanup():
sc = SomeClass()
for model in models:
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + ".tar.gz"))
if __name__ == "__main__":
try:
subprocess.check_call(["aws", "sts", "get-caller-identity"])
except:
print(
"update-caffe2-models.py: please run `aws configure` manually to set up credentials"
)
sys.exit(1)
if sys.argv[1] == "download":
download_models()
if sys.argv[1] == "generate":
generate_models()
elif sys.argv[1] == "upload":
upload_models()
elif sys.argv[1] == "cleanup":
cleanup()