diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index 1d8dd4182a6c9a..1890d6050d9493 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -32,6 +32,7 @@ enum class MacOSVersion : uint32_t { MACOS_VER_13_0_PLUS = 0, MACOS_VER_13_1_PLUS, MACOS_VER_13_2_PLUS, + MACOS_VER_13_3_PLUS, }; //----------------------------------------------------------------- diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 3a485ba295947e..0576f9bb78990e 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -98,11 +98,15 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector( sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES; static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES; + static bool _macos_13_3_plus = NO; + if (@available(macOS 13.3, *)) + _macos_13_3_plus = YES; switch (version) { case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus; case MacOSVersion::MACOS_VER_13_1_PLUS: return _macos_13_1_plus; case MacOSVersion::MACOS_VER_13_2_PLUS: return _macos_13_2_plus; + case MacOSVersion::MACOS_VER_13_3_PLUS: return _macos_13_3_plus; default: return false; } } diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm index 354cdb435959b5..519de6afa3b85a 100644 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ b/aten/src/ATen/native/mps/operations/Inverse.mm @@ -10,7 +10,7 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); - if (!is_macos_13_or_newer()) { + if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt); auto cpu_result = result.clone().to("cpu"); @@ -24,6 +24,10 @@ MPSStream* stream = getCurrentMPSStream(); info.zero_(); + if (A.numel() == 0) { + return; + } + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index c07e22ef750282..f858714fb82d5c 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -1023,7 +1023,13 @@ Tensor std_mps( TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t) { using CachedGraph = MPSUnaryCachedGraph; - if (output_t.numel() == 0 || input_t.numel() == 0) { + if (input_t.numel() == 0) { + output_t.zero_(); + return; + } else if (input_t.numel() == 1) { + output_t.copy_(input_t.view_as(output_t).to(at::kBool)); + return; + } else if (output_t.numel() == 0) { return; } diff --git a/test/test_mps.py b/test/test_mps.py index 81ba49a782e599..4841e6a0e75706 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8957,6 +8957,8 @@ class TestConsistency(TestCase): 'native_batch_norm': ['f32'], 'minreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'maxreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.inv': ['f32'], + 'linalg.inv_ex': ['f32'], } @@ -9171,7 +9173,6 @@ class TestConsistency(TestCase): 'chalf': None, 'diag_embed': [torch.uint8], 'diagonal_scatter': [torch.uint8], - 'linalg.inv': [torch.float32], 'long': None, 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64],