| #! /usr/bin/env python3 |
| |
| import os |
| import subprocess |
| import sys |
| import tarfile |
| import tempfile |
| |
| from urllib.request import urlretrieve |
| |
| from caffe2.python.models.download import ( |
| deleteDirectory, |
| downloadFromURLToFile, |
| getURLFromName, |
| ) |
| |
| |
| class SomeClass: |
| # largely copied from |
| # https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py |
| def _download(self, model): |
| model_dir = self._caffe2_model_dir(model) |
| assert not os.path.exists(model_dir) |
| os.makedirs(model_dir) |
| for f in ["predict_net.pb", "init_net.pb", "value_info.json"]: |
| url = getURLFromName(model, f) |
| dest = os.path.join(model_dir, f) |
| try: |
| try: |
| downloadFromURLToFile(url, dest, show_progress=False) |
| except TypeError: |
| # show_progress not supported prior to |
| # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 |
| # (Sep 17, 2017) |
| downloadFromURLToFile(url, dest) |
| except Exception as e: |
| print(f"Abort: {e}") |
| print("Cleaning up...") |
| deleteDirectory(model_dir) |
| sys.exit(1) |
| |
| def _caffe2_model_dir(self, model): |
| caffe2_home = os.path.expanduser("~/.caffe2") |
| models_dir = os.path.join(caffe2_home, "models") |
| return os.path.join(models_dir, model) |
| |
| def _onnx_model_dir(self, model): |
| onnx_home = os.path.expanduser("~/.onnx") |
| models_dir = os.path.join(onnx_home, "models") |
| model_dir = os.path.join(models_dir, model) |
| return model_dir, os.path.dirname(model_dir) |
| |
| # largely copied from |
| # https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py |
| def _prepare_model_data(self, model): |
| model_dir, models_dir = self._onnx_model_dir(model) |
| if os.path.exists(model_dir): |
| return |
| os.makedirs(model_dir) |
| url = f"https://s3.amazonaws.com/download.onnx/models/{model}.tar.gz" |
| |
| # On Windows, NamedTemporaryFile cannot be opened for a |
| # second time |
| download_file = tempfile.NamedTemporaryFile(delete=False) |
| try: |
| download_file.close() |
| print(f"Start downloading model {model} from {url}") |
| urlretrieve(url, download_file.name) |
| print("Done") |
| with tarfile.open(download_file.name) as t: |
| t.extractall(models_dir) |
| except Exception as e: |
| print(f"Failed to prepare data for model {model}: {e}") |
| raise |
| finally: |
| os.remove(download_file.name) |
| |
| |
| models = [ |
| "bvlc_alexnet", |
| "densenet121", |
| "inception_v1", |
| "inception_v2", |
| "resnet50", |
| # TODO currently onnx can't translate squeezenet :( |
| # 'squeezenet', |
| "vgg16", |
| # TODO currently vgg19 doesn't work in the CI environment, |
| # possibly due to OOM |
| # 'vgg19' |
| ] |
| |
| |
| def download_models(): |
| sc = SomeClass() |
| for model in models: |
| print("update-caffe2-models.py: downloading", model) |
| caffe2_model_dir = sc._caffe2_model_dir(model) |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| if not os.path.exists(caffe2_model_dir): |
| sc._download(model) |
| if not os.path.exists(onnx_model_dir): |
| sc._prepare_model_data(model) |
| |
| |
| def generate_models(): |
| sc = SomeClass() |
| for model in models: |
| print("update-caffe2-models.py: generating", model) |
| caffe2_model_dir = sc._caffe2_model_dir(model) |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| subprocess.check_call(["echo", model]) |
| with open(os.path.join(caffe2_model_dir, "value_info.json")) as f: |
| value_info = f.read() |
| subprocess.check_call( |
| [ |
| "convert-caffe2-to-onnx", |
| "--caffe2-net-name", |
| model, |
| "--caffe2-init-net", |
| os.path.join(caffe2_model_dir, "init_net.pb"), |
| "--value-info", |
| value_info, |
| "-o", |
| os.path.join(onnx_model_dir, "model.pb"), |
| os.path.join(caffe2_model_dir, "predict_net.pb"), |
| ] |
| ) |
| subprocess.check_call( |
| ["tar", "-czf", model + ".tar.gz", model], cwd=onnx_models_dir |
| ) |
| |
| |
| def upload_models(): |
| sc = SomeClass() |
| for model in models: |
| print("update-caffe2-models.py: uploading", model) |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| subprocess.check_call( |
| [ |
| "aws", |
| "s3", |
| "cp", |
| model + ".tar.gz", |
| f"s3://download.onnx/models/{model}.tar.gz", |
| "--acl", |
| "public-read", |
| ], |
| cwd=onnx_models_dir, |
| ) |
| |
| |
| def cleanup(): |
| sc = SomeClass() |
| for model in models: |
| onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) |
| os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + ".tar.gz")) |
| |
| |
| if __name__ == "__main__": |
| try: |
| subprocess.check_call(["aws", "sts", "get-caller-identity"]) |
| except: |
| print( |
| "update-caffe2-models.py: please run `aws configure` manually to set up credentials" |
| ) |
| sys.exit(1) |
| if sys.argv[1] == "download": |
| download_models() |
| if sys.argv[1] == "generate": |
| generate_models() |
| elif sys.argv[1] == "upload": |
| upload_models() |
| elif sys.argv[1] == "cleanup": |
| cleanup() |