blob: 6f9db0837946254bcd80d5713a7b0b7a3b3a53a9 [file] [log] [blame]
#pragma once
#include <caffe2/utils/proto_utils.h>
namespace caffe2 {
namespace cast {
inline TensorProto_DataType GetCastDataType(const ArgumentHelper& helper, std::string arg) {
TensorProto_DataType to;
if (helper.HasSingleArgumentOfType<string>(arg)) {
string s = helper.GetSingleArgument<string>(arg, "float");
std::transform(s.begin(), s.end(), s.begin(), ::toupper);
#ifndef CAFFE2_USE_LITE_PROTO
CAFFE_ENFORCE(TensorProto_DataType_Parse(s, &to), "Unknown 'to' argument: ", s);
#else
// Manually implement in the lite proto case.
#define X(t) \
if (s == #t) { \
return TensorProto_DataType_##t; \
}
X(FLOAT);
X(INT32);
X(BYTE);
X(STRING);
X(BOOL);
X(UINT8);
X(INT8);
X(UINT16);
X(INT16);
X(INT64);
X(FLOAT16);
X(DOUBLE);
#undef X
CAFFE_THROW("Unhandled type argument: ", s);
#endif
} else {
to = static_cast<TensorProto_DataType>(
helper.GetSingleArgument<int>(arg, TensorProto_DataType_FLOAT));
}
return to;
}
}; // namespace cast
}; // namespace caffe2