Skip to content

Commit

Permalink
Fix kernel cache threading (#3437)
Browse files Browse the repository at this point in the history
* Fix thread safety issue in kernel cache when running tuning

* Fix formatting

* Fix NoGpu build
  • Loading branch information
BrianHarrisonAMD authored Dec 13, 2024
1 parent 97c2c01 commit 4d10252
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/hip/handlehip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ void Handle::ClearKernels(const std::string& algorithm, const std::string& netwo
this->impl->cache.ClearKernels(algorithm, network_config);
}

const std::vector<Kernel>& Handle::GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const
std::vector<Kernel> Handle::GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const
{
return this->impl->cache.GetKernels(algorithm, network_config);
}
Expand Down
17 changes: 13 additions & 4 deletions src/include/miopen/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,18 @@ struct MIOPEN_EXPORT Handle : miopenHandle

auto GetKernels(const std::string& algorithm, const std::string& network_config) const
{
return this->GetKernelsImpl(algorithm, network_config) |
boost::adaptors::transformed([this](Kernel k) { return this->Run(k); });
auto kernels = this->GetKernelsImpl(algorithm, network_config);

std::vector<KernelInvoke> kernelInvokers;
kernelInvokers.resize(kernels.size());
std::transform(kernels.begin(),
kernels.end(),
kernelInvokers.begin(),
[this](const Kernel& k) { return this->Run(k); });

return kernelInvokers;
}

KernelInvoke GetKernel(const std::string& algorithm, const std::string& network_config) const
{
auto ks = this->GetKernelsImpl(algorithm, network_config);
Expand All @@ -137,8 +146,8 @@ struct MIOPEN_EXPORT Handle : miopenHandle
}

KernelInvoke Run(Kernel k, bool coop_launch = false) const;
const std::vector<Kernel>& GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const;
std::vector<Kernel> GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const;

Program LoadProgram(const fs::path& program_name,
std::string params,
Expand Down
7 changes: 5 additions & 2 deletions src/include/miopen/kernel_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <shared_mutex>

namespace miopen {

Expand Down Expand Up @@ -81,8 +82,7 @@ class KernelCache

void ClearKernels(const std::string& algorithm, const std::string& network_config);

const std::vector<Kernel>& GetKernels(const std::string& algorithm,
const std::string& network_config);
std::vector<Kernel> GetKernels(const std::string& algorithm, const std::string& network_config);

bool HasProgram(const fs::path& name, const std::string& params) const;
void ClearProgram(const fs::path& name, const std::string& params);
Expand All @@ -92,8 +92,11 @@ class KernelCache
KernelCache();

private:
void AddKernelUnsafe(Key key, Kernel k, std::size_t cache_index);

KernelMap kernel_map;
ProgramMap program_map;
mutable std::shared_mutex lock;
};

} // namespace miopen
Expand Down
34 changes: 28 additions & 6 deletions src/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,19 @@

#include <iostream>
#include <iterator>
#include <mutex>

MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_DEVICE_ARCH)

namespace miopen {

const std::vector<Kernel>& KernelCache::GetKernels(const std::string& algorithm,
const std::string& network_config)
using WriteLock = std::unique_lock<std::shared_mutex>;
using ReadLock = std::shared_lock<std::shared_mutex>;

std::vector<Kernel> KernelCache::GetKernels(const std::string& algorithm,
const std::string& network_config)
{
ReadLock readLock(lock);

std::pair<std::string, std::string> key = std::make_pair(algorithm, network_config);

Expand All @@ -73,21 +78,28 @@ const std::vector<Kernel>& KernelCache::GetKernels(const std::string& algorithm,

bool KernelCache::HasProgram(const fs::path& name, const std::string& params) const
{
ReadLock readLock(lock);

const auto key = std::make_pair(name, params);
return program_map.count(key) > 0;
}

void KernelCache::ClearProgram(const fs::path& name, const std::string& params)
{
if(HasProgram(name, params))
WriteLock writeLock(lock);

const auto key = std::make_pair(name, params);
auto program_it = program_map.find(key);
if(program_it != program_map.end())
{
const auto key = std::make_pair(name, params);
program_map.erase(key);
program_map.erase(program_it);
}
}

void KernelCache::AddProgram(Program prog, const fs::path& program_name, std::string params)
{
WriteLock writeLock(lock);

program_map[std::make_pair(program_name, params)] = prog;
}

Expand All @@ -103,6 +115,8 @@ Kernel KernelCache::AddKernel(const Handle& h,
const std::string& kernel_src,
Program* program_out)
{
WriteLock writeLock(lock);

const std::pair<std::string, std::string> key = std::make_pair(algorithm, network_config);
if(!network_config.empty() || !algorithm.empty()) // Don't log only _empty_ keys.
MIOPEN_LOG_I2("Key: " << key.first << " \"" << key.second << '\"');
Expand Down Expand Up @@ -149,12 +163,18 @@ Kernel KernelCache::AddKernel(const Handle& h,

if(!network_config.empty() && !algorithm.empty())
{
this->AddKernel(key, kernel, cache_index);
this->AddKernelUnsafe(key, kernel, cache_index);
}
return kernel;
}

void KernelCache::AddKernel(Key key, Kernel k, std::size_t cache_index)
{
WriteLock writeLock(lock);
AddKernelUnsafe(key, k, cache_index);
}

void KernelCache::AddKernelUnsafe(Key key, Kernel k, std::size_t cache_index)
{
auto&& v = kernel_map[key];
if(cache_index >= v.size())
Expand All @@ -166,6 +186,8 @@ void KernelCache::AddKernel(Key key, Kernel k, std::size_t cache_index)

void KernelCache::ClearKernels(const std::string& algorithm, const std::string& network_config)
{
WriteLock writeLock(lock);

if(network_config.empty() || algorithm.empty())
{
MIOPEN_THROW("Network config or algorithm empty.");
Expand Down
4 changes: 2 additions & 2 deletions src/nogpu/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ void Handle::ClearProgram(const fs::path& program_name, const std::string& param
this->impl->cache.ClearProgram(program_name, params);
}

const std::vector<Kernel>& Handle::GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const
std::vector<Kernel> Handle::GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const
{
return this->impl->cache.GetKernels(algorithm, network_config);
}
Expand Down
4 changes: 2 additions & 2 deletions src/ocl/handleocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ void Handle::ClearKernels(const std::string& algorithm, const std::string& netwo
this->impl->cache.ClearKernels(algorithm, network_config);
}

const std::vector<Kernel>& Handle::GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const
std::vector<Kernel> Handle::GetKernelsImpl(const std::string& algorithm,
const std::string& network_config) const
{
return this->impl->cache.GetKernels(algorithm, network_config);
}
Expand Down

0 comments on commit 4d10252

Please sign in to comment.