diff --git a/src/hip/handlehip.cpp b/src/hip/handlehip.cpp index 5b1852c7cd..a76a8ef470 100644 --- a/src/hip/handlehip.cpp +++ b/src/hip/handlehip.cpp @@ -518,8 +518,8 @@ void Handle::ClearKernels(const std::string& algorithm, const std::string& netwo this->impl->cache.ClearKernels(algorithm, network_config); } -const std::vector& Handle::GetKernelsImpl(const std::string& algorithm, - const std::string& network_config) const +std::vector Handle::GetKernelsImpl(const std::string& algorithm, + const std::string& network_config) const { return this->impl->cache.GetKernels(algorithm, network_config); } diff --git a/src/include/miopen/handle.hpp b/src/include/miopen/handle.hpp index 5e6dc5f793..27d2227292 100644 --- a/src/include/miopen/handle.hpp +++ b/src/include/miopen/handle.hpp @@ -122,9 +122,18 @@ struct MIOPEN_EXPORT Handle : miopenHandle auto GetKernels(const std::string& algorithm, const std::string& network_config) const { - return this->GetKernelsImpl(algorithm, network_config) | - boost::adaptors::transformed([this](Kernel k) { return this->Run(k); }); + auto kernels = this->GetKernelsImpl(algorithm, network_config); + + std::vector kernelInvokers; + kernelInvokers.resize(kernels.size()); + std::transform(kernels.begin(), + kernels.end(), + kernelInvokers.begin(), + [this](const Kernel& k) { return this->Run(k); }); + + return kernelInvokers; } + KernelInvoke GetKernel(const std::string& algorithm, const std::string& network_config) const { auto ks = this->GetKernelsImpl(algorithm, network_config); @@ -137,8 +146,8 @@ struct MIOPEN_EXPORT Handle : miopenHandle } KernelInvoke Run(Kernel k, bool coop_launch = false) const; - const std::vector& GetKernelsImpl(const std::string& algorithm, - const std::string& network_config) const; + std::vector GetKernelsImpl(const std::string& algorithm, + const std::string& network_config) const; Program LoadProgram(const fs::path& program_name, std::string params, diff --git a/src/include/miopen/kernel_cache.hpp b/src/include/miopen/kernel_cache.hpp index cb3eeed527..05773c1c26 100644 --- a/src/include/miopen/kernel_cache.hpp +++ b/src/include/miopen/kernel_cache.hpp @@ -50,6 +50,7 @@ #include #include #include +#include namespace miopen { @@ -81,8 +82,7 @@ class KernelCache void ClearKernels(const std::string& algorithm, const std::string& network_config); - const std::vector& GetKernels(const std::string& algorithm, - const std::string& network_config); + std::vector GetKernels(const std::string& algorithm, const std::string& network_config); bool HasProgram(const fs::path& name, const std::string& params) const; void ClearProgram(const fs::path& name, const std::string& params); @@ -92,8 +92,11 @@ class KernelCache KernelCache(); private: + void AddKernelUnsafe(Key key, Kernel k, std::size_t cache_index); + KernelMap kernel_map; ProgramMap program_map; + mutable std::shared_mutex lock; }; } // namespace miopen diff --git a/src/kernel_cache.cpp b/src/kernel_cache.cpp index 034bfd62bf..b4eef70760 100644 --- a/src/kernel_cache.cpp +++ b/src/kernel_cache.cpp @@ -47,14 +47,19 @@ #include #include +#include MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_DEVICE_ARCH) namespace miopen { -const std::vector& KernelCache::GetKernels(const std::string& algorithm, - const std::string& network_config) +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +std::vector KernelCache::GetKernels(const std::string& algorithm, + const std::string& network_config) { + ReadLock readLock(lock); std::pair key = std::make_pair(algorithm, network_config); @@ -73,21 +78,28 @@ const std::vector& KernelCache::GetKernels(const std::string& algorithm, bool KernelCache::HasProgram(const fs::path& name, const std::string& params) const { + ReadLock readLock(lock); + const auto key = std::make_pair(name, params); return program_map.count(key) > 0; } void KernelCache::ClearProgram(const fs::path& name, const std::string& params) { - if(HasProgram(name, params)) + WriteLock writeLock(lock); + + const auto key = std::make_pair(name, params); + auto program_it = program_map.find(key); + if(program_it != program_map.end()) { - const auto key = std::make_pair(name, params); - program_map.erase(key); + program_map.erase(program_it); } } void KernelCache::AddProgram(Program prog, const fs::path& program_name, std::string params) { + WriteLock writeLock(lock); + program_map[std::make_pair(program_name, params)] = prog; } @@ -103,6 +115,8 @@ Kernel KernelCache::AddKernel(const Handle& h, const std::string& kernel_src, Program* program_out) { + WriteLock writeLock(lock); + const std::pair key = std::make_pair(algorithm, network_config); if(!network_config.empty() || !algorithm.empty()) // Don't log only _empty_ keys. MIOPEN_LOG_I2("Key: " << key.first << " \"" << key.second << '\"'); @@ -149,12 +163,18 @@ Kernel KernelCache::AddKernel(const Handle& h, if(!network_config.empty() && !algorithm.empty()) { - this->AddKernel(key, kernel, cache_index); + this->AddKernelUnsafe(key, kernel, cache_index); } return kernel; } void KernelCache::AddKernel(Key key, Kernel k, std::size_t cache_index) +{ + WriteLock writeLock(lock); + AddKernelUnsafe(key, k, cache_index); +} + +void KernelCache::AddKernelUnsafe(Key key, Kernel k, std::size_t cache_index) { auto&& v = kernel_map[key]; if(cache_index >= v.size()) @@ -166,6 +186,8 @@ void KernelCache::AddKernel(Key key, Kernel k, std::size_t cache_index) void KernelCache::ClearKernels(const std::string& algorithm, const std::string& network_config) { + WriteLock writeLock(lock); + if(network_config.empty() || algorithm.empty()) { MIOPEN_THROW("Network config or algorithm empty."); diff --git a/src/nogpu/handle.cpp b/src/nogpu/handle.cpp index 6d3af3a9f2..44b10e9882 100644 --- a/src/nogpu/handle.cpp +++ b/src/nogpu/handle.cpp @@ -170,8 +170,8 @@ void Handle::ClearProgram(const fs::path& program_name, const std::string& param this->impl->cache.ClearProgram(program_name, params); } -const std::vector& Handle::GetKernelsImpl(const std::string& algorithm, - const std::string& network_config) const +std::vector Handle::GetKernelsImpl(const std::string& algorithm, + const std::string& network_config) const { return this->impl->cache.GetKernels(algorithm, network_config); } diff --git a/src/ocl/handleocl.cpp b/src/ocl/handleocl.cpp index e4f314a1a9..793ee25e96 100644 --- a/src/ocl/handleocl.cpp +++ b/src/ocl/handleocl.cpp @@ -363,8 +363,8 @@ void Handle::ClearKernels(const std::string& algorithm, const std::string& netwo this->impl->cache.ClearKernels(algorithm, network_config); } -const std::vector& Handle::GetKernelsImpl(const std::string& algorithm, - const std::string& network_config) const +std::vector Handle::GetKernelsImpl(const std::string& algorithm, + const std::string& network_config) const { return this->impl->cache.GetKernels(algorithm, network_config); }