From ecf3775ed422e4da70366d5436d9ed3ab7a1c12a Mon Sep 17 00:00:00 2001 From: Youssef1313 Date: Wed, 8 Jan 2025 06:08:35 +0100 Subject: [PATCH] Fix potential race condition in GetResultOrRunClassInitialize --- .../Execution/TestClassInfo.cs | 136 ++++++++++-------- 1 file changed, 74 insertions(+), 62 deletions(-) diff --git a/src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs b/src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs index b3c6510e62..86fac8fd8e 100644 --- a/src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs +++ b/src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Extensions; @@ -254,6 +254,7 @@ public void RunClassInitialize(TestContext testContext) // If no class initialize and no base class initialize, return if (ClassInitializeMethod is null && BaseClassInitMethods.Count == 0) { + DebugEx.Assert(false, "Caller shouldn't call us if nothing to execute"); IsClassInitializeExecuted = true; return; } @@ -270,45 +271,37 @@ public void RunClassInitialize(TestContext testContext) string? failedClassInitializeMethodName = string.Empty; // If class initialization is not done, then do it. + DebugEx.Assert(!IsClassInitializeExecuted, "Caller shouldn't call us if it was executed."); if (!IsClassInitializeExecuted) { - // Acquiring a lock is usually a costly operation which does not need to be - // performed every time if the class initialization is already executed. - lock (_testClassExecuteSyncObject) + try { - // Perform a check again. - if (!IsClassInitializeExecuted) + // We have discovered the methods from bottom (most derived) to top (less derived) but we want to execute + // from top to bottom. + for (int i = BaseClassInitMethods.Count - 1; i >= 0; i--) { - try + initializeMethod = BaseClassInitMethods[i]; + ClassInitializationException = InvokeInitializeMethod(initializeMethod, testContext); + if (ClassInitializationException is not null) { - // We have discovered the methods from bottom (most derived) to top (less derived) but we want to execute - // from top to bottom. - for (int i = BaseClassInitMethods.Count - 1; i >= 0; i--) - { - initializeMethod = BaseClassInitMethods[i]; - ClassInitializationException = InvokeInitializeMethod(initializeMethod, testContext); - if (ClassInitializationException is not null) - { - break; - } - } - - if (ClassInitializationException is null) - { - initializeMethod = ClassInitializeMethod; - ClassInitializationException = InvokeInitializeMethod(ClassInitializeMethod, testContext); - } - } - catch (Exception ex) - { - ClassInitializationException = ex; - failedClassInitializeMethodName = initializeMethod?.Name ?? ClassInitializeMethod?.Name; - } - finally - { - IsClassInitializeExecuted = true; + break; } } + + if (ClassInitializationException is null) + { + initializeMethod = ClassInitializeMethod; + ClassInitializationException = InvokeInitializeMethod(ClassInitializeMethod, testContext); + } + } + catch (Exception ex) + { + ClassInitializationException = ex; + failedClassInitializeMethodName = initializeMethod?.Name ?? ClassInitializeMethod?.Name; + } + finally + { + IsClassInitializeExecuted = true; } } @@ -385,8 +378,6 @@ internal UnitTestResult GetResultOrRunClassInitialize(ITestContext testContext, return clonedInitializeResult; } - DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available."); - // For optimization purposes, return right away if there is nothing to execute. // For STA, this avoids starting a thread when we know it will do nothing. // But we still return early even not STA. @@ -396,41 +387,62 @@ internal UnitTestResult GetResultOrRunClassInitialize(ITestContext testContext, return _classInitializeResult = new(ObjectModelUnitTestOutcome.Passed, null); } - bool isSTATestClass = AttributeComparer.IsDerived(ClassAttribute); - bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); - if (isSTATestClass - && isWindowsOS - && Thread.CurrentThread.GetApartmentState() != ApartmentState.STA) + // At this point, maybe class initialize was executed by another thread such + // that TryGetClonedCachedClassInitializeResult would return non-null. + // Now, we need to check again, but under a lock. + // Note that we are duplicating the logic above. + // We could keep the logic in lock only and not duplicate, but we don't want to pay + // the lock cost unnecessarily for a common case. + // We also need to lock to avoid concurrency issues and guarantee that class init is called only once. + lock (_testClassExecuteSyncObject) { - UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete"); - Thread entryPointThread = new(() => result = DoRun()) - { - Name = "MSTest STATestClass ClassInitialize", - }; - - entryPointThread.SetApartmentState(ApartmentState.STA); - entryPointThread.Start(); + clonedInitializeResult = TryGetClonedCachedClassInitializeResult(); - try + // Optimization: If we already ran before and know the result, return it. + if (clonedInitializeResult is not null) { - entryPointThread.Join(); - return result; + DebugEx.Assert(IsClassInitializeExecuted, "Class initialize result should be available if and only if class initialize was executed"); + return clonedInitializeResult; } - catch (Exception ex) + + DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available."); + + bool isSTATestClass = AttributeComparer.IsDerived(ClassAttribute); + bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); + if (isSTATestClass + && isWindowsOS + && Thread.CurrentThread.GetApartmentState() != ApartmentState.STA) { - PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString()); - return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation())); + UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete"); + Thread entryPointThread = new(() => result = DoRun()) + { + Name = "MSTest STATestClass ClassInitialize", + }; + + entryPointThread.SetApartmentState(ApartmentState.STA); + entryPointThread.Start(); + + try + { + entryPointThread.Join(); + return result; + } + catch (Exception ex) + { + PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString()); + return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation())); + } } - } - else - { - // If the requested apartment state is STA and the OS is not Windows, then warn the user. - if (!isWindowsOS && isSTATestClass) + else { - PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning); - } + // If the requested apartment state is STA and the OS is not Windows, then warn the user. + if (!isWindowsOS && isSTATestClass) + { + PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning); + } - return DoRun(); + return DoRun(); + } } // Local functions