Skip to content

Commit

Permalink
Start to substitute the mscclpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasNing committed Jan 13, 2025
1 parent abd2755 commit cfa8268
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,72 @@ extern __constant__ DeviceHandle<mscclpp::SmChannel> constSlaveSmChannels[8]; //

extern __constant__ DeviceHandle<mscclpp::SmChannel> constMasterSmChannel;

static constexpr int kMaxBlocks = 64;

using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));

struct Signal
{
alignas(128) uint32_t start[kMaxBlocks][8];
alignas(128) uint32_t end[kMaxBlocks][8];
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
};

struct __align__(16) RankData { const void* ptrs[8]; };

struct __align__(16) RankSignals { volatile Signal* signals[8]; };

namespace ck_tile {
struct DeviceReduceConnect
{
index_t rank;
index_t world_size;
bool full_mesh_connect;
RankSignals signals;
std::unordered_map<void*, RankData*> buffers;
Signal* self_signal;

RankData *d_rank_data_base, *d_rank_data_end;
std::vector<void*> graph_unreg_buffers;
map<IPC_KEY, char*> ipc_handles;

// Initialization function
DeviceReduceConnect(Signal* meta,
void* rank_data,
size_t rank_data_sz,
const cudaIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets,
int rank,
bool full_mesh_connect = true)
: rank(rank),
world_size(offsets.size()),
full_mesh_connect(full_mesh_connect),
self_signal(meta),
d_rank_data_base(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end(d_rank_data_base + rank_data_sz / sizeof(RankData))
{
for(int i = 0; i < world_size; i++)
{
Signal* rank_sg;
if(i != rank)
{
char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_sg = (Signal*)handle;
}
else
{
rank_sg = self_signal;
}
signals.signals[i] = rank_sg;
}
}
};

} // namespace ck_tile

void setupConnection(int rank,
int slaveRank,
int worldSize,
Expand Down

0 comments on commit cfa8268

Please sign in to comment.