Skip to content

Commit

Permalink
apply base class to test class example
Browse files Browse the repository at this point in the history
  • Loading branch information
reidkwja committed Nov 11, 2024
1 parent 70d2ec5 commit 56b894c
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions test/gtest/smoke_solver_ConvBinWinograd3x3U.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*******************************************************************************/
#include <tuple>
#include <string_view>

#include "test_base.hpp"
#include "gtest_common.hpp"

#include "../conv2d.hpp"
Expand All @@ -34,19 +34,19 @@ namespace {

auto GetTestCases()
{
const auto env = std::tuple{std::pair{MIOPEN_FIND_MODE, "normal"},
std::pair{MIOPEN_DEBUG_FIND_ONLY_SOLVER, "ConvBinWinograd3x3U"}};

const std::string vf = " --verbose --disable-backward-data --disable-backward-weights";
const std::string vb = " --verbose --disable-forward --disable-backward-weights";

return std::vector{
// clang-format off
//smoke_solver_ConvAsmImplicitGemmV4R1Dynamic
std::pair{env, vf + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"},
std::pair{env, vb + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"}
// clang-format on
};
std::pair{
env,
vf + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"},
std::pair{
env,
vb + " --input 1 20 20 20 --weights 20 20 3 3 --pads_strides_dilations 1 1 1 1 1 1"}};
// clang-format on
}

using TestCase = decltype(GetTestCases())::value_type;
Expand All @@ -64,12 +64,31 @@ bool IsTestSupportedForDevice()

} // namespace

class GPU_Conv2dDefault_FP32 : public FloatTestCase<std::vector<TestCase>>
// Using TestBase
class GPU_Conv2dDefault_FP32 : public TestBase<GPU_Conv2dDefault_FP32>
{
public:
static std::vector<std::string> get_env_vars()
{
return {"MIOPEN_FIND_MODE", "MIOPEN_DEBUG_FIND_ONLY_SOLVER"};
}

static std::map<std::string, std::string> get_env_values()
{
return {{"MIOPEN_FIND_MODE", "normal"},
{"MIOPEN_DEBUG_FIND_ONLY_SOLVER", "ConvBinWinograd3x3U"}};
}
};

// Parameterized test implementation
TEST_P(GPU_Conv2dDefault_FP32, FloatTest_smoke_solver_ConvBinWinograd3x3U)
{
// Set environment variables dynamically
for(const auto& [var, val] : GPU_Conv2dDefault_FP32::get_env_values())
{
miopen::env::setEnvironmentVariable(var, val);
}

if(IsTestSupportedForDevice() && !SkipTest())
{
invoke_with_params<conv2d_driver, GPU_Conv2dDefault_FP32>(default_check);
Expand All @@ -78,6 +97,6 @@ TEST_P(GPU_Conv2dDefault_FP32, FloatTest_smoke_solver_ConvBinWinograd3x3U)
{
GTEST_SKIP();
}
};
}

INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Conv2dDefault_FP32, testing::Values(GetTestCases()));

0 comments on commit 56b894c

Please sign in to comment.