Skip to content

Commit

Permalink
[MPS] Added zero check to inverse & fix for any op to avoid segfault …
Browse files Browse the repository at this point in the history
…issue (pytorch#94551)

Fixes empty placeholder error in inverse op. Change to any op should also resolve previously seen segfaults
Pull Request resolved: pytorch#94551
Approved by: https://github.com/kulinseth
  • Loading branch information
DenisVieriu97 authored and pytorchmergebot committed Feb 10, 2023
1 parent 45edf9a commit 0b31ebf
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/mps/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum class MacOSVersion : uint32_t {
MACOS_VER_13_0_PLUS = 0,
MACOS_VER_13_1_PLUS,
MACOS_VER_13_2_PLUS,
MACOS_VER_13_3_PLUS,
};

//-----------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,15 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector(
sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES;
static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
static bool _macos_13_3_plus = NO;
if (@available(macOS 13.3, *))
_macos_13_3_plus = YES;

switch (version) {
case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus;
case MacOSVersion::MACOS_VER_13_1_PLUS: return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS: return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS: return _macos_13_3_plus;
default: return false;
}
}
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/native/mps/operations/Inverse.mm
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info)
{
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
if (!is_macos_13_or_newer()) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt);
auto cpu_result = result.clone().to("cpu");
Expand All @@ -24,6 +24,10 @@
MPSStream* stream = getCurrentMPSStream();
info.zero_();

if (A.numel() == 0) {
return;
}

struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
Expand Down
8 changes: 7 additions & 1 deletion aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,13 @@ Tensor std_mps(

TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t) {
using CachedGraph = MPSUnaryCachedGraph;
if (output_t.numel() == 0 || input_t.numel() == 0) {
if (input_t.numel() == 0) {
output_t.zero_();
return;
} else if (input_t.numel() == 1) {
output_t.copy_(input_t.view_as(output_t).to(at::kBool));
return;
} else if (output_t.numel() == 0) {
return;
}

Expand Down
3 changes: 2 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8957,6 +8957,8 @@ class TestConsistency(TestCase):
'native_batch_norm': ['f32'],
'minreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'maxreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'linalg.inv': ['f32'],
'linalg.inv_ex': ['f32'],
}


Expand Down Expand Up @@ -9171,7 +9173,6 @@ class TestConsistency(TestCase):
'chalf': None,
'diag_embed': [torch.uint8],
'diagonal_scatter': [torch.uint8],
'linalg.inv': [torch.float32],
'long': None,
'nn.functional.conv1d': [torch.int64],
'nn.functional.conv2d': [torch.int64],
Expand Down

0 comments on commit 0b31ebf

Please sign in to comment.