From a43c7821cca85e476143aff05ccc8166aba9b60b Mon Sep 17 00:00:00 2001 From: jslhcl Date: Thu, 23 May 2024 08:11:20 -0700 Subject: [PATCH] new APIs for ORT-genai --- include/custom_op/custom_op_lite.h | 9 ++++++++- include/custom_op/kernel_context.h | 2 ++ include/ort_c_to_cpp.h | 3 +++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/include/custom_op/custom_op_lite.h b/include/custom_op/custom_op_lite.h index 960f64230..113b5ef0b 100644 --- a/include/custom_op/custom_op_lite.h +++ b/include/custom_op/custom_op_lite.h @@ -458,7 +458,7 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext { public: static const int cuda_resource_ver = 1; - OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api), kernel_context_(ctx) { api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); if (!cuda_stream_) { ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); @@ -526,8 +526,15 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext { return device_id_; } + void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* mem_info, size_t count_or_bytes) override { + void* ret = nullptr; + api_.KernelContext_GetScratchBuffer(&kernel_context_, mem_info, count_or_bytes, &ret); + return ret; + } + private: const OrtApi& api_; + const OrtKernelContext& kernel_context_; OrtAllocator* cpu_allocator_; OrtAllocator* cuda_allocator_; void* cuda_stream_ = {}; diff --git a/include/custom_op/kernel_context.h b/include/custom_op/kernel_context.h index d8ac8fbbd..d56c79314 100644 --- a/include/custom_op/kernel_context.h +++ b/include/custom_op/kernel_context.h @@ -5,6 +5,7 @@ #include #include #include +#include "onnxruntime_c_api.h" namespace Ort { namespace Custom { @@ -29,6 +30,7 @@ class CUDAKernelContext : public KernelContext { virtual void* GetCudaStream() const = 0; virtual void* GetCublasHandle() const = 0; virtual int GetCudaDeviceId() const = 0; + virtual void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* , size_t ) { return nullptr; } }; #endif diff --git a/include/ort_c_to_cpp.h b/include/ort_c_to_cpp.h index 92c2fb01d..152aa6633 100644 --- a/include/ort_c_to_cpp.h +++ b/include/ort_c_to_cpp.h @@ -81,6 +81,9 @@ class API { return instance()->KernelContext_GetAllocator(context, mem_info, out); } #endif + static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) { + return instance()->ReleaseMemoryInfo(mem_info); + } private: const OrtApi* operator->() const { return &api_;