diff --git a/include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp b/include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp index 4baa0498fe..fcf7c6eda4 100644 --- a/include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp +++ b/include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp @@ -28,6 +28,72 @@ extern __constant__ DeviceHandle constSlaveSmChannels[8]; // extern __constant__ DeviceHandle constMasterSmChannel; +static constexpr int kMaxBlocks = 64; + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t)); + +struct Signal +{ + alignas(128) uint32_t start[kMaxBlocks][8]; + alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank +}; + +struct __align__(16) RankData { const void* ptrs[8]; }; + +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; + +namespace ck_tile { +struct DeviceReduceConnect +{ + index_t rank; + index_t world_size; + bool full_mesh_connect; + RankSignals signals; + std::unordered_map buffers; + Signal* self_signal; + + RankData *d_rank_data_base, *d_rank_data_end; + std::vector graph_unreg_buffers; + map ipc_handles; + + // Initialization function + DeviceReduceConnect(Signal* meta, + void* rank_data, + size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, + int rank, + bool full_mesh_connect = true) + : rank(rank), + world_size(offsets.size()), + full_mesh_connect(full_mesh_connect), + self_signal(meta), + d_rank_data_base(reinterpret_cast(rank_data)), + d_rank_data_end(d_rank_data_base + rank_data_sz / sizeof(RankData)) + { + for(int i = 0; i < world_size; i++) + { + Signal* rank_sg; + if(i != rank) + { + char* handle = open_ipc_handle(&handles[i]); + handle += offsets[i]; + rank_sg = (Signal*)handle; + } + else + { + rank_sg = self_signal; + } + signals.signals[i] = rank_sg; + } + } +}; + +} // namespace ck_tile + void setupConnection(int rank, int slaveRank, int worldSize,