blob: a484585770e044c948958a4686ade00defcc0317 [file] [log] [blame]
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/image/image_input_op.h"
namespace caffe2 {
template <>
bool ImageInputOp<CUDAContext>::ApplyTransformOnGPU(
const std::vector<std::int64_t>& dims,
const c10::Device& type) {
// GPU transform kernel allows explicitly setting output type
if (output_type_ == TensorProto_DataType_FLOAT) {
auto* image_output =
OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(type));
TransformOnGPU<uint8_t, float, CUDAContext>(
prefetched_image_on_device_,
image_output,
mean_gpu_,
std_gpu_,
&context_);
} else if (output_type_ == TensorProto_DataType_FLOAT16) {
auto* image_output =
OperatorBase::OutputTensor(0, dims, at::dtype<at::Half>().device(type));
TransformOnGPU<uint8_t, at::Half, CUDAContext>(
prefetched_image_on_device_,
image_output,
mean_gpu_,
std_gpu_,
&context_);
} else {
return false;
}
return true;
}
REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp<CUDAContext>);
} // namespace caffe2