Skip to content

Commit

Permalink
Fix potential race condition in GetResultOrRunClassInitialize
Browse files Browse the repository at this point in the history
  • Loading branch information
Youssef1313 committed Jan 9, 2025
1 parent ef3b985 commit ecf3775
Showing 1 changed file with 74 additions and 62 deletions.
136 changes: 74 additions & 62 deletions src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Extensions;
Expand Down Expand Up @@ -254,6 +254,7 @@ public void RunClassInitialize(TestContext testContext)
// If no class initialize and no base class initialize, return
if (ClassInitializeMethod is null && BaseClassInitMethods.Count == 0)
{
DebugEx.Assert(false, "Caller shouldn't call us if nothing to execute");
IsClassInitializeExecuted = true;
return;
}
Expand All @@ -270,45 +271,37 @@ public void RunClassInitialize(TestContext testContext)
string? failedClassInitializeMethodName = string.Empty;

// If class initialization is not done, then do it.
DebugEx.Assert(!IsClassInitializeExecuted, "Caller shouldn't call us if it was executed.");
if (!IsClassInitializeExecuted)
{
// Acquiring a lock is usually a costly operation which does not need to be
// performed every time if the class initialization is already executed.
lock (_testClassExecuteSyncObject)
try
{
// Perform a check again.
if (!IsClassInitializeExecuted)
// We have discovered the methods from bottom (most derived) to top (less derived) but we want to execute
// from top to bottom.
for (int i = BaseClassInitMethods.Count - 1; i >= 0; i--)
{
try
initializeMethod = BaseClassInitMethods[i];
ClassInitializationException = InvokeInitializeMethod(initializeMethod, testContext);
if (ClassInitializationException is not null)
{
// We have discovered the methods from bottom (most derived) to top (less derived) but we want to execute
// from top to bottom.
for (int i = BaseClassInitMethods.Count - 1; i >= 0; i--)
{
initializeMethod = BaseClassInitMethods[i];
ClassInitializationException = InvokeInitializeMethod(initializeMethod, testContext);
if (ClassInitializationException is not null)
{
break;
}
}

if (ClassInitializationException is null)
{
initializeMethod = ClassInitializeMethod;
ClassInitializationException = InvokeInitializeMethod(ClassInitializeMethod, testContext);
}
}
catch (Exception ex)
{
ClassInitializationException = ex;
failedClassInitializeMethodName = initializeMethod?.Name ?? ClassInitializeMethod?.Name;
}
finally
{
IsClassInitializeExecuted = true;
break;
}
}

if (ClassInitializationException is null)
{
initializeMethod = ClassInitializeMethod;
ClassInitializationException = InvokeInitializeMethod(ClassInitializeMethod, testContext);
}
}
catch (Exception ex)
{
ClassInitializationException = ex;
failedClassInitializeMethodName = initializeMethod?.Name ?? ClassInitializeMethod?.Name;
}
finally
{
IsClassInitializeExecuted = true;
}
}

Expand Down Expand Up @@ -385,8 +378,6 @@ internal UnitTestResult GetResultOrRunClassInitialize(ITestContext testContext,
return clonedInitializeResult;
}

DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available.");

// For optimization purposes, return right away if there is nothing to execute.
// For STA, this avoids starting a thread when we know it will do nothing.
// But we still return early even not STA.
Expand All @@ -396,41 +387,62 @@ internal UnitTestResult GetResultOrRunClassInitialize(ITestContext testContext,
return _classInitializeResult = new(ObjectModelUnitTestOutcome.Passed, null);
}

bool isSTATestClass = AttributeComparer.IsDerived<STATestClassAttribute>(ClassAttribute);
bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows);
if (isSTATestClass
&& isWindowsOS
&& Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
// At this point, maybe class initialize was executed by another thread such
// that TryGetClonedCachedClassInitializeResult would return non-null.
// Now, we need to check again, but under a lock.
// Note that we are duplicating the logic above.
// We could keep the logic in lock only and not duplicate, but we don't want to pay
// the lock cost unnecessarily for a common case.
// We also need to lock to avoid concurrency issues and guarantee that class init is called only once.
lock (_testClassExecuteSyncObject)
{
UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete");
Thread entryPointThread = new(() => result = DoRun())
{
Name = "MSTest STATestClass ClassInitialize",
};

entryPointThread.SetApartmentState(ApartmentState.STA);
entryPointThread.Start();
clonedInitializeResult = TryGetClonedCachedClassInitializeResult();

try
// Optimization: If we already ran before and know the result, return it.
if (clonedInitializeResult is not null)
{
entryPointThread.Join();
return result;
DebugEx.Assert(IsClassInitializeExecuted, "Class initialize result should be available if and only if class initialize was executed");
return clonedInitializeResult;
}
catch (Exception ex)

DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available.");

bool isSTATestClass = AttributeComparer.IsDerived<STATestClassAttribute>(ClassAttribute);
bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows);
if (isSTATestClass
&& isWindowsOS
&& Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString());
return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation()));
UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete");
Thread entryPointThread = new(() => result = DoRun())
{
Name = "MSTest STATestClass ClassInitialize",
};

entryPointThread.SetApartmentState(ApartmentState.STA);
entryPointThread.Start();

try
{
entryPointThread.Join();
return result;
}
catch (Exception ex)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString());
return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation()));
}
}
}
else
{
// If the requested apartment state is STA and the OS is not Windows, then warn the user.
if (!isWindowsOS && isSTATestClass)
else
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning);
}
// If the requested apartment state is STA and the OS is not Windows, then warn the user.
if (!isWindowsOS && isSTATestClass)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning);
}

return DoRun();
return DoRun();
}
}

// Local functions
Expand Down

0 comments on commit ecf3775

Please sign in to comment.