blob: e9a5f28cb880e2964158529119379a17e89fb09d [file] [log] [blame]
#! /usr/bin/env python3
import os
import subprocess
import sys
import tarfile
import tempfile
from six.moves.urllib.request import urlretrieve
from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
class SomeClass:
# largely copied from
# https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py
def _download(self, model):
model_dir = self._caffe2_model_dir(model)
assert not os.path.exists(model_dir)
os.makedirs(model_dir)
for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']:
url = getURLFromName(model, f)
dest = os.path.join(model_dir, f)
try:
try:
downloadFromURLToFile(url, dest,
show_progress=False)
except TypeError:
# show_progress not supported prior to
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
# (Sep 17, 2017)
downloadFromURLToFile(url, dest)
except Exception as e:
print("Abort: {reason}".format(reason=e))
print("Cleaning up...")
deleteDirectory(model_dir)
exit(1)
def _caffe2_model_dir(self, model):
caffe2_home = os.path.expanduser('~/.caffe2')
models_dir = os.path.join(caffe2_home, 'models')
return os.path.join(models_dir, model)
def _onnx_model_dir(self, model):
onnx_home = os.path.expanduser('~/.onnx')
models_dir = os.path.join(onnx_home, 'models')
model_dir = os.path.join(models_dir, model)
return model_dir, os.path.dirname(model_dir)
# largely copied from
# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py
def _prepare_model_data(self, model):
model_dir, models_dir = self._onnx_model_dir(model)
if os.path.exists(model_dir):
return
os.makedirs(model_dir)
url = 'https://s3.amazonaws.com/download.onnx/models/{}.tar.gz'.format(model)
# On Windows, NamedTemporaryFile cannot be opened for a
# second time
download_file = tempfile.NamedTemporaryFile(delete=False)
try:
download_file.close()
print('Start downloading model {} from {}'.format(model, url))
urlretrieve(url, download_file.name)
print('Done')
with tarfile.open(download_file.name) as t:
t.extractall(models_dir)
except Exception as e:
print('Failed to prepare data for model {}: {}'.format(model, e))
raise
finally:
os.remove(download_file.name)
models = [
'bvlc_alexnet',
'densenet121',
'inception_v1',
'inception_v2',
'resnet50',
# TODO currently onnx can't translate squeezenet :(
# 'squeezenet',
'vgg16',
# TODO currently vgg19 doesn't work in the CI environment,
# possibly due to OOM
# 'vgg19'
]
def download_models():
sc = SomeClass()
for model in models:
print('update-caffe2-models.py: downloading', model)
caffe2_model_dir = sc._caffe2_model_dir(model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
if not os.path.exists(caffe2_model_dir):
sc._download(model)
if not os.path.exists(onnx_model_dir):
sc._prepare_model_data(model)
def generate_models():
sc = SomeClass()
for model in models:
print('update-caffe2-models.py: generating', model)
caffe2_model_dir = sc._caffe2_model_dir(model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
subprocess.check_call(['echo', model])
with open(os.path.join(caffe2_model_dir, 'value_info.json'), 'r') as f:
value_info = f.read()
subprocess.check_call([
'convert-caffe2-to-onnx',
'--caffe2-net-name', model,
'--caffe2-init-net', os.path.join(caffe2_model_dir, 'init_net.pb'),
'--value-info', value_info,
'-o', os.path.join(onnx_model_dir, 'model.pb'),
os.path.join(caffe2_model_dir, 'predict_net.pb')
])
subprocess.check_call([
'tar',
'-czf',
model + '.tar.gz',
model
], cwd=onnx_models_dir)
def upload_models():
sc = SomeClass()
for model in models:
print('update-caffe2-models.py: uploading', model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
subprocess.check_call([
'aws',
's3',
'cp',
model + '.tar.gz',
"s3://download.onnx/models/{}.tar.gz".format(model),
'--acl', 'public-read'
], cwd=onnx_models_dir)
def cleanup():
sc = SomeClass()
for model in models:
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + '.tar.gz'))
if __name__ == '__main__':
try:
subprocess.check_call(['aws', 'sts', 'get-caller-identity'])
except:
print('update-caffe2-models.py: please run `aws configure` manually to set up credentials')
sys.exit(1)
if sys.argv[1] == 'download':
download_models()
if sys.argv[1] == 'generate':
generate_models()
elif sys.argv[1] == 'upload':
upload_models()
elif sys.argv[1] == 'cleanup':
cleanup()