| ## @package workspace |
| # Module caffe2.python.workspace |
| |
| |
| |
| |
| import collections |
| import contextlib |
| from google.protobuf.message import Message |
| from multiprocessing import Process |
| import os |
| from collections import defaultdict |
| import logging |
| import numpy as np |
| from past.builtins import basestring |
| import shutil |
| import socket |
| import tempfile |
| |
| from caffe2.proto import caffe2_pb2 |
| from caffe2.python import scope, utils |
| from caffe2.python.lazy import TriggerLazyImport |
| |
| import caffe2.python._import_c_extension as C |
| |
| logger = logging.getLogger(__name__) |
| |
| Blobs = C.blobs |
| ResetBlob = C.reset_blob |
| CreateBlob = C.create_blob |
| CurrentWorkspace = C.current_workspace |
| DeserializeBlob = C.deserialize_blob |
| GlobalInit = C.global_init |
| HasBlob = C.has_blob |
| RegisteredOperators = C.registered_operators |
| SerializeBlob = C.serialize_blob |
| SwitchWorkspace = C.switch_workspace |
| RootFolder = C.root_folder |
| Workspaces = C.workspaces |
| BenchmarkNet = C.benchmark_net |
| BenchmarkNetOnce = C.benchmark_net_once |
| GetStats = C.get_stats |
| CreateOfflineTensor = C.create_offline_tensor |
| |
| operator_tracebacks = defaultdict(dict) |
| |
| is_asan = C.is_asan |
| has_fbgemm = C.has_fbgemm |
| has_cuda_support = C.has_cuda_support |
| has_hip_support = C.has_hip_support |
| has_gpu_support = C.has_gpu_support |
| if has_cuda_support: |
| GpuDeviceType = caffe2_pb2.CUDA |
| NumCudaDevices = C.num_cuda_devices |
| # This is a duplicate of NumCudaDevices. Remove |
| # NumCudaDevices once replaced everywhere in the code |
| NumGpuDevices = C.num_cuda_devices |
| GetCUDAVersion = C.get_cuda_version |
| GetCuDNNVersion = C.get_cudnn_version |
| |
| def GetGpuPeerAccessPattern(): |
| return np.asarray(C.get_cuda_peer_access_pattern()) |
| |
| GetDeviceProperties = C.get_device_properties |
| GetGPUMemoryInfo = C.get_gpu_memory_info |
| else: |
| # pyre-fixme[9]: incompatible type assignment |
| NumCudaDevices = lambda: 0 # noqa |
| # pyre-fixme[9]: incompatible type assignment |
| GetCUDAVersion = lambda: 0 # noqa |
| # pyre-fixme[9]: incompatible type assignment |
| GetCuDNNVersion = lambda: 0 # noqa |
| |
| if has_hip_support: |
| GpuDeviceType = caffe2_pb2.HIP |
| # pyre-fixme[9]: incompatible type assignment |
| NumGpuDevices = C.num_hip_devices |
| GetHIPVersion = C.get_hip_version |
| |
| def GetGpuPeerAccessPattern(): |
| return np.asarray(C.get_hip_peer_access_pattern()) |
| GetDeviceProperties = C.get_device_properties |
| GetGPUMemoryInfo = C.get_gpu_memory_info |
| |
| if not has_gpu_support: |
| # setting cuda as the default GpuDeviceType as some tests |
| # like core, scope tests use GpuDeviceType even without gpu support |
| GpuDeviceType = caffe2_pb2.CUDA |
| # pyre-fixme[9]: incompatible type assignment |
| NumGpuDevices = lambda: 0 # noqa |
| GetDeviceProperties = lambda x: None # noqa |
| GetGpuPeerAccessPattern = lambda: np.array([]) # noqa |
| # pyre-fixme[9]: incompatible type assignment |
| GetGPUMemoryInfo = lambda: None # noqa |
| |
| IsNUMAEnabled = C.is_numa_enabled |
| GetNumNUMANodes = C.get_num_numa_nodes |
| GetBlobNUMANode = C.get_blob_numa_node |
| GetBlobSizeBytes = C.get_blob_size_bytes |
| |
| |
| def FillRandomNetworkInputs(net, input_dims, input_types): |
| C.fill_random_network_inputs(net.Proto().SerializeToString(), input_dims, input_types) |
| |
| |
| def _GetFreeFlaskPort(): |
| """Get a free flask port.""" |
| # We will prefer to use 5000. If not, we will then pick a random port. |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| result = sock.connect_ex(('127.0.0.1', 5000)) |
| if result == 0: |
| return 5000 |
| else: |
| s = socket.socket() |
| s.bind(('', 0)) |
| port = s.getsockname()[1] |
| s.close() |
| # Race condition: between the interval we close the socket and actually |
| # start a mint process, another process might have occupied the port. We |
| # don't do much here as this is mostly for convenience in research |
| # rather than 24x7 service. |
| return port |
| |
| def StartMint(root_folder=None, port=None): |
| """Start a mint instance. |
| |
| TODO(Yangqing): this does not work well under ipython yet. According to |
| https://github.com/ipython/ipython/issues/5862 |
| writing up some fix is a todo item. |
| """ |
| from caffe2.python.mint import app |
| if root_folder is None: |
| # Get the root folder from the current workspace |
| root_folder = C.root_folder() |
| if port is None: |
| port = _GetFreeFlaskPort() |
| process = Process( |
| target=app.main, |
| args=( |
| ['-p', str(port), '-r', root_folder], |
| ) |
| ) |
| process.start() |
| print('Mint running at http://{}:{}'.format(socket.getfqdn(), port)) |
| return process |
| |
| |
| def StringifyProto(obj): |
| """Stringify a protocol buffer object. |
| |
| Inputs: |
| obj: a protocol buffer object, or a Pycaffe2 object that has a Proto() |
| function. |
| Outputs: |
| string: the output protobuf string. |
| Raises: |
| AttributeError: if the passed in object does not have the right attribute. |
| """ |
| if isinstance(obj, basestring): |
| return obj |
| else: |
| if isinstance(obj, Message): |
| # First, see if this object is a protocol buffer, which we can |
| # simply serialize with the SerializeToString() call. |
| return obj.SerializeToString() |
| elif hasattr(obj, 'Proto'): |
| return obj.Proto().SerializeToString() |
| else: |
| raise ValueError("Unexpected argument to StringifyProto of type " + |
| type(obj).__name__) |
| |
| |
| def ResetWorkspace(root_folder=None): |
| if root_folder is None: |
| # Reset the workspace, but keep the current root folder setting. |
| return C.reset_workspace(C.root_folder()) |
| else: |
| if not os.path.exists(root_folder): |
| os.makedirs(root_folder) |
| return C.reset_workspace(root_folder) |
| |
| |
| def CreateNet(net, overwrite=False, input_blobs=None): |
| TriggerLazyImport() |
| if input_blobs is None: |
| input_blobs = [] |
| for input_blob in input_blobs: |
| C.create_blob(input_blob) |
| return CallWithExceptionIntercept( |
| C.create_net, |
| C.Workspace.current._last_failed_op_net_position, |
| GetNetName(net), |
| StringifyProto(net), overwrite, |
| ) |
| |
| |
| def Predictor(init_net, predict_net): |
| return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net)) |
| |
| |
| def GetOperatorCost(operator, blobs): |
| return C.get_operator_cost(StringifyProto(operator), blobs) |
| |
| |
| def RunOperatorOnce(operator): |
| return C.run_operator_once(StringifyProto(operator)) |
| |
| |
| def RunOperatorMultiple(operator, num_runs): |
| return C.run_operator_multiple(StringifyProto(operator), num_runs) |
| |
| |
| def RunOperatorsOnce(operators): |
| for op in operators: |
| success = RunOperatorOnce(op) |
| if not success: |
| return False |
| return True |
| |
| |
| def ClearGlobalNetObserver(): |
| return C.clear_global_net_observer() |
| |
| |
| def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs): |
| try: |
| return func(*args, **kwargs) |
| except Exception: |
| op_id = op_id_fetcher() |
| net_tracebacks = operator_tracebacks.get(net_name, None) |
| logger.warning( |
| 'Original python traceback for operator `{}` in network ' |
| '`{}` in exception above (most recent call last):'.format( |
| op_id, net_name)) |
| if net_tracebacks and op_id in net_tracebacks: |
| tb = net_tracebacks[op_id] |
| for line in reversed(tb): |
| logger.warning(' File "{}", line {}, in {}'.format( |
| line[0], line[1], line[2])) |
| raise |
| |
| |
| def RunNetOnce(net): |
| return CallWithExceptionIntercept( |
| C.run_net_once, |
| C.Workspace.current._last_failed_op_net_position, |
| GetNetName(net), |
| StringifyProto(net), |
| ) |
| |
| |
| def RunNet(name, num_iter=1, allow_fail=False): |
| """Runs a given net. |
| |
| Inputs: |
| name: the name of the net, or a reference to the net. |
| num_iter: number of iterations to run |
| allow_fail: if True, does not assert on net exec failure but returns False |
| Returns: |
| True or an exception. |
| """ |
| return CallWithExceptionIntercept( |
| C.run_net, |
| C.Workspace.current._last_failed_op_net_position, |
| GetNetName(name), |
| StringifyNetName(name), num_iter, allow_fail, |
| ) |
| |
| |
| def RunPlan(plan_or_step): |
| # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps |
| import caffe2.python.core as core |
| if isinstance(plan_or_step, core.ExecutionStep): |
| plan_or_step = core.Plan(plan_or_step) |
| return C.run_plan(StringifyProto(plan_or_step)) |
| |
| |
| def RunPlanInBackground(plan_or_step): |
| # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps |
| import caffe2.python.core as core |
| if isinstance(plan_or_step, core.ExecutionStep): |
| plan_or_step = core.Plan(plan_or_step) |
| return C.run_plan_in_background(StringifyProto(plan_or_step)) |
| |
| |
| def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False, |
| blob_types=None): |
| """Infers the shapes and types for the specified nets. |
| |
| Inputs: |
| nets: the list of nets |
| blob_dimensions (optional): a dictionary of blobs and their dimensions. |
| If not specified, the workspace blobs are used. |
| nets_proto (optional): a boolean flag indicating whether the protobuffer |
| representation is passed to the routine. |
| Returns: |
| A tuple of (shapes, types) dictionaries keyed by blob name. |
| """ |
| if nets_proto: |
| net_protos = [StringifyProto(n) for n in nets] |
| else: |
| net_protos = [StringifyProto(n.Proto()) for n in nets] |
| if blob_dimensions is None: |
| assert blob_types is None |
| blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos) |
| elif blob_types is None: |
| blobdesc_prototxt = C.infer_shapes_and_types_from_map( |
| net_protos, blob_dimensions |
| ) |
| else: |
| blobdesc_prototxt = C.infer_shapes_and_types_from_map( |
| net_protos, blob_dimensions, blob_types |
| ) |
| blobdesc_proto = caffe2_pb2.TensorShapes() |
| blobdesc_proto.ParseFromString(blobdesc_prototxt) |
| shapes = {} |
| types = {} |
| for ts in blobdesc_proto.shapes: |
| if not ts.unknown_shape: |
| shapes[ts.name] = list(ts.dims) |
| types[ts.name] = ts.data_type |
| |
| return (shapes, types) |
| |
| |
| def _StringifyName(name, expected_type): |
| if isinstance(name, basestring): |
| return name |
| assert type(name).__name__ == expected_type, \ |
| "Expected a string or %s" % expected_type |
| return str(name) |
| |
| |
| def StringifyBlobName(name): |
| return _StringifyName(name, "BlobReference") |
| |
| |
| def StringifyNetName(name): |
| return _StringifyName(name, "Net") |
| |
| |
| def GetNetName(net): |
| if isinstance(net, basestring): |
| return net |
| if type(net).__name__ == "Net" or type(net).__name__ == "NetWithShapeInference": |
| return net.Name() |
| if isinstance(net, caffe2_pb2.NetDef): |
| return net.name |
| raise Exception("Not a Net object: {}".format(str(net))) |
| |
| |
| def FeedBlob(name, arr, device_option=None): |
| """Feeds a blob into the workspace. |
| |
| Inputs: |
| name: the name of the blob. |
| arr: either a TensorProto object or a numpy array object to be fed into |
| the workspace. |
| device_option (optional): the device option to feed the data with. |
| Returns: |
| True or False, stating whether the feed is successful. |
| """ |
| ws = C.Workspace.current |
| return _Workspace_feed_blob(ws, name, arr, device_option) |
| |
| |
| def FetchBlobs(names): |
| """Fetches a list of blobs from the workspace. |
| |
| Inputs: |
| names: list of names of blobs - strings or BlobReferences |
| Returns: |
| list of fetched blobs |
| """ |
| return [FetchBlob(name) for name in names] |
| |
| |
| def FetchBlob(name): |
| """Fetches a blob from the workspace. |
| |
| Inputs: |
| name: the name of the blob - a string or a BlobReference |
| Returns: |
| Fetched blob (numpy array or string) if successful |
| """ |
| result = C.fetch_blob(StringifyBlobName(name)) |
| if isinstance(result, tuple): |
| raise TypeError( |
| "Use FetchInt8Blob to fetch Int8 Blob {}".format( |
| StringifyBlobName(name) |
| ) |
| ) |
| return result |
| |
| |
| def FetchTorch(name): |
| ws = C.Workspace.current |
| return ws.blobs[name].to_torch() |
| |
| |
| Int8Tensor = collections.namedtuple( |
| 'Int8Tensor', ['data', 'scale', 'zero_point'] |
| ) |
| |
| |
| def FetchInt8Blob(name): |
| """Fetches an Int8 blob from the workspace. It shared backend implementation |
| with FetchBlob but it is recommended when fetching Int8 Blobs |
| |
| Inputs: |
| name: the name of the Int8 blob - a string or a BlobReference |
| Returns: |
| data: int8 numpy array, data |
| scale: float, fake quantization scale |
| zero_point: int, fake quantization offset |
| """ |
| result = C.fetch_blob(StringifyBlobName(name)) |
| assert isinstance(result, tuple), \ |
| 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format( |
| StringifyBlobName(name)) |
| return Int8Tensor(*result) |
| |
| |
| def FetchInt8BlobRealVal(name): |
| """Fetches an Int8 blob from the workspace and return its real value representation. |
| |
| Inputs: |
| name: the name of the Int8 blob - a string or a BlobReference |
| Returns: |
| real value representation of int8 numpy array |
| """ |
| result = C.fetch_blob(StringifyBlobName(name)) |
| assert isinstance(result, tuple), \ |
| 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format( |
| StringifyBlobName(name)) |
| int8_blob = Int8Tensor(*result) |
| return (int8_blob.data.astype(np.int32) - int(int8_blob.zero_point)).astype( |
| np.float32) * int8_blob.scale |
| |
| |
| def RemoveBlob(name) -> None: |
| ws = C.Workspace.current |
| _Workspace_remove_blob(ws, name) |
| |
| def _Workspace_fetch_int8_blob(ws, name): |
| """Fetches an Int8 blob from the workspace. It shared backend implementation |
| with FetchBlob but it is recommended when fetching Int8 Blobs |
| |
| Inputs: |
| name: the name of the Int8 blob - a string or a BlobReference |
| Returns: |
| data: int8 numpy array, data |
| scale: float, fake quantization scale |
| zero_point: int, fake quantization offset |
| """ |
| result = ws.fetch_blob(name) |
| assert isinstance(result, tuple), \ |
| 'You are not fetching an Int8Blob {}. Please use fetch_blob'.format( |
| StringifyBlobName(name)) |
| return Int8Tensor(*result) |
| |
| |
| C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob |
| |
| |
| def ApplyTransform(transform_key, net): |
| """Apply a Transform to a NetDef protobuf object, and returns the new |
| transformed NetDef. |
| |
| Inputs: |
| transform_key: the name of the transform, as it is stored in the registry |
| net: a NetDef protobuf object |
| Returns: |
| Transformed NetDef protobuf object. |
| """ |
| transformed_net = caffe2_pb2.NetDef() |
| transformed_str = C.apply_transform( |
| str(transform_key).encode('utf-8'), |
| net.SerializeToString(), |
| ) |
| transformed_net.ParseFromString(transformed_str) |
| return transformed_net |
| |
| |
| def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs): |
| """Apply a Transform to a NetDef protobuf object, and returns the new |
| transformed NetDef, only if it runs faster than the original. |
| |
| The runs are performed on the current active workspace (gWorkspace). |
| You should initialize that workspace before making a call to this function. |
| |
| Inputs: |
| transform_key: the name of the transform, as it is stored in the registry |
| net: a NetDef protobuf object |
| init_net: The net to initialize the workspace. |
| warmup_runs (optional): |
| Determines how many times the net is run before testing. |
| Will be 5 by default. |
| main_runs (optional): |
| Determines how many times the net is run during testing. |
| Will be 10 by default. |
| improvement_threshold (optional): |
| Determines the factor which the new net needs to be faster |
| in order to replace the old. Will be 1.01 by default. |
| |
| Returns: |
| Either a Transformed NetDef protobuf object, or the original netdef. |
| """ |
| |
| warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5 |
| main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10 |
| improvement_threshold = kwargs['improvement_threshold'] \ |
| if 'improvement_threshold' in kwargs else 1.01 |
| |
| transformed_net = caffe2_pb2.NetDef() |
| transformed_str = C.apply_transform_if_faster( |
| str(transform_key).encode('utf-8'), |
| net.SerializeToString(), |
| init_net.SerializeToString(), |
| warmup_runs, |
| main_runs, |
| float(improvement_threshold), |
| ) |
| transformed_net.ParseFromString(transformed_str) |
| return transformed_net |
| |
| |
| def GetNameScope(): |
| """Return the current namescope string. To be used to fetch blobs""" |
| return scope.CurrentNameScope() |
| |
| |
| class _BlobDict: |
| """Provides python dict compatible way to do fetching and feeding""" |
| |
| def __getitem__(self, key): |
| return FetchBlob(key) |
| |
| def __setitem__(self, key, value): |
| return FeedBlob(key, value) |
| |
| def __len__(self): |
| return len(C.blobs()) |
| |
| def __iter__(self): |
| return C.blobs().__iter__() |
| |
| def __contains__(self, item): |
| return C.has_blob(item) |
| |
| |
| blobs = _BlobDict() |
| |
| |
| ################################################################################ |
| # Utilities for immediate mode |
| # |
| # Caffe2's immediate mode implements the following behavior: between the two |
| # function calls StartImmediate() and StopImmediate(), for any operator that is |
| # called through CreateOperator(), we will also run that operator in a workspace |
| # that is specific to the immediate mode. The user is explicitly expected to |
| # make sure that these ops have proper inputs and outputs, i.e. one should not |
| # run an op where an external input is not created or fed. |
| # |
| # Users can use FeedImmediate() and FetchImmediate() to interact with blobs |
| # in the immediate workspace. |
| # |
| # Once StopImmediate() is called, all contents in the immediate workspace is |
| # freed up so one can continue using normal runs. |
| # |
| # The immediate mode is solely for debugging purposes and support will be very |
| # sparse. |
| ################################################################################ |
| |
| _immediate_mode = False |
| _immediate_workspace_name = "_CAFFE2_IMMEDIATE" |
| _immediate_root_folder = '' |
| |
| |
| def IsImmediate(): |
| return _immediate_mode |
| |
| |
| @contextlib.contextmanager |
| def WorkspaceGuard(workspace_name): |
| current = CurrentWorkspace() |
| SwitchWorkspace(workspace_name, True) |
| yield |
| SwitchWorkspace(current) |
| |
| |
| def StartImmediate(i_know=False): |
| global _immediate_mode |
| global _immediate_root_folder |
| if IsImmediate(): |
| # already in immediate mode. We will kill the previous one |
| # and start from fresh. |
| StopImmediate() |
| _immediate_mode = True |
| with WorkspaceGuard(_immediate_workspace_name): |
| _immediate_root_folder = tempfile.mkdtemp() |
| ResetWorkspace(_immediate_root_folder) |
| if i_know: |
| # if the user doesn't want to see the warning message, sure... |
| return |
| print(""" |
| Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL |
| feature and may very easily go wrong. This is because Caffe2 uses a |
| declarative way of defining operators and models, which is essentially |
| not meant to run things in an interactive way. Read the following carefully |
| to make sure that you understand the caveats. |
| |
| (1) You need to make sure that the sequences of operators you create are |
| actually runnable sequentially. For example, if you create an op that takes |
| an input X, somewhere earlier you should have already created X. |
| |
| (2) Caffe2 immediate uses one single workspace, so if the set of operators |
| you run are intended to be under different workspaces, they will not run. |
| To create boundaries between such use cases, you can call FinishImmediate() |
| and StartImmediate() manually to flush out everything no longer needed. |
| |
| (3) Underlying objects held by the immediate mode may interfere with your |
| normal run. For example, if there is a leveldb that you opened in immediate |
| mode and did not close, your main run will fail because leveldb does not |
| support double opening. Immediate mode may also occupy a lot of memory esp. |
| on GPUs. Call FinishImmediate() as soon as possible when you no longer |
| need it. |
| |
| (4) Immediate is designed to be slow. Every immediate call implicitly |
| creates a temp operator object, runs it, and destroys the operator. This |
| slow-speed run is by design to discourage abuse. For most use cases other |
| than debugging, do NOT turn on immediate mode. |
| |
| (5) If there is anything FATAL happening in the underlying C++ code, the |
| immediate mode will immediately (pun intended) cause the runtime to crash. |
| |
| Thus you should use immediate mode with extra care. If you still would |
| like to, have fun [https://xkcd.com/149/]. |
| """) |
| |
| |
| def StopImmediate(): |
| """Stops an immediate mode run.""" |
| # Phew, that was a dangerous ride. |
| global _immediate_mode |
| global _immediate_root_folder |
| if not IsImmediate(): |
| return |
| with WorkspaceGuard(_immediate_workspace_name): |
| ResetWorkspace() |
| shutil.rmtree(_immediate_root_folder) |
| _immediate_root_folder = '' |
| _immediate_mode = False |
| |
| |
| def ImmediateBlobs(): |
| with WorkspaceGuard(_immediate_workspace_name): |
| return Blobs() |
| |
| |
| def RunOperatorImmediate(op): |
| with WorkspaceGuard(_immediate_workspace_name): |
| RunOperatorOnce(op) |
| |
| |
| def FetchImmediate(*args, **kwargs): |
| with WorkspaceGuard(_immediate_workspace_name): |
| return FetchBlob(*args, **kwargs) |
| |
| |
| def FeedImmediate(*args, **kwargs): |
| with WorkspaceGuard(_immediate_workspace_name): |
| return FeedBlob(*args, **kwargs) |
| |
| |
| # C.Workspace methods. |
| |
| def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False): |
| return CallWithExceptionIntercept( |
| ws._create_net, |
| ws._last_failed_op_net_position, |
| GetNetName(net), |
| StringifyProto(net), overwrite, |
| ) |
| |
| |
| def _Workspace_run(ws, obj): |
| if hasattr(obj, 'Proto'): |
| obj = obj.Proto() |
| if isinstance(obj, caffe2_pb2.PlanDef): |
| return ws._run_plan(obj.SerializeToString()) |
| if isinstance(obj, caffe2_pb2.NetDef): |
| return CallWithExceptionIntercept( |
| ws._run_net, |
| ws._last_failed_op_net_position, |
| GetNetName(obj), |
| obj.SerializeToString(), |
| ) |
| # return ws._run_net(obj.SerializeToString()) |
| if isinstance(obj, caffe2_pb2.OperatorDef): |
| return ws._run_operator(obj.SerializeToString()) |
| raise ValueError( |
| "Don't know how to do Workspace.run() on {}".format(type(obj))) |
| |
| |
| def _Workspace_feed_blob(ws, name, arr, device_option=None): |
| if type(arr) is caffe2_pb2.TensorProto: |
| arr = utils.Caffe2TensorToNumpyArray(arr) |
| if type(arr) is np.ndarray and arr.dtype.kind in 'SU': |
| # Plain NumPy strings are weird, let's use objects instead |
| arr = arr.astype(object) |
| |
| if device_option is None: |
| device_option = scope.CurrentDeviceScope() |
| |
| if device_option and device_option.device_type == caffe2_pb2.CUDA: |
| if arr.dtype == np.dtype('float64'): |
| logger.warning( |
| "CUDA operators do not support 64-bit doubles, " + |
| "please use arr.astype(np.float32) or np.int32 for ints." + |
| " Blob: {}".format(name) + |
| " type: {}".format(str(arr.dtype)) |
| ) |
| |
| name = StringifyBlobName(name) |
| if device_option is not None: |
| return ws.create_blob(name).feed(arr, device_option) |
| else: |
| return ws.create_blob(name).feed(arr) |
| |
| |
| def _Workspace_remove_blob(ws, blob): |
| ws._remove_blob(str(blob)) |
| |
| |
| Workspace = C.Workspace |
| Workspace.create_net = _Workspace_create_net_with_exception_intercept |
| Workspace.run = _Workspace_run |
| Workspace.feed_blob = _Workspace_feed_blob |
| Workspace.remove_blob = _Workspace_remove_blob |
| |
| # C.Blob methods. |
| |
| |
| def _Blob_feed(blob, arg, device_option=None): |
| # conservative type check to avoid unnecessary import |
| if type(arg).__name__ == 'Tensor' and type(arg).__module__ == 'torch': |
| import torch |
| if isinstance(arg, torch.Tensor): |
| assert device_option is None, \ |
| "device_option doesn't make sense with PyTorch tensors" |
| handle = torch._C._tensor_impl_raw_handle(arg) |
| blob._wrap_tensor_impl(handle) |
| return True # _feed() returns True for some reason |
| if device_option is not None: |
| device_option = StringifyProto(device_option) |
| return blob._feed(arg, device_option) |
| |
| |
| C.Blob.feed = _Blob_feed |
| |
| |
| def _Tensor_to_torch(tensor): |
| """ |
| PyTorch tensor interop (TensorCPU methods) |
| |
| Can be accessed as: |
| workspace.Workspace.current.blobs['foo'].tensor().to_torch() |
| """ |
| # avoiding circular dependency |
| import torch |
| handle = tensor._tensor_impl_raw_handle() |
| return torch._C._wrap_tensor_impl(handle) |
| |
| C.TensorCPU.to_torch = _Tensor_to_torch |
| |
| |
| def _Blob_to_torch(blob): |
| if not blob.is_tensor(): |
| raise RuntimeError("Blob has to be a tensor") |
| return blob.as_tensor().to_torch() |
| |
| C.Blob.to_torch = _Blob_to_torch |