| #! /usr/bin/env python3 |
| |
| import os |
| import subprocess |
| import sys |
| import tarfile |
| import tempfile |
| |
| from six.moves.urllib.request import urlretrieve |
| |
| from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory |
| |
| class SomeClass: |
| # largely copied from |
| # https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py |
| def _download(self, model): |
| model_dir = self._caffe2_model_dir(model) |
| assert not os.path.exists(model_dir) |
| os.makedirs(model_dir) |
| for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']: |
| url = getURLFromName(model, f) |
| dest = os.path.join(model_dir, f) |
| try: |
| try: |
| downloadFromURLToFile(url, dest, |
| show_progress=False) |
| except TypeError: |
| # show_progress not supported prior to |
| # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 |
| # (Sep 17, 2017) |
| downloadFromURLToFile(url, dest) |
| except Exception as e: |
| print("Abort: {reason}".format(reason=e)) |
| print("Cleaning up...") |
| deleteDirectory(model_dir) |
| exit(1) |
| |
| def _caffe2_model_dir(self, model): |
| caffe2_home = os.path.expanduser('~/.caffe2') |
| models_dir = os.path.join(caffe2_home, 'models') |
| return os.path.join(models_dir, model) |
| |
| def _onnx_model_dir(self, model): |
| onnx_home = os.path.expanduser('~/.onnx') |
| models_dir = os.path.join(onnx_home, 'models') |
| model_dir = os.path.join(models_dir, model) |
| return model_dir, os.path.dirname(model_dir) |
| |
| # largely copied from |
| # https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py |
| def _prepare_model_data(self, model): |
| model_dir, models_dir = self._onnx_model_dir(model) |
| if os.path.exists(model_dir): |
| return |
| os.makedirs(model_dir) |
| url = 'https://s3.amazonaws.com/download.onnx/models/{}.tar.gz'.format(model) |
| |
| # On Windows, NamedTemporaryFile cannot be opened for a |
| # second time |
| download_file = tempfile.NamedTemporaryFile(delete=False) |
| try: |
| download_file.close() |
| print('Start downloading model {} from {}'.format(model, url)) |
| urlretrieve(url, download_file.name) |
| print('Done') |
| with tarfile.open(download_file.name) as t: |
| t.extractall(models_dir) |
| except Exception as e: |
| print('Failed to prepare data for model {}: {}'.format(model, e)) |
| raise |
| finally: |
| os.remove(download_file.name) |
| |
| models = [ |
| 'bvlc_alexnet', |
| 'densenet121', |
| 'inception_v1', |
| 'inception_v2', |
| 'resnet50', |
| |
| # TODO currently onnx can't translate squeezenet :( |
| # 'squeezenet', |
| |
| 'vgg16', |
| |
| # TODO currently vgg19 doesn't work in the CI environment, |
| # possibly due to OOM |
| # 'vgg19' |
| ] |
| |
| def download_models(): |
| sc = SomeClass() |
| for model in models: |
| print('update-caffe2-models.py: downloading', model) |
| caffe2_model_dir = sc._caffe2_model_dir(model) |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| if not os.path.exists(caffe2_model_dir): |
| sc._download(model) |
| if not os.path.exists(onnx_model_dir): |
| sc._prepare_model_data(model) |
| |
| def generate_models(): |
| sc = SomeClass() |
| for model in models: |
| print('update-caffe2-models.py: generating', model) |
| caffe2_model_dir = sc._caffe2_model_dir(model) |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| subprocess.check_call(['echo', model]) |
| with open(os.path.join(caffe2_model_dir, 'value_info.json'), 'r') as f: |
| value_info = f.read() |
| subprocess.check_call([ |
| 'convert-caffe2-to-onnx', |
| '--caffe2-net-name', model, |
| '--caffe2-init-net', os.path.join(caffe2_model_dir, 'init_net.pb'), |
| '--value-info', value_info, |
| '-o', os.path.join(onnx_model_dir, 'model.pb'), |
| os.path.join(caffe2_model_dir, 'predict_net.pb') |
| ]) |
| subprocess.check_call([ |
| 'tar', |
| '-czf', |
| model + '.tar.gz', |
| model |
| ], cwd=onnx_models_dir) |
| |
| def upload_models(): |
| sc = SomeClass() |
| for model in models: |
| print('update-caffe2-models.py: uploading', model) |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| subprocess.check_call([ |
| 'aws', |
| 's3', |
| 'cp', |
| model + '.tar.gz', |
| "s3://download.onnx/models/{}.tar.gz".format(model), |
| '--acl', 'public-read' |
| ], cwd=onnx_models_dir) |
| |
| def cleanup(): |
| sc = SomeClass() |
| for model in models: |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + '.tar.gz')) |
| |
| if __name__ == '__main__': |
| try: |
| subprocess.check_call(['aws', 'sts', 'get-caller-identity']) |
| except: |
| print('update-caffe2-models.py: please run `aws configure` manually to set up credentials') |
| sys.exit(1) |
| if sys.argv[1] == 'download': |
| download_models() |
| if sys.argv[1] == 'generate': |
| generate_models() |
| elif sys.argv[1] == 'upload': |
| upload_models() |
| elif sys.argv[1] == 'cleanup': |
| cleanup() |