[MPS] Fix the crash in bitwise ops on x86 platforms. (#85285)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85285
Approved by: https://github.com/razarmehr, https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/BitwiseOps.mm b/aten/src/ATen/native/mps/operations/BitwiseOps.mm
index c16818d..411f35e 100644
--- a/aten/src/ATen/native/mps/operations/BitwiseOps.mm
+++ b/aten/src/ATen/native/mps/operations/BitwiseOps.mm
@@ -122,8 +122,10 @@
return it->second;
}
NSError *error = nil;
+ MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
+ [options setLanguageVersion: MTLLanguageVersion2_3];
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
- options:nil
+ options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
libMap[key] = rc;
@@ -170,6 +172,9 @@
getMetalType(other),
kernel_name);
uint32_t length = output.numel();
+ if (length == 0) {
+ return;
+ }
dispatch_sync(stream->queue(), ^(){
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
@@ -200,6 +205,9 @@
kernel_name);
uint64_t sval = other.to<int64_t>();
uint32_t length = output.numel();
+ if (length == 0) {
+ return;
+ }
dispatch_sync(stream->queue(), ^(){
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];