[MPS] Fix bidirectional LSTM & small one-direction LSTM fix (#95563)
Fixes #94754
With this PR I hope to finish my breathtaking journey of fixing MPS LSTM.
Here, I enable `bidirectional` on MPS. Also, I've noticed that cache key did not account for all parameters, so there could have been problems with one-directional LSTM when created without bias or dropout and then with one of them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95563
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index 6b2b985..c2dbb99 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -1422,7 +1422,7 @@
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
#ifdef USE_MPS
- if (_input.is_mps() && !bidirectional) {
+ if (_input.is_mps()) {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
diff --git a/aten/src/ATen/native/mps/operations/RnnOps.mm b/aten/src/ATen/native/mps/operations/RnnOps.mm
index 9e59a6c..fbf0a99 100644
--- a/aten/src/ATen/native/mps/operations/RnnOps.mm
+++ b/aten/src/ATen/native/mps/operations/RnnOps.mm
@@ -23,6 +23,85 @@
return output_dimensions;
}
+/**
+ * Accepts tensors in Pytorch API format and returns tensors in MPS API format
+ * @return tuple of tensors to use with MPS API in order:
+ * stateTensor, cellStateTensor, recurrentWeight, inputWeight, biasTensor
+ */
+static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*>
+ getMPSTensorsFromPytorchTensors(MPSGraph* mpsGraph,
+ MPSGraphTensor* stateTensor, MPSGraphTensor* cellStateTensor,
+ NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList,
+ NSMutableArray<MPSGraphTensor*> *kernelWeightsList,
+ NSMutableArray<MPSGraphTensor*> *kernelBiasList,
+ NSMutableArray<MPSGraphTensor*> *recurrentBiasList,
+ bool has_biases, bool bidirectional, size_t layer_no) {
+ MPSGraphTensor* biasTensor_ = nil;
+ MPSGraphTensor* stateTensor_ = nil, *cellStateTensor_ = nil;
+ MPSGraphTensor* recurrentWeight_ = nil, *inputWeight_ = nil;
+
+ if (bidirectional) {
+ stateTensor_ = [mpsGraph sliceTensor:stateTensor
+ dimension:0
+ start:layer_no * 2
+ length:2
+ name:nil];
+ // [2, N, H] -> [N, 2, H]
+ stateTensor_ = [mpsGraph transposeTensor:stateTensor_ dimension: 0 withDimension: 1 name:nil];
+ // [N, 2, H] -> [N, 2 * H]
+ stateTensor_ = [mpsGraph flatten2DTensor:stateTensor_ axis:1 name:nil];
+ cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
+ dimension:0
+ start:layer_no * 2
+ length:2
+ name:nil];
+ cellStateTensor_ = [mpsGraph transposeTensor:cellStateTensor_ dimension: 0 withDimension: 1 name:nil];
+ cellStateTensor_ = [mpsGraph flatten2DTensor:cellStateTensor_ axis:1 name:nil];
+
+ recurrentWeight_ = [mpsGraph
+ concatTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2] axis: 0 name: nil]
+ withTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2 + 1] axis: 0 name: nil]
+ dimension: 0
+ name: nil
+ ];
+ inputWeight_ = [mpsGraph
+ concatTensor: kernelWeightsList[layer_no * 2]
+ withTensor: kernelWeightsList[layer_no * 2 + 1]
+ dimension: 0
+ name: nil
+ ];
+ if (has_biases) {
+ auto biasTensorFwd_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2]
+ secondaryTensor:recurrentBiasList[layer_no * 2]
+ name:nil];
+ auto biasTensorBack_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2 + 1]
+ secondaryTensor:recurrentBiasList[layer_no * 2 + 1]
+ name:nil];
+
+ biasTensor_ = [mpsGraph concatTensor:biasTensorFwd_ withTensor:biasTensorBack_ dimension:0 name:nil];
+ }
+ } else {
+ stateTensor_ = [mpsGraph sliceTensor:stateTensor
+ dimension:0
+ start:layer_no
+ length:1
+ name:nil];
+ cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
+ dimension:0
+ start:layer_no
+ length:1
+ name:nil];
+ recurrentWeight_ = recurrentKernelWeightsList[layer_no];
+ inputWeight_ = kernelWeightsList[layer_no];
+ if (has_biases) {
+ biasTensor_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no]
+ secondaryTensor:recurrentBiasList[layer_no]
+ name:nil];
+ }
+ }
+ return std::make_tuple(stateTensor_, cellStateTensor_, recurrentWeight_, inputWeight_, biasTensor_);
+}
+
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
using namespace mps;
@@ -38,15 +117,17 @@
std::vector<Tensor> recurrent_kernel_weights;
std::vector<Tensor> biases;
std::vector<Tensor> recurrent_biases;
- for (size_t i = 0; i < num_layers; i+=1) {
+
+ const int64_t total_layers = num_layers * (bidirectional ? 2 : 1);
+
+ for (const auto i : c10::irange(total_layers)) {
+ const int stride = (has_biases ? 4 : 2);
+ kernel_weights.push_back(params[i*stride]);
+ recurrent_kernel_weights.push_back(params[i*stride+1]);
+
if (has_biases) {
- kernel_weights.push_back(params[i*4]);
- recurrent_kernel_weights.push_back(params[i*4+1]);
- biases.push_back(params[i*4+2]);
- recurrent_biases.push_back(params[i*4+3]);
- } else {
- kernel_weights.push_back(params[i*2]);
- recurrent_kernel_weights.push_back(params[i*2+1]);
+ biases.push_back(params[i*stride+2]);
+ recurrent_biases.push_back(params[i*stride+3]);
}
}
@@ -65,7 +146,7 @@
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
- string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers);
+ string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" + std::to_string(batch_first);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
@@ -81,7 +162,7 @@
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers];
- for (size_t i = 0; i < num_layers; i += 1) {
+ for (const auto i : c10::irange(total_layers)) {
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
if(has_biases) {
@@ -100,7 +181,7 @@
MPSGraphTensor* cellStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(hx[1]));
std::vector<MPSGraphTensor*> inputTensors = {inputTensor, stateTensor, cellStateTensor,};
- if(batch_first) {
+ if (batch_first) {
inputTensor = [mpsGraph transposeTensor:inputTensor
dimension:0
withDimension:1
@@ -113,49 +194,61 @@
NSMutableArray<MPSGraphTensor*>* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* outputCellStateFwdArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
- for(int i = 0; i < num_layers; i++) {
- MPSGraphTensor* biasTensor = nil;
- if(has_biases) {
- biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
- secondaryTensor:recurrentBiasList[i]
- name:nil];
- }
- MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
- dimension:0
- start:i
- length:1
- name:nil];
- MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
- dimension:0
- start:i
- length:1
- name:nil];
+ for (int i = 0; i < num_layers; i++) {
+ auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor,
+ recurrentKernelWeightsList, kernelWeightsList,
+ kernelBiasList, recurrentBiasList, has_biases,
+ bidirectional, i);
+ MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData);
+ MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData);
+ MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData);
+
+
outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_
- recurrentWeight:recurrentKernelWeightsList[i]
- inputWeight:kernelWeightsList[i]
- bias:biasTensor
+ recurrentWeight:recurrentWeight_
+ inputWeight:inputWeight_
+ bias:biasTensor_
initState:stateTensor_
initCell:cellStateTensor_
descriptor:opDesc
name:nil];
inputTensor_ = [outputs objectAtIndex:0];
- // no need to keep a final layer output copy as it is
+ // no need to keep the final layer output copy as it is
// returned anyway and not used in backprop
- if(i != num_layers - 1) {
+ if (i != num_layers - 1) {
[layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_
axis:0
name:nil]];
}
- if(dropout_p>0.0 && train && (i!=num_layers-1)) {
+ if (dropout_p>0.0 && train && (i!=num_layers-1)) {
inputTensor_ = [mpsGraph dropoutTensor:inputTensor_
rate:dropout_p
name:nil];
}
- [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]];
- [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]];
+ if (bidirectional) {
+ // [1, N, 2 * H]
+ auto stateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil];
+ auto stateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:0 length:1 name:nil];
+ // [1, N, H] ([1, N, 0:H])
+ auto stateForward = [mpsGraph sliceTensor:stateLastT dimension: -1 start:0 length:hx[0].sizes()[2] name:nil];
+ // [1, N, H] ([1, N, H:2H])
+ auto stateBack = [mpsGraph sliceTensor:stateFirstT dimension: -1 start:hx[0].sizes()[2] length:hx[0].sizes()[2] name:nil];
+ [outputStateArray addObject:stateForward];
+ [outputStateArray addObject:stateBack];
+
+ auto cellStateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil];
+ auto cellStateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:0 length:1 name:nil];
+ auto cellStateForward = [mpsGraph sliceTensor:cellStateLastT dimension: -1 start:0 length:hx[1].sizes()[2] name:nil];
+ auto cellStateBack = [mpsGraph sliceTensor:cellStateFirstT dimension: -1 start:hx[1].sizes()[2] length:hx[1].sizes()[2] name:nil];
+ [outputCellStateArray addObject:cellStateForward];
+ [outputCellStateArray addObject:cellStateBack];
+ } else {
+ [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]];
+ [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]];
+ }
[outputCellStateFwdArray addObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:1]
axis:0
name:nil]];
@@ -205,21 +298,18 @@
NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
- Placeholder kernelWeight, recurrentKernelWeight, bias, recurrentBias;
-
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
- for (size_t i = 0; i < num_layers; i+=1) {
- kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
- recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
+ for (const auto i : c10::irange(total_layers)) {
+ Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
+ Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
- if(has_biases) {
- bias = Placeholder([biasList objectAtIndex:i], biases[i]);
- recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
+ if (has_biases) {
+ Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]);
+ Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
}
-
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input);
Placeholder selfState = Placeholder(cachedGraph->inputTensors_[1], hx[0]);
@@ -274,22 +364,22 @@
std::vector<Tensor> recurrent_kernel_weights;
std::vector<Tensor> biases;
std::vector<Tensor> recurrent_biases;
- for (size_t i = 0; i < num_layers; i+=1) {
- if(has_biases) {
- kernel_weights.push_back(params[i*4]);
- recurrent_kernel_weights.push_back(params[i*4+1]);
- biases.push_back(params[i*4+2]);
- recurrent_biases.push_back(params[i*4+3]);
- } else {
- kernel_weights.push_back(params[i*2]);
- recurrent_kernel_weights.push_back(params[i*2+1]);
- }
+
+ const int64_t total_layers = num_layers * (bidirectional ? 2 : 1);
+
+ for (const auto i : c10::irange(total_layers)) {
+ const int stride = (has_biases ? 4 : 2);
+ kernel_weights.push_back(params[i*stride]);
+ recurrent_kernel_weights.push_back(params[i*stride+1]);
+ if(has_biases) {
+ biases.push_back(params[i*stride + 2]);
+ recurrent_biases.push_back(params[i*stride + 3]);
+ }
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;
- std::vector<MPSGraphTensor*> outputTensors_;
NSMutableArray<MPSGraphTensor*> *kernelWeightsList_ = nil;
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
@@ -308,9 +398,9 @@
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
- string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers);
+ string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" + std::to_string(batch_first);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
- if(!cachedGraph) {
+ if (!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@@ -323,10 +413,10 @@
NSMutableArray<MPSGraphTensor*> *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
- for (size_t i = 0; i < num_layers; i += 1) {
+ for (const auto i : c10::irange(total_layers)) {
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
- if(has_biases) {
+ if (has_biases) {
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
}
@@ -377,6 +467,8 @@
NSMutableArray<MPSGraphTensor*>* gradStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
+ auto hidden_size = hx[0].sizes()[2];
+
for (int i = num_layers - 1; i >= 0; i--) {
MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor
dimension:0
@@ -394,46 +486,57 @@
cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd
axis:0
name:nil];
- MPSGraphTensor* biasTensor = nil;
- if(has_biases) {
- biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
- secondaryTensor:recurrentBiasList[i]
- name:nil];
+ auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor,
+ recurrentKernelWeightsList, kernelWeightsList,
+ kernelBiasList, recurrentBiasList, has_biases,
+ bidirectional, i);
+ MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData);
+ MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData);
+ MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData);
+
+ MPSGraphTensor* gradientHyTensor_ = nil, *gradientCyTensor_ = nil;
+ if (bidirectional) {
+ gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
+ dimension:0
+ start:i * 2
+ length:2
+ name:nil];
+ // [2, N, H] -> [N, 2, H]
+ gradientHyTensor_ = [mpsGraph transposeTensor:gradientHyTensor_ dimension: 0 withDimension: 1 name:nil];
+ // [N, 2, H] -> [N, 2 * H]
+ gradientHyTensor_ = [mpsGraph flatten2DTensor:gradientHyTensor_ axis:1 name:nil];
+
+
+ gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
+ dimension:0
+ start:i * 2
+ length:2
+ name:nil];
+ gradientCyTensor_ = [mpsGraph transposeTensor:gradientCyTensor_ dimension: 0 withDimension: 1 name:nil];
+ gradientCyTensor_ = [mpsGraph flatten2DTensor:gradientCyTensor_ axis:1 name:nil];
} else {
- biasTensor = [mpsGraph constantWithScalar:0.0
- dataType:inputTensor.dataType];
+ gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
+ dimension:0
+ start:i
+ length:1
+ name:nil];
+
+ gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
+ dimension:0
+ start:i
+ length:1
+ name:nil];
}
- MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
- dimension:0
- start:i
- length:1
- name:nil];
- MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
- dimension:0
- start:i
- length:1
- name:nil];
- MPSGraphTensor* gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
- dimension:0
- start:i
- length:1
- name:nil];
-
- MPSGraphTensor* gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
- dimension:0
- start:i
- length:1
- name:nil];
-
MPSGraphTensor* iterationInputTensor_ = nil;
if (i == 0) {
iterationInputTensor_ = inputTensor;
} else {
iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor
dimension: 0
- // last element in layersOutputsTensor contains
- // **inputs** for the last layer
+ // the last element in layersOutputsTensor
+ // contains **inputs** for the **last** layer
+ // and so on
start: i - num_layers
length: 1
name: nil];
@@ -443,14 +546,14 @@
}
outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_
- recurrentWeight: recurrentKernelWeightsList[i]
+ recurrentWeight: recurrentWeight_
sourceGradient: gradientTensor_
zState: zState
cellOutputFwd: cellStateFwd
stateGradient: gradientHyTensor_
cellGradient: gradientCyTensor_
- inputWeight: kernelWeightsList[i]
- bias: biasTensor
+ inputWeight: inputWeight_
+ bias: biasTensor_
initState: stateTensor_
initCell: cellStateTensor_
mask: nil
@@ -459,14 +562,103 @@
name: nil];
gradientTensor_ = [outputs objectAtIndex:0];
- [gradRecWeightsArray insertObject:[outputs objectAtIndex:1] atIndex:0];
- [gradWeightsArray insertObject:[outputs objectAtIndex:2] atIndex:0];
- [gradBiasArray insertObject: [outputs objectAtIndex:3] atIndex:0];
- [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:4] axis:0 name:nil] atIndex:0];
- [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:5] axis:0 name:nil] atIndex:0];
- }
- std::vector<MPSGraphTensor*> outputTensors = {[outputs objectAtIndex:0],[outputs objectAtIndex:1],[outputs objectAtIndex:2],[outputs objectAtIndex:3], [outputs objectAtIndex:4], [outputs objectAtIndex:5]};
+ if (bidirectional) {
+ int outputIter = 1;
+ auto gradRecWeightsBidirectional = [outputs objectAtIndex:outputIter++];
+ auto gradRecWeightFwd = [mpsGraph sliceTensor:gradRecWeightsBidirectional
+ dimension: 0
+ start: 0
+ length: 1
+ name: nil];
+ gradRecWeightFwd = [mpsGraph squeezeTensor:gradRecWeightFwd axis:0 name: nil];
+ auto gradRecWeightBack = [mpsGraph sliceTensor:gradRecWeightsBidirectional
+ dimension: 0
+ start: 1
+ length: 1
+ name: nil];
+ gradRecWeightBack = [mpsGraph squeezeTensor:gradRecWeightBack axis:0 name: nil];
+ // inverse order
+ [gradRecWeightsArray insertObject:gradRecWeightBack atIndex:0];
+ [gradRecWeightsArray insertObject:gradRecWeightFwd atIndex:0];
+
+ auto gradWeightsBidirectional = [outputs objectAtIndex:outputIter++];
+ auto gradWeightFwd = [mpsGraph sliceTensor:gradWeightsBidirectional
+ dimension: 0
+ start: 0
+ length: hidden_size * 4
+ name: nil];
+ auto gradWeightBack = [mpsGraph sliceTensor:gradWeightsBidirectional
+ dimension: 0
+ start: hidden_size * 4
+ length: hidden_size * 4
+ name: nil];
+
+ [gradWeightsArray insertObject:gradWeightBack atIndex:0];
+ [gradWeightsArray insertObject:gradWeightFwd atIndex:0];
+
+ if (has_biases) {
+ // has shape [1, 1, 8H] vs [8H] as should be
+ // so, squeeze these two first dimensions
+ auto gradBiasBidirectional = [outputs objectAtIndex:outputIter++];
+ gradBiasBidirectional = [mpsGraph squeezeTensor: gradBiasBidirectional
+ axes: @[@0, @1]
+ name: nil];
+ auto gradBiasFwd = [mpsGraph sliceTensor:gradBiasBidirectional
+ dimension: 0
+ start: 0
+ length: hidden_size * 4
+ name: nil];
+ auto gradBiasBack = [mpsGraph sliceTensor:gradBiasBidirectional
+ dimension: 0
+ start: hidden_size * 4
+ length: hidden_size * 4
+ name: nil];
+
+ [gradBiasArray insertObject: gradBiasBack atIndex:0];
+ [gradBiasArray insertObject: gradBiasFwd atIndex:0];
+ }
+
+ auto gradStateBidirectional = [outputs objectAtIndex:outputIter++];
+ auto gradStateFwd = [mpsGraph sliceTensor:gradStateBidirectional
+ dimension: 1
+ start: 0
+ length: hidden_size
+ name: nil];
+ auto gradStateBack = [mpsGraph sliceTensor:gradStateBidirectional
+ dimension: 1
+ start: hidden_size
+ length: hidden_size
+ name: nil];
+
+ [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateBack axis:0 name:nil] atIndex:0];
+ [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateFwd axis:0 name:nil] atIndex:0];
+
+ auto gradCellStateBidirectional = [outputs objectAtIndex:outputIter++];
+ auto gradCellStateFwd = [mpsGraph sliceTensor:gradCellStateBidirectional
+ dimension: 1
+ start: 0
+ length: hidden_size
+ name: nil];
+ auto gradCellStateBack = [mpsGraph sliceTensor:gradCellStateBidirectional
+ dimension: 1
+ start: hidden_size
+ length: hidden_size
+ name: nil];
+
+ [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateBack axis:0 name:nil] atIndex:0];
+ [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateFwd axis:0 name:nil] atIndex:0];
+ } else {
+ int outputIter = 1;
+ [gradRecWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0];
+ [gradWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0];
+ if (has_biases) {
+ [gradBiasArray insertObject: [outputs objectAtIndex:outputIter++] atIndex:0];
+ }
+ [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0];
+ [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0];
+ }
+ }
if (batch_first) {
MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_
dimension: 0
@@ -477,7 +669,6 @@
newCachedGraph->gradOutput_ = gradientTensor_;
}
- newCachedGraph->outputTensors_ = outputTensors;
newCachedGraph->gradRecWeights_ = gradRecWeightsArray;
newCachedGraph->gradWeights_ = gradWeightsArray;
newCachedGraph->gradBias_ = gradBiasArray;
@@ -514,18 +705,15 @@
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_;
NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
- Placeholder kernelWeight;
- Placeholder recurrentKernelWeight;
- Placeholder bias;
- Placeholder recurrentBias;
- for (size_t i = 0; i < num_layers; i+=1) {
- kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
- recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
+
+ for (const auto i : c10::irange(total_layers)) {
+ Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
+ Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
- if(has_biases) {
- bias = Placeholder([biasList objectAtIndex:i], biases[i]);
- recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
+ if (has_biases) {
+ Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]);
+ Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
}
@@ -556,25 +744,32 @@
Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder;
std::vector<Tensor> weights;
- for (int i = 0; i < num_layers; i++) {
+ for (const auto i : c10::irange(total_layers)) {
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]);
Tensor grad_weights = at::empty_like(kernel_weights[i]);
- Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
+
weights.push_back(grad_weights);
weights.push_back(grad_rec_weights);
- if(has_biases) {
- weights.push_back(grad_bias);
- weights.push_back(grad_bias);
- }
-
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights);
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights);
- gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
- [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
[results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
+
+ if (has_biases) {
+ Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
+
+ // In PyTorch LSTM API there are two biases. The second bias is included for CuDNN compatibility.
+ // In this implementation these two biases are added together and used further.
+ // Therefore, they have equal gradient, and it is pushed
+ // twice for each of two bias vectors.
+ weights.push_back(grad_bias);
+ weights.push_back(grad_bias);
+
+ gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
+ [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
+ }
}
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
diff --git a/test/test_mps.py b/test/test_mps.py
index 062a513..bce494c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9255,91 +9255,91 @@
self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
class TestRNNMPS(TestCaseMPS):
- def test_lstm_1(self, device="mps", dtype=torch.float32):
- for layers in [1] if product_version < 13.0 else [1, 2, 5]:
- torch.random.manual_seed(42)
- rnn = nn.LSTM(7, 4, layers, device="cpu")
- input = torch.randn(2, 3, 7, device="cpu")
- hx = torch.randn(layers, 3, 4, device="cpu")
- cx = torch.randn(layers, 3, 4, device="cpu")
+ def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
+ seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
+ rnn = nn.LSTM(
+ input_size=input_size,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ bias=bias,
+ bidirectional=bidirectional,
+ batch_first=batch_first,
+ device="cpu"
+ )
+ bidirectional_mul = 2 if bidirectional else 1
- cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
+ if batch_first:
+ input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
+ hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+ requires_grad=backward)
+ cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+ requires_grad=backward)
+ else:
+ input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
+ hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+ requires_grad=backward)
+ cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+ requires_grad=backward)
+ cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
+
+ rnn = rnn.to(device)
+ input = input.to(device)
+ hx = hx.to(device)
+ cx = cx.to(device)
+ output, (hn, cn) = rnn(input, (hx, cx))
+
+ self.assertEqual(cpu_output, output)
+ self.assertEqual(cpu_hn, hn)
+ self.assertEqual(cpu_cn, cn)
+
+ def get_backward_results(rnn, device, inp, hx, cx):
rnn = rnn.to(device)
- input = input.to(device)
- hx = hx.to(device)
- cx = cx.to(device)
- output, (hn, cn) = rnn(input, (hx, cx))
+ inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
- self.assertEqual(cpu_output, output)
- self.assertEqual(cpu_hn, hn)
- self.assertEqual(cpu_cn, cn)
+ output, _ = rnn(inp, (hx, cx))
+ f = 3 * output.sum() + (hx * cx).sum()
- # test batch_first
- rnn = nn.LSTM(7, 4, layers, device="cpu", batch_first=True)
- input = torch.randn(3, 2, 7, device="cpu")
- hx = torch.randn(layers, 3, 4, device="cpu")
- cx = torch.randn(layers, 3, 4, device="cpu")
- cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
+ param_names, params = zip(*rnn.named_parameters())
+ param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
- rnn = rnn.to(device)
- input = input.to(device)
- hx = hx.to(device)
- cx = cx.to(device)
- output, (hn, cn) = rnn(input, (hx, cx))
+ input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
+ return output, param_grads, input_grad, hx_grad, cx_grad
- self.assertEqual(cpu_output, output)
- self.assertEqual(cpu_hn, hn)
- self.assertEqual(cpu_cn, cn)
+ if backward:
+ cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
+ get_backward_results(rnn, "cpu", input, hx, cx)
+ mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
+ get_backward_results(rnn, device, input, hx, cx)
+
+ self.assertEqual(cpu_hx_grad, mps_hx_grad)
+ self.assertEqual(cpu_cx_grad, mps_cx_grad)
+ self.assertEqual(cpu_output, mps_output)
+ self.assertEqual(cpu_input_grad, mps_input_grad)
+ for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
+ self.assertEqual(cpu_weight_grad, mps_weight_grad,
+ f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
+
+ LSTM_TEST_CASES = [
+ dict(), # default
+ dict(batch_first=True),
+ dict(bias=False),
+ dict(bidirectional=True),
+ dict(batch_first=True, bias=False),
+ dict(bidirectional=True, bias=False),
+ dict(bidirectional=True, batch_first=True),
+ dict(bidirectional=True, batch_first=True, bias=False)
+ ]
+
+ def test_lstm_forward(self, device="mps", dtype=torch.float32):
+ for num_layers in [1] if product_version < 13.0 else [1, 2, 5]:
+ for test_options in self.LSTM_TEST_CASES:
+ self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
def test_lstm_backward(self, device="mps", dtype=torch.float32):
- for layers in [1] if product_version < 13.0 else [1, 2, 5]:
- lstm = nn.LSTM(2, 4, layers) # initialized globally for consistent parameters init
- lstm.train()
-
- def get_results(device, inp, hx, cx):
- rnn = lstm.to(device)
- inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
-
- output, _ = rnn(inp, (hx, cx))
- f = output.sum()
-
- param_names, params = zip(*rnn.named_parameters())
- param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
-
- input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
- return output, param_grads, input_grad, hx_grad, cx_grad
-
- inp = torch.randn((5, 3, 2), requires_grad=True, dtype=dtype, device=device)
- hx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
- cx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
-
- cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
- mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
-
- self.assertEqual(cpu_hx_grad, mps_hx_grad)
- self.assertEqual(cpu_cx_grad, mps_cx_grad)
- self.assertEqual(cpu_output, mps_output)
- self.assertEqual(cpu_input_grad, mps_input_grad)
- for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
- self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
-
- # test batch_first backward
- lstm = nn.LSTM(2, 4, layers, batch_first=True)
- lstm.train()
-
- hx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
- cx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
-
- cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
- mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
-
- self.assertEqual(cpu_hx_grad, mps_hx_grad)
- self.assertEqual(cpu_cx_grad, mps_cx_grad)
- self.assertEqual(cpu_output, mps_output)
- self.assertEqual(cpu_input_grad, mps_input_grad)
- for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
- self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
+ for num_layers in [1] if product_version < 13.0 else [1, 2, 5]:
+ for test_options in self.LSTM_TEST_CASES:
+ self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
def test_RNN_cell_no_broadcasting(self):