From 56b894c1e5357120fc6c422674674f7782484c2c Mon Sep 17 00:00:00 2001 From: Reid Kawaja <74506315+reidkwja@users.noreply.github.com> Date: Mon, 11 Nov 2024 22:45:29 +0000 Subject: [PATCH] apply base class to test class example --- .../smoke_solver_ConvBinWinograd3x3U.cpp | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/test/gtest/smoke_solver_ConvBinWinograd3x3U.cpp b/test/gtest/smoke_solver_ConvBinWinograd3x3U.cpp index a1ad5787f0..fcff2bafbc 100644 --- a/test/gtest/smoke_solver_ConvBinWinograd3x3U.cpp +++ b/test/gtest/smoke_solver_ConvBinWinograd3x3U.cpp @@ -25,7 +25,7 @@ *******************************************************************************/ #include #include - +#include "test_base.hpp" #include "gtest_common.hpp" #include "../conv2d.hpp" @@ -34,19 +34,19 @@ namespace { auto GetTestCases() { - const auto env = std::tuple{std::pair{MIOPEN_FIND_MODE, "normal"}, - std::pair{MIOPEN_DEBUG_FIND_ONLY_SOLVER, "ConvBinWinograd3x3U"}}; - const std::string vf = " --verbose --disable-backward-data --disable-backward-weights"; const std::string vb = " --verbose --disable-forward --disable-backward-weights"; return std::vector{ // clang-format off //smoke_solver_ConvAsmImplicitGemmV4R1Dynamic - std::pair{env, vf + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"}, - std::pair{env, vb + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"} - // clang-format on - }; + std::pair{ + env, + vf + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"}, + std::pair{ + env, + vb + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"}}; + // clang-format on } using TestCase = decltype(GetTestCases())::value_type; @@ -64,12 +64,31 @@ bool IsTestSupportedForDevice() } // namespace -class GPU_Conv2dDefault_FP32 : public FloatTestCase> +// Using TestBase +class GPU_Conv2dDefault_FP32 : public TestBase { +public: + static std::vector get_env_vars() + { + return {"MIOPEN_FIND_MODE", "MIOPEN_DEBUG_FIND_ONLY_SOLVER"}; + } + + static std::map get_env_values() + { + return {{"MIOPEN_FIND_MODE", "normal"}, + {"MIOPEN_DEBUG_FIND_ONLY_SOLVER", "ConvBinWinograd3x3U"}}; + } }; +// Parameterized test implementation TEST_P(GPU_Conv2dDefault_FP32, FloatTest_smoke_solver_ConvBinWinograd3x3U) { + // Set environment variables dynamically + for(const auto& [var, val] : GPU_Conv2dDefault_FP32::get_env_values()) + { + miopen::env::setEnvironmentVariable(var, val); + } + if(IsTestSupportedForDevice() && !SkipTest()) { invoke_with_params(default_check); @@ -78,6 +97,6 @@ TEST_P(GPU_Conv2dDefault_FP32, FloatTest_smoke_solver_ConvBinWinograd3x3U) { GTEST_SKIP(); } -}; +} INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Conv2dDefault_FP32, testing::Values(GetTestCases()));