[MPS] Add argmin (#80828)
This PR
1. adds argmin
2. refactors `reduction_type` in `ReduceOps.mm` with enum.
Co-authored by Kulin Seth <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80828
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 0bcb2ea..bb1b058 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -13,13 +13,23 @@
namespace at {
namespace native {
-using namespace std;
-
enum StdVarType {
STANDARD_VARIANCE,
STANDARD_DEVIATION
};
+enum MPSReductionType {
+ MAX,
+ MIN,
+ AMAX,
+ AMIN,
+ SUM,
+ PROD,
+ MEAN,
+ COUNT_NONZERO
+};
+
+
void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
NSMutableArray<NSNumber*> * &apparent_in_shape,
int64_t num_reduce_dims,
@@ -131,8 +141,8 @@
bool keepdim,
c10::optional<ScalarType> dtype,
const Tensor& output_t,
- string reduction_type,
- string func_name) {
+ MPSReductionType reduction_type,
+ const std::string& func_name) {
IntArrayRef input_shape = input_t.sizes();
@@ -196,19 +206,19 @@
MPSGraphTensor* castOutputTensor = nil;
- if(reduction_type == "sum") {
+ if(reduction_type == MPSReductionType::SUM) {
castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor
axes:axes
name:nil];
- } else if(reduction_type == "prod") {
+ } else if(reduction_type == MPSReductionType::PROD) {
castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor
axes:axes
name:nil];
- } else if(reduction_type == "mean") {
+ } else if(reduction_type == MPSReductionType::MEAN) {
castOutputTensor = [mpsGraph meanOfTensor:inputTensor
axes:axes
name:nil];
- } else if(reduction_type == "count_nonzero") {
+ } else if(reduction_type == MPSReductionType::COUNT_NONZERO) {
MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0
dataType:castInputTensor.dataType];
@@ -220,11 +230,11 @@
axes:axes
name:nil];
}
- else if(reduction_type == "amax") {
+ else if(reduction_type == MPSReductionType::AMAX) {
castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
axes:axes
name:nil];
- } else if(reduction_type == "amin") {
+ } else if(reduction_type == MPSReductionType::AMIN) {
castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
axes:axes
name:nil];
@@ -273,7 +283,7 @@
c10::optional<ScalarType> dtype,
const Tensor& output_t) {
- reduction_out_mps(input_t, dim, keepdim, dtype, output_t, "sum", "sum_out_mps");
+ reduction_out_mps(input_t, dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
}
TORCH_IMPL_FUNC(prod_out_mps)
@@ -285,7 +295,7 @@
int64_t dims[1] = {dim};
- reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, "prod", "prod_out_mps");
+ reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps");
}
// Taken from ReduceOps.cpp
@@ -309,7 +319,7 @@
bool keepdim,
const Tensor& output_t) {
- reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amax", "amax_out_mps");
+ reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps");
}
TORCH_IMPL_FUNC(amin_out_mps)
@@ -318,7 +328,7 @@
bool keepdim,
const Tensor& output_t) {
- reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amin", "amin_out_mps");
+ reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps");
}
Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {
@@ -338,7 +348,7 @@
c10::nullopt,
c10::nullopt);
- reduction_out_mps(self, IntArrayRef(dims, num_dims), false, opt_dtype, const_cast<Tensor&>(output_t), "prod", "prod_mps");
+ reduction_out_mps(self, IntArrayRef(dims, num_dims), false, opt_dtype, const_cast<Tensor&>(output_t), MPSReductionType::PROD, "prod_mps");
return output_t;
}
@@ -365,7 +375,7 @@
c10::nullopt,
c10::nullopt);
- reduction_out_mps(self, dims, false, self.scalar_type(), const_cast<Tensor&>(output_t), "count_nonzero", "count_nonzero_mps");
+ reduction_out_mps(self, dims, false, self.scalar_type(), const_cast<Tensor&>(output_t), MPSReductionType::COUNT_NONZERO, "count_nonzero_mps");
free(raw_output_shape);
@@ -379,135 +389,7 @@
c10::optional<ScalarType> dtype,
const Tensor& output_t) {
- reduction_out_mps(input_t, dim, keepdim, dtype, output_t, "mean", "mean_out_mps");
-}
-
-TORCH_IMPL_FUNC(argmax_out_mps)
- (const Tensor& input_t,
- c10::optional<int64_t> dim,
- bool keepdim,
- const Tensor& output_t) {
-
- namespace native_mps = at::native::mps;
-
- // Derive from MPSCachedGraph
- struct CachedGraph : public native_mps::MPSCachedGraph
- {
- CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- MPSGraphTensor *inputTensor_ = nil;
- MPSGraphTensor *outputTensor_ = nil;
- };
-
- native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
-
- int64_t dim_;
-
- if (dim.has_value()) {
- dim_ = maybe_wrap_dim(dim.value(), input_t.dim());
- native::zero_numel_check_dims(input_t, dim_, "argmax()");
- } else {
- TORCH_CHECK_INDEX(
- input_t.numel() != 0,
- "argmax()", ": Expected reduction dim to be specified for input.numel() == 0.");
- // Since input will be flattened, take argmax along 0'th dimension
- dim_ = 0;
- }
-
- // Calculate the output shape according to keepdim=True
- // If there is no dim argument, the input shape is flattened
- IntArrayRef input_shape = input_t.sizes();
- int64_t num_input_dims = input_shape.size();
- NSMutableArray<NSNumber*> *apparent_in_shape = nil;
- NSMutableArray<NSNumber*> *apparent_out_shape = nil;
-
- if(dim.has_value()) {
- apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
- for(int i = 0; i < num_input_dims; i++) {
- if(dim_ == i)
- apparent_out_shape[i] = @1;
- else
- apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
- }
- }
- else {
- apparent_in_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
- int64_t num_in_elements = 1;
- for(int i = 0; i < num_input_dims; i++) {
- num_in_elements *= input_shape[i];
- }
- apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements];
-
- apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
- apparent_out_shape[0] = @1;
- }
-
- if (output_t.numel() == 0) {
- return;
- }
-
- auto stream = at::mps::getCurrentMPSStream();
-
- @autoreleasepool {
- string key = "argmax_out_mps:" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type());
- CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-
- if(!cachedGraph) {
- native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
-
- CachedGraph *newCachedGraph = nil;
-
- @autoreleasepool {
- MPSGraph* mpsGraph = native_mps::make_mps_graph();
- newCachedGraph = new CachedGraph(mpsGraph);
-
- MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
-
- MPSGraphTensor* castInputTensor = nil;
-
- if(input_t.scalar_type() != ScalarType::Float &&
- input_t.scalar_type() != ScalarType::Int &&
- input_t.scalar_type() != ScalarType::Half)
- castInputTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeFloat32
- name:@"castInputTensor"];
- else
- castInputTensor = inputTensor;
-
- MPSGraphTensor* argmaxOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor
- axis:(NSInteger)dim_
- name:@"argmax_out"];
- MPSGraphTensor* outputTensor = [mpsGraph castTensor:argmaxOutTensor
- toType:MPSDataTypeInt64
- name:@"cast_out"];
-
- newCachedGraph->inputTensor_ = inputTensor;
- newCachedGraph->outputTensor_ = outputTensor;
- }
- return newCachedGraph;
- });
- cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
- }
-
- native_mps::Placeholder inputPlaceholder = native_mps::Placeholder();
- if(apparent_in_shape)
- inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape);
- else
- inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
-
- auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
-
- NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
- inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
- };
-
- NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
- outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
- };
-
- native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
-
- }
-
+ reduction_out_mps(input_t, dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps");
}
TORCH_IMPL_FUNC(norm_out_mps)
@@ -1316,8 +1198,8 @@
Tensor min_max_mps
(const Tensor& input_t,
- string reduction_type,
- string func_name) {
+ MPSReductionType reduction_type,
+ const std::string& func_name) {
namespace native_mps = at::native::mps;
@@ -1365,11 +1247,11 @@
MPSGraphTensor* outputTensor = nil;
- if(reduction_type == "max")
+ if(reduction_type == MPSReductionType::MAX)
outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
axes:@[@0]
name:nil];
- else if(reduction_type == "min")
+ else if(reduction_type == MPSReductionType::MIN)
outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
axes:@[@0]
name:nil];
@@ -1403,13 +1285,13 @@
// Max entire tensor into scalar result
Tensor max_mps(const Tensor& input_t) {
- return min_max_mps(input_t, "max", "max_mps");
+ return min_max_mps(input_t, MPSReductionType::MAX, "max_mps");
}
// Min entire tensor into scalar result
Tensor min_mps(const Tensor& input_t) {
- return min_max_mps(input_t, "min", "min_mps");
+ return min_max_mps(input_t, MPSReductionType::MIN, "min_mps");
}
void min_max_out_mps
@@ -1418,8 +1300,8 @@
bool keepdim,
const Tensor& output_t,
const Tensor& indices_t,
- string reduction_type,
- string func_name) {
+ MPSReductionType reduction_type,
+ const std::string& func_name) {
namespace native_mps = at::native::mps;
@@ -1477,11 +1359,11 @@
MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
MPSGraphTensor* outputTensor = nil;
- if(reduction_type == "max")
+ if(reduction_type == MPSReductionType::MAX)
outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
axis:(NSInteger)dim_
name:nil];
- else if(reduction_type == "min")
+ else if(reduction_type == MPSReductionType::MIN)
outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
axis:(NSInteger)dim_
name:nil];
@@ -1498,11 +1380,11 @@
castInputTensor = inputTensor;
MPSGraphTensor* argreduceOutTensor = nil;
- if(reduction_type == "max")
+ if(reduction_type == MPSReductionType::MAX)
argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor
axis:(NSInteger)dim_
name:@"argmax_out"];
- else if(reduction_type == "min")
+ else if(reduction_type == MPSReductionType::MIN)
argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor
axis:(NSInteger)dim_
name:@"argmax_out"];
@@ -1550,7 +1432,7 @@
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "max()");
- min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, "max", "max_out_mps");
+ min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MAX, "max_out_mps");
}
// Min out with dim
@@ -1564,16 +1446,170 @@
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "min()");
- min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, "min", "min_out_mps");
+ min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MIN, "min_out_mps");
}
+void argmax_argmin_out_mps
+ (const Tensor& input_t,
+ c10::optional<int64_t> dim,
+ bool keepdim,
+ const Tensor& output_t,
+ MPSReductionType reduction_type,
+ const std::string& func_name) {
+ namespace native_mps = at::native::mps;
+
+ // Derive from MPSCachedGraph
+ struct CachedGraph : public native_mps::MPSCachedGraph
+ {
+ CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+ MPSGraphTensor *inputTensor_ = nil;
+ MPSGraphTensor *outputTensor_ = nil;
+ };
+
+ native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
+
+ int64_t dim_;
+
+ if (dim.has_value()) {
+ dim_ = maybe_wrap_dim(dim.value(), input_t.dim());
+ zero_numel_check_dims(input_t, dim_, reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()");
+ } else {
+ TORCH_CHECK_INDEX(
+ input_t.numel() != 0,
+ reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()" , ": Expected reduction dim to be specified for input.numel() == 0.");
+ // Since input will be flattened, take argmax or argmin along 0'th dimension
+ dim_ = 0;
+ }
+
+ // Calculate the output shape according to keepdim=True
+ // If there is no dim argument, the input shape is flattened
+ IntArrayRef input_shape = input_t.sizes();
+ int64_t num_input_dims = input_shape.size();
+ NSMutableArray<NSNumber*> *apparent_in_shape = nil;
+ NSMutableArray<NSNumber*> *apparent_out_shape = nil;
+
+ if(dim.has_value()) {
+ apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
+ for(int i = 0; i < num_input_dims; i++) {
+ if(dim_ == i)
+ apparent_out_shape[i] = @1;
+ else
+ apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
+ }
+ }
+ else {
+ apparent_in_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
+ int64_t num_in_elements = 1;
+ for(int i = 0; i < num_input_dims; i++) {
+ num_in_elements *= input_shape[i];
+ }
+ apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements];
+
+ apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
+ apparent_out_shape[0] = @1;
+ }
+
+ if (output_t.numel() == 0) {
+ return;
+ }
+
+ auto stream = at::mps::getCurrentMPSStream();
+
+ @autoreleasepool {
+ string key = func_name + to_string(dim_) + ":" + native_mps::getTensorsStringKey(input_t);
+ CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+
+ if(!cachedGraph) {
+ native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
+
+ CachedGraph *newCachedGraph = nil;
+
+ @autoreleasepool {
+ MPSGraph* mpsGraph = native_mps::make_mps_graph();
+ newCachedGraph = new CachedGraph(mpsGraph);
+
+ MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
+
+ MPSGraphTensor* castInputTensor = nil;
+ MPSGraphTensor* argreduceOutTensor = nil;
+
+ if(input_t.scalar_type() != ScalarType::Float &&
+ input_t.scalar_type() != ScalarType::Int &&
+ input_t.scalar_type() != ScalarType::Half)
+ castInputTensor = [mpsGraph castTensor:inputTensor
+ toType:MPSDataTypeFloat32
+ name:@"castInputTensor"];
+ else
+ castInputTensor = inputTensor;
+
+ if (reduction_type == MPSReductionType::MAX) {
+ argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor
+ axis:(NSInteger)dim_
+ name:nil];
+ }
+ else {
+ argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor
+ axis:(NSInteger)dim_
+ name:nil];
+ }
+ MPSGraphTensor* outputTensor = [mpsGraph castTensor:argreduceOutTensor
+ toType:MPSDataTypeInt64
+ name:@"castOutpuTensor"];
+
+ newCachedGraph->inputTensor_ = inputTensor;
+ newCachedGraph->outputTensor_ = outputTensor;
+ }
+ return newCachedGraph;
+ });
+ cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+ }
+
+ native_mps::Placeholder inputPlaceholder = native_mps::Placeholder();
+ if(apparent_in_shape)
+ inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape);
+ else
+ inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
+
+ auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
+
+ NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
+ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
+ };
+
+ NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
+ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
+ };
+
+ native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ }
+}
+
+TORCH_IMPL_FUNC(argmax_out_mps)
+ (const Tensor& input_t,
+ c10::optional<int64_t> dim,
+ bool keepdim,
+ const Tensor& output_t) {
+
+ argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MAX, "argmax_out_mps");
+}
+
+TORCH_IMPL_FUNC(argmin_out_mps)
+ (const Tensor& input_t,
+ c10::optional<int64_t> dim,
+ bool keepdim,
+ const Tensor& output_t) {
+
+ argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MIN, "argmin_out_mps");
+}
+
+
// Min/Max with dim
std::tuple<Tensor, Tensor> min_max_mps
(const Tensor& input_t,
int64_t dim,
bool keepdim,
- string reduction_type,
- string func_name) {
+ MPSReductionType reduction_type,
+ const std::string& func_name) {
namespace native_mps = at::native::mps;
@@ -1661,7 +1697,7 @@
int64_t dim,
bool keepdim) {
- return min_max_mps(input_t, dim, keepdim, "max", "max_mps");
+ return min_max_mps(input_t, dim, keepdim, MPSReductionType::MAX, "max_mps");
}
// Min with dim
@@ -1670,9 +1706,8 @@
int64_t dim,
bool keepdim) {
- return min_max_mps(input_t, dim, keepdim, "min", "min_mps");
+ return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps");
}
-}
-
-}
+} // native
+} // at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f22852f..4bca2bd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -660,6 +660,7 @@
structured: True
dispatch:
CPU, CUDA: argmin_out
+ MPS: argmin_out_mps
- func: acosh(Tensor self) -> Tensor
variants: function, method
diff --git a/test/test_mps.py b/test/test_mps.py
index f8614ea..dc857a3 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2187,9 +2187,14 @@
helper((2, 3, 4, 5))
- # Test forward argmax
- def test_argmax(self):
- def helper(n, c, h, w, dtype=torch.float32):
+ # Test forward argmin argmax
+ def test_argmin_argmax(self):
+ def helper(n, c, h, w, reduction_type, dtype=torch.float32):
+ if reduction_type == "max":
+ arg_reduction_fn = torch.argmax
+ else:
+ arg_reduction_fn = torch.argmin
+
cpu_x = None
x = None
if(dtype not in [torch.float32, torch.bool]):
@@ -2202,46 +2207,50 @@
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
- y = torch.argmax(x)
- ref_y = torch.argmax(cpu_x)
+ y = arg_reduction_fn(x)
+ ref_y = arg_reduction_fn(cpu_x)
self.assertEqual(y, ref_y)
- y_0 = torch.argmax(x, dim=0)
- refy_0 = torch.argmax(cpu_x, dim=0)
+ y_0 = arg_reduction_fn(x, dim=0)
+ refy_0 = arg_reduction_fn(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
- y_0dim = torch.argmax(x, dim=0, keepdim=True)
- refy_0dim = torch.argmax(cpu_x, dim=0, keepdim=True)
+ y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
+ refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
- y_1 = torch.argmax(x, dim=1)
- refy_1 = torch.argmax(cpu_x, dim=1)
+ y_1 = arg_reduction_fn(x, dim=1)
+ refy_1 = arg_reduction_fn(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
- y_1dim = torch.argmax(x, dim=1, keepdim=True)
- refy_1dim = torch.argmax(cpu_x, dim=1, keepdim=True)
+ y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
+ refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
- y_2 = torch.argmax(x, dim=2)
- refy_2 = torch.argmax(cpu_x, dim=2)
+ y_2 = arg_reduction_fn(x, dim=2)
+ refy_2 = arg_reduction_fn(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
- y_2dim = torch.argmax(x, dim=2, keepdim=True)
- refy_2dim = torch.argmax(cpu_x, dim=2, keepdim=True)
+ y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
+ refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
- y_3 = torch.argmax(x, dim=3)
- refy_3 = torch.argmax(cpu_x, dim=3)
+ y_3 = arg_reduction_fn(x, dim=3)
+ refy_3 = arg_reduction_fn(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
- y_3dim = torch.argmax(x, dim=3, keepdim=True)
- refy_3dim = torch.argmax(cpu_x, dim=3, keepdim=True)
+ y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
+ refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
- helper(2, 8, 4, 4, torch.float32)
- helper(2, 8, 4, 4, torch.int32)
- helper(2, 8, 4, 4, torch.float16)
- helper(2, 8, 4, 4, torch.int64)
+ helper(2, 8, 4, 4, "max", torch.float32)
+ helper(2, 8, 4, 4, "max", torch.int32)
+ helper(2, 8, 4, 4, "max", torch.float16)
+ helper(2, 8, 4, 4, "max", torch.int64)
+ helper(2, 8, 4, 4, "min", torch.float32)
+ helper(2, 8, 4, 4, "min", torch.int32)
+ helper(2, 8, 4, 4, "min", torch.float16)
+ helper(2, 8, 4, 4, "min", torch.int64)
# Test forward max
# Note - don't test grad now