Skip to content

Commit

Permalink
multiple streams implmentation for copy thunk between hosts and devic…
Browse files Browse the repository at this point in the history
…es using events
  • Loading branch information
zhenying-liu committed Mar 20, 2024
1 parent 364ec63 commit f10e739
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 86 deletions.
78 changes: 56 additions & 22 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ inline std::pair<bool, int64_t> GetSendRecvAsyncEventsKey(Thunk::Kind kind,
IrEmitterUnnested::IrEmitterUnnested(IrEmitterContext* ir_emitter_context)
: IrEmitter(ir_emitter_context, /*is_nested=*/false),
send_recv_events_(std::make_shared<SendRecvAsyncEvents>()),
copy_events_(std::make_shared<CopyAsyncEvents>()),
elemental_emitter_(*ir_emitter_context, &b_) {}

std::unique_ptr<IrEmitterUnnested> IrEmitterUnnested::Create(
Expand Down Expand Up @@ -2557,46 +2558,44 @@ static std::optional<GlobalDeviceId> DeviceConstraint(
return std::nullopt;
}

absl::Status IrEmitterUnnested::EmitCopyStartThunk(
const HloCopyStartInstruction* instr) {
absl::Status
IrEmitterUnnested::EmitCopyStartThunk(const HloCopyStartInstruction *instr) {
// copy-start has a tuple shape: {host, device, context},
// or {device, host, context}.
// Only the destination shape is needed to get the output buffer.
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer,
GetAllocationSliceForHlo(instr,
/*ShapeIndex=*/{0}));

const HloInstruction* src = instr->operand(0);
const Shape& input_shape = src->shape();
const HloInstruction *src = instr->operand(0);
const Shape &input_shape = src->shape();
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer,
GetAllocationSliceForHlo(src, {}));
Shape shape = instr->shape();
CHECK(shape.IsTuple());
enum copy_direction { H2D = 0, D2H = 1, D2D = 2 };
copy_direction dir = D2D;

int host_memory_space = static_cast<int>(stream_executor::MemoryType::kHost);
if (shape.mutable_tuple_shapes(0)->has_layout() &&
shape.mutable_tuple_shapes(0)->mutable_layout()->memory_space() ==
static_cast<int>(stream_executor::MemoryType::kHost)) {
VLOG(3) << "Device to Host: host memory space "
<< static_cast<int>(stream_executor::MemoryType::kHost);
auto thunk = std::make_unique<DeviceToHostCopyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*source_buffer=*/src_buffer,
/*destination_buffer=*/dst_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(input_shape));
AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
host_memory_space) {
dir = D2H;
}
if (shape.mutable_tuple_shapes(1)->has_layout() &&
shape.mutable_tuple_shapes(1)->mutable_layout()->memory_space() ==
static_cast<int>(stream_executor::MemoryType::kHost)) {
VLOG(3) << "Host to Device from the host memory space "
<< static_cast<int>(stream_executor::MemoryType::kHost);
;
auto thunk = std::make_unique<HostToDeviceCopyThunk>(
host_memory_space) {
dir = H2D;
}
if (dir != D2D) {
auto thunk = std::make_unique<DeviceHostCopyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*source_buffer=*/src_buffer,
/*destination_buffer=*/dst_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(input_shape));
/*mem_size=*/ShapeUtil::ByteSizeOf(input_shape),
/*async_events=*/copy_events_,
/*copy_start_instr=*/instr,
/*device_to_host=*/static_cast<bool>(dir));
AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}
Expand All @@ -2610,10 +2609,44 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
/*destination_buffer=*/dst_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(input_shape));
AddThunkToThunkSequence(std::move(thunk));

return absl::OkStatus();
}

// In the CopyDone thunk, the corresponding copy start instruction is passed
// in order to finding the matching event and make sure it's completed.
absl::Status IrEmitterUnnested::EmitCopyDoneThunk(const HloInstruction *instr) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer,
GetAllocationSliceForHlo(instr,
/*ShapeIndex=*/{}));

const HloInstruction *src = instr->operand(0);
const Shape shape = src->shape();

TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer,
GetAllocationSliceForHlo(src, {0}));
CHECK(shape.IsTuple());

if ((shape.tuple_shapes(0).has_layout() &&
shape.tuple_shapes(0).layout().memory_space() ==
static_cast<int>(stream_executor::MemoryType::kHost)) ||
(shape.tuple_shapes(1).has_layout() &&
shape.tuple_shapes(1).layout().memory_space() ==
static_cast<int>(stream_executor::MemoryType::kHost))) {
VLOG(3) << "CopyDone from the host memory space "
<< static_cast<int>(stream_executor::MemoryType::kHost)
<< src->ToString();
auto thunk = std::make_unique<DeviceHostCopyDoneThunk>(
Thunk::kCopyDone, Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*async_events=*/copy_events_,
/*copy_start_instr=*/src);
AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}

return absl::InternalError(
"Unknown copy-done instruction with incorrect memory space color");
}

absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
if (!instr->channel_id().has_value())
return absl::InternalError("Unknown send instruction channel id");
Expand Down Expand Up @@ -2952,6 +2985,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
return EmitWhile(instr);
case HloOpcode::kCopyStart:
return EmitCopyStartThunk(Cast<HloCopyStartInstruction>(instr));
case HloOpcode::kCopyDone:
return EmitCopyDoneThunk(instr);

// HLO module is already scheduled, so instructions for ordering are noops.
case HloOpcode::kAddDependency:
Expand All @@ -2962,7 +2997,6 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
case HloOpcode::kGetTupleElement:
case HloOpcode::kParameter:
case HloOpcode::kTuple:
case HloOpcode::kCopyDone:
return absl::OkStatus();
default:
return Internal("Unsupported instruction opcode: %s",
Expand Down
6 changes: 6 additions & 0 deletions xla/service/gpu/ir_emitter_unnested.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "xla/service/gpu/ir_emitter.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/runtime/copy_thunk.h"
#include "xla/service/gpu/runtime/send_recv_thunk.h"
#include "xla/service/gpu/thunk.h"
#include "xla/service/llvm_ir/ir_array.h"
Expand Down Expand Up @@ -199,6 +200,8 @@ class IrEmitterUnnested : public IrEmitter {

absl::Status EmitCopyStartThunk(const HloCopyStartInstruction* instr);

absl::Status EmitCopyDoneThunk(const HloInstruction* instr);

absl::Status EmitHloInstruction(const HloInstruction* instr);

absl::Status EmitTargetElementLoop(
Expand Down Expand Up @@ -377,6 +380,9 @@ class IrEmitterUnnested : public IrEmitter {
// Container for async send/recv events shared by send/recv thunks.
std::shared_ptr<SendRecvAsyncEvents> send_recv_events_;

// Container for async copy-start/copy-done events.
std::shared_ptr<CopyAsyncEvents> copy_events_;

// Returns the ShapedSlices for the given operands.
absl::StatusOr<std::vector<ShapedSlice>> GetShapedSlices(
mlir::Operation::operand_range operands);
Expand Down
149 changes: 111 additions & 38 deletions xla/service/gpu/runtime/copy_thunk.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2017 The OpenXLA Authors.
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@ limitations under the License.

#include <cstdint>

#include "mlir/IR/Value.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/thunk.h"
#include "xla/status.h"
Expand All @@ -27,15 +27,13 @@ namespace xla {
namespace gpu {

DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer, uint64_t mem_size)
: Thunk(Kind::kCopy, thunk_info),
source_buffer_(source_buffer),
destination_buffer_(destination_buffer),
mem_size_(mem_size) {}

absl::Status DeviceToDeviceCopyThunk::ExecuteOnStream(
const ExecuteParams& params) {
ThunkInfo thunk_info, const BufferAllocation::Slice &source_buffer,
const BufferAllocation::Slice &destination_buffer, uint64_t mem_size)
: Thunk(Kind::kCopy, thunk_info), source_buffer_(source_buffer),
destination_buffer_(destination_buffer), mem_size_(mem_size) {}

absl::Status
DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams &params) {
se::DeviceMemoryBase destination_data =
params.buffer_allocations->GetDeviceAddress(destination_buffer_);
se::DeviceMemoryBase source_data =
Expand All @@ -45,41 +43,116 @@ absl::Status DeviceToDeviceCopyThunk::ExecuteOnStream(
return params.stream->Memcpy(&destination_data, source_data, mem_size_);
}

DeviceToHostCopyThunk::DeviceToHostCopyThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer, uint64_t mem_size)
: DeviceToDeviceCopyThunk(thunk_info, source_buffer, destination_buffer,
mem_size) {}
//===----------------------------------------------------------------------===//
// CopyAsyncEvents
//===----------------------------------------------------------------------===//

absl::Status DeviceToHostCopyThunk::ExecuteOnStream(
const ExecuteParams& params) {
se::DeviceMemoryBase destination_data =
params.buffer_allocations->GetDeviceAddress(destination());
se::DeviceMemoryBase source_data =
params.buffer_allocations->GetDeviceAddress(source());
void* cpu_dst = destination_data.opaque();
VLOG(3) << "Memcpy D2H for memory offload from " << source_data.opaque()
<< " to " << cpu_dst;
return params.stream->Memcpy(cpu_dst, source_data, size_bytes());
// Emplace() will insert {key, event} pair into the hash map,
// and return the event in order to do RecordEvent() for async memcpy.
absl::Status CopyAsyncEvents::Emplace(se::StreamExecutor *executor,
const HloInstruction *instr,
se::Event &&event) {
Key key = {executor, instr};

absl::MutexLock lock(&mutex_);
VLOG(3) << "Emplace event " << event.implementation();
if (auto [it, inserted] = events_.try_emplace(key, std::move(event));
inserted) {
return absl::OkStatus();
}
VLOG(3) << "ATTN: event " << event.implementation() << "already exists!";
return absl::InternalError("Async copy event already exists!");
}

HostToDeviceCopyThunk::HostToDeviceCopyThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer, uint64_t mem_size)
// Retrieve a completion event started by copy-start instruction
// `instr`, and remove the event from the collection.
absl::StatusOr<se::Event>
CopyAsyncEvents::Extract(se::StreamExecutor *executor,
const HloInstruction *instr) {

Key key = {executor, instr};
absl::MutexLock lock(&mutex_);
if (auto event = events_.extract(key)) {
VLOG(3) << "Extract event " << event.mapped().implementation();
return std::move(event.mapped());
}
return absl::InternalError("Async copy event was not found!");
}

//===----------------------------------------------------------------------===//
// DeviceHostCopyThunk
//===----------------------------------------------------------------------===//
DeviceHostCopyThunk::DeviceHostCopyThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice &source_buffer,
const BufferAllocation::Slice &destination_buffer, uint64_t mem_size,
std::shared_ptr<CopyAsyncEvents> async_events, const HloInstruction *instr,
bool device_to_host)
: DeviceToDeviceCopyThunk(thunk_info, source_buffer, destination_buffer,
mem_size) {}
mem_size),
async_events_(std::move(async_events)), instr_(instr),
device_to_host_(device_to_host) {}

absl::Status HostToDeviceCopyThunk::ExecuteOnStream(
const ExecuteParams& params) {
absl::Status DeviceHostCopyThunk::ExecuteOnStream(const ExecuteParams &params) {
se::DeviceMemoryBase destination_data =
params.buffer_allocations->GetDeviceAddress(destination());
se::DeviceMemoryBase source_data =
params.buffer_allocations->GetDeviceAddress(source());
void* cpu_src = source_data.opaque();
VLOG(3) << "Memcpy H2D for memory offload from " << cpu_src << " to "
<< destination_data.opaque();
return params.stream->Memcpy(&destination_data, cpu_src, size_bytes());
void *cpu_dst = destination_data.opaque();
void *cpu_src = source_data.opaque();
TF_ASSIGN_OR_RETURN(
se::Stream * stream,
GetStreamForExecution(Thunk::execution_stream_id(), params));
if (stream == params.stream) {
if (device_to_host_) {
VLOG(3) << "Memcpy D2H from the main stream";
return params.stream->Memcpy(cpu_dst, source_data, size_bytes());
} else {
VLOG(3) << "Memcpy H2D from the main stream";
return params.stream->Memcpy(&destination_data, cpu_src, size_bytes());
}
}
// memcpy is issued from the other stream, not the main compute stream
if (device_to_host_) {
VLOG(3) << "Memcpy D2H from the other stream";
TF_RETURN_IF_ERROR(stream->Memcpy(cpu_dst, source_data, size_bytes()));
} else {
VLOG(3) << "Memcpy H2D from the other stream";
TF_RETURN_IF_ERROR(
stream->Memcpy(&destination_data, cpu_src, size_bytes()));
}
se::StreamExecutor *executor = params.stream->parent();
se::Event event(executor);
if (!event.Init()) {
return absl::InternalError(
"Failed to initialize copy operation async completion event!");
}
// Record memcpy operation completion.
TF_RETURN_IF_ERROR(stream->RecordEvent(&event));
VLOG(3) << "Emplace events: " << event.implementation()
<< " for inst: " << instr_->ToString();
return async_events_->Emplace(executor, instr_, std::move(event));
}

//===----------------------------------------------------------------------===//
// DeviceHostCopyDoneThunk
//===----------------------------------------------------------------------===//
DeviceHostCopyDoneThunk::DeviceHostCopyDoneThunk(
Thunk::Kind kind, ThunkInfo thunk_info,
std::shared_ptr<CopyAsyncEvents> async_events,
const HloInstruction *copy_start_instr)
: Thunk(kind, std::move(thunk_info)),
async_events_(std::move(async_events)),
copy_start_instr_(copy_start_instr) {}

absl::Status
DeviceHostCopyDoneThunk::ExecuteOnStream(const ExecuteParams &params) {
VLOG(3) << "CopyDone thunk between a host and a device for: "
<< copy_start_instr_->ToString();
se::StreamExecutor *executor = params.stream->parent();
TF_ASSIGN_OR_RETURN(se::Event event,
async_events_->Extract(executor, copy_start_instr_));
return params.stream->WaitFor(&event);
}

} // namespace gpu
} // namespace xla
} // namespace gpu
} // namespace xla
Loading

0 comments on commit f10e739

Please sign in to comment.