diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 2711662bdf174..cec6887fd0cba 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -191,6 +191,7 @@ inline std::pair GetSendRecvAsyncEventsKey(Thunk::Kind kind, IrEmitterUnnested::IrEmitterUnnested(IrEmitterContext* ir_emitter_context) : IrEmitter(ir_emitter_context, /*is_nested=*/false), send_recv_events_(std::make_shared()), + copy_events_(std::make_shared()), elemental_emitter_(*ir_emitter_context, &b_) {} std::unique_ptr IrEmitterUnnested::Create( @@ -2557,8 +2558,8 @@ static std::optional DeviceConstraint( return std::nullopt; } -absl::Status IrEmitterUnnested::EmitCopyStartThunk( - const HloCopyStartInstruction* instr) { +absl::Status +IrEmitterUnnested::EmitCopyStartThunk(const HloCopyStartInstruction *instr) { // copy-start has a tuple shape: {host, device, context}, // or {device, host, context}. // Only the destination shape is needed to get the output buffer. @@ -2566,37 +2567,35 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk( GetAllocationSliceForHlo(instr, /*ShapeIndex=*/{0})); - const HloInstruction* src = instr->operand(0); - const Shape& input_shape = src->shape(); + const HloInstruction *src = instr->operand(0); + const Shape &input_shape = src->shape(); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer, GetAllocationSliceForHlo(src, {})); Shape shape = instr->shape(); CHECK(shape.IsTuple()); + enum copy_direction { H2D = 0, D2H = 1, D2D = 2 }; + copy_direction dir = D2D; + int host_memory_space = static_cast(stream_executor::MemoryType::kHost); if (shape.mutable_tuple_shapes(0)->has_layout() && shape.mutable_tuple_shapes(0)->mutable_layout()->memory_space() == - static_cast(stream_executor::MemoryType::kHost)) { - VLOG(3) << "Device to Host: host memory space " - << static_cast(stream_executor::MemoryType::kHost); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), - /*source_buffer=*/src_buffer, - /*destination_buffer=*/dst_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(input_shape)); - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); + host_memory_space) { + dir = D2H; } if (shape.mutable_tuple_shapes(1)->has_layout() && shape.mutable_tuple_shapes(1)->mutable_layout()->memory_space() == - static_cast(stream_executor::MemoryType::kHost)) { - VLOG(3) << "Host to Device from the host memory space " - << static_cast(stream_executor::MemoryType::kHost); - ; - auto thunk = std::make_unique( + host_memory_space) { + dir = H2D; + } + if (dir != D2D) { + auto thunk = std::make_unique( Thunk::ThunkInfo::WithProfileAnnotation(instr), /*source_buffer=*/src_buffer, /*destination_buffer=*/dst_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(input_shape)); + /*mem_size=*/ShapeUtil::ByteSizeOf(input_shape), + /*async_events=*/copy_events_, + /*copy_start_instr=*/instr, + /*device_to_host=*/static_cast(dir)); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -2610,10 +2609,44 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk( /*destination_buffer=*/dst_buffer, /*mem_size=*/ShapeUtil::ByteSizeOf(input_shape)); AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); } +// In the CopyDone thunk, the corresponding copy start instruction is passed +// in order to finding the matching event and make sure it's completed. +absl::Status IrEmitterUnnested::EmitCopyDoneThunk(const HloInstruction *instr) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer, + GetAllocationSliceForHlo(instr, + /*ShapeIndex=*/{})); + + const HloInstruction *src = instr->operand(0); + const Shape shape = src->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer, + GetAllocationSliceForHlo(src, {0})); + CHECK(shape.IsTuple()); + + if ((shape.tuple_shapes(0).has_layout() && + shape.tuple_shapes(0).layout().memory_space() == + static_cast(stream_executor::MemoryType::kHost)) || + (shape.tuple_shapes(1).has_layout() && + shape.tuple_shapes(1).layout().memory_space() == + static_cast(stream_executor::MemoryType::kHost))) { + VLOG(3) << "CopyDone from the host memory space " + << static_cast(stream_executor::MemoryType::kHost) + << src->ToString(); + auto thunk = std::make_unique( + Thunk::kCopyDone, Thunk::ThunkInfo::WithProfileAnnotation(instr), + /*async_events=*/copy_events_, + /*copy_start_instr=*/src); + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); + } + + return absl::InternalError( + "Unknown copy-done instruction with incorrect memory space color"); +} + absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) { if (!instr->channel_id().has_value()) return absl::InternalError("Unknown send instruction channel id"); @@ -2952,6 +2985,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( return EmitWhile(instr); case HloOpcode::kCopyStart: return EmitCopyStartThunk(Cast(instr)); + case HloOpcode::kCopyDone: + return EmitCopyDoneThunk(instr); // HLO module is already scheduled, so instructions for ordering are noops. case HloOpcode::kAddDependency: @@ -2962,7 +2997,6 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( case HloOpcode::kGetTupleElement: case HloOpcode::kParameter: case HloOpcode::kTuple: - case HloOpcode::kCopyDone: return absl::OkStatus(); default: return Internal("Unsupported instruction opcode: %s", diff --git a/xla/service/gpu/ir_emitter_unnested.h b/xla/service/gpu/ir_emitter_unnested.h index f378c07872c00..b2182b8547198 100644 --- a/xla/service/gpu/ir_emitter_unnested.h +++ b/xla/service/gpu/ir_emitter_unnested.h @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/gpu/ir_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/send_recv_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/service/llvm_ir/ir_array.h" @@ -199,6 +200,8 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitCopyStartThunk(const HloCopyStartInstruction* instr); + absl::Status EmitCopyDoneThunk(const HloInstruction* instr); + absl::Status EmitHloInstruction(const HloInstruction* instr); absl::Status EmitTargetElementLoop( @@ -377,6 +380,9 @@ class IrEmitterUnnested : public IrEmitter { // Container for async send/recv events shared by send/recv thunks. std::shared_ptr send_recv_events_; + // Container for async copy-start/copy-done events. + std::shared_ptr copy_events_; + // Returns the ShapedSlices for the given operands. absl::StatusOr> GetShapedSlices( mlir::Operation::operand_range operands); diff --git a/xla/service/gpu/runtime/copy_thunk.cc b/xla/service/gpu/runtime/copy_thunk.cc index 4bdc2acf5545c..2f39cab8ed939 100644 --- a/xla/service/gpu/runtime/copy_thunk.cc +++ b/xla/service/gpu/runtime/copy_thunk.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/IR/Value.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" @@ -27,15 +27,13 @@ namespace xla { namespace gpu { DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, uint64_t mem_size) - : Thunk(Kind::kCopy, thunk_info), - source_buffer_(source_buffer), - destination_buffer_(destination_buffer), - mem_size_(mem_size) {} - -absl::Status DeviceToDeviceCopyThunk::ExecuteOnStream( - const ExecuteParams& params) { + ThunkInfo thunk_info, const BufferAllocation::Slice &source_buffer, + const BufferAllocation::Slice &destination_buffer, uint64_t mem_size) + : Thunk(Kind::kCopy, thunk_info), source_buffer_(source_buffer), + destination_buffer_(destination_buffer), mem_size_(mem_size) {} + +absl::Status +DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams ¶ms) { se::DeviceMemoryBase destination_data = params.buffer_allocations->GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = @@ -45,41 +43,116 @@ absl::Status DeviceToDeviceCopyThunk::ExecuteOnStream( return params.stream->Memcpy(&destination_data, source_data, mem_size_); } -DeviceToHostCopyThunk::DeviceToHostCopyThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, uint64_t mem_size) - : DeviceToDeviceCopyThunk(thunk_info, source_buffer, destination_buffer, - mem_size) {} +//===----------------------------------------------------------------------===// +// CopyAsyncEvents +//===----------------------------------------------------------------------===// -absl::Status DeviceToHostCopyThunk::ExecuteOnStream( - const ExecuteParams& params) { - se::DeviceMemoryBase destination_data = - params.buffer_allocations->GetDeviceAddress(destination()); - se::DeviceMemoryBase source_data = - params.buffer_allocations->GetDeviceAddress(source()); - void* cpu_dst = destination_data.opaque(); - VLOG(3) << "Memcpy D2H for memory offload from " << source_data.opaque() - << " to " << cpu_dst; - return params.stream->Memcpy(cpu_dst, source_data, size_bytes()); +// Emplace() will insert {key, event} pair into the hash map, +// and return the event in order to do RecordEvent() for async memcpy. +absl::Status CopyAsyncEvents::Emplace(se::StreamExecutor *executor, + const HloInstruction *instr, + se::Event &&event) { + Key key = {executor, instr}; + + absl::MutexLock lock(&mutex_); + VLOG(3) << "Emplace event " << event.implementation(); + if (auto [it, inserted] = events_.try_emplace(key, std::move(event)); + inserted) { + return absl::OkStatus(); + } + VLOG(3) << "ATTN: event " << event.implementation() << "already exists!"; + return absl::InternalError("Async copy event already exists!"); } -HostToDeviceCopyThunk::HostToDeviceCopyThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, uint64_t mem_size) +// Retrieve a completion event started by copy-start instruction +// `instr`, and remove the event from the collection. +absl::StatusOr +CopyAsyncEvents::Extract(se::StreamExecutor *executor, + const HloInstruction *instr) { + + Key key = {executor, instr}; + absl::MutexLock lock(&mutex_); + if (auto event = events_.extract(key)) { + VLOG(3) << "Extract event " << event.mapped().implementation(); + return std::move(event.mapped()); + } + return absl::InternalError("Async copy event was not found!"); +} + +//===----------------------------------------------------------------------===// +// DeviceHostCopyThunk +//===----------------------------------------------------------------------===// +DeviceHostCopyThunk::DeviceHostCopyThunk( + ThunkInfo thunk_info, const BufferAllocation::Slice &source_buffer, + const BufferAllocation::Slice &destination_buffer, uint64_t mem_size, + std::shared_ptr async_events, const HloInstruction *instr, + bool device_to_host) : DeviceToDeviceCopyThunk(thunk_info, source_buffer, destination_buffer, - mem_size) {} + mem_size), + async_events_(std::move(async_events)), instr_(instr), + device_to_host_(device_to_host) {} -absl::Status HostToDeviceCopyThunk::ExecuteOnStream( - const ExecuteParams& params) { +absl::Status DeviceHostCopyThunk::ExecuteOnStream(const ExecuteParams ¶ms) { se::DeviceMemoryBase destination_data = params.buffer_allocations->GetDeviceAddress(destination()); se::DeviceMemoryBase source_data = params.buffer_allocations->GetDeviceAddress(source()); - void* cpu_src = source_data.opaque(); - VLOG(3) << "Memcpy H2D for memory offload from " << cpu_src << " to " - << destination_data.opaque(); - return params.stream->Memcpy(&destination_data, cpu_src, size_bytes()); + void *cpu_dst = destination_data.opaque(); + void *cpu_src = source_data.opaque(); + TF_ASSIGN_OR_RETURN( + se::Stream * stream, + GetStreamForExecution(Thunk::execution_stream_id(), params)); + if (stream == params.stream) { + if (device_to_host_) { + VLOG(3) << "Memcpy D2H from the main stream"; + return params.stream->Memcpy(cpu_dst, source_data, size_bytes()); + } else { + VLOG(3) << "Memcpy H2D from the main stream"; + return params.stream->Memcpy(&destination_data, cpu_src, size_bytes()); + } + } + // memcpy is issued from the other stream, not the main compute stream + if (device_to_host_) { + VLOG(3) << "Memcpy D2H from the other stream"; + TF_RETURN_IF_ERROR(stream->Memcpy(cpu_dst, source_data, size_bytes())); + } else { + VLOG(3) << "Memcpy H2D from the other stream"; + TF_RETURN_IF_ERROR( + stream->Memcpy(&destination_data, cpu_src, size_bytes())); + } + se::StreamExecutor *executor = params.stream->parent(); + se::Event event(executor); + if (!event.Init()) { + return absl::InternalError( + "Failed to initialize copy operation async completion event!"); + } + // Record memcpy operation completion. + TF_RETURN_IF_ERROR(stream->RecordEvent(&event)); + VLOG(3) << "Emplace events: " << event.implementation() + << " for inst: " << instr_->ToString(); + return async_events_->Emplace(executor, instr_, std::move(event)); +} + +//===----------------------------------------------------------------------===// +// DeviceHostCopyDoneThunk +//===----------------------------------------------------------------------===// +DeviceHostCopyDoneThunk::DeviceHostCopyDoneThunk( + Thunk::Kind kind, ThunkInfo thunk_info, + std::shared_ptr async_events, + const HloInstruction *copy_start_instr) + : Thunk(kind, std::move(thunk_info)), + async_events_(std::move(async_events)), + copy_start_instr_(copy_start_instr) {} + +absl::Status +DeviceHostCopyDoneThunk::ExecuteOnStream(const ExecuteParams ¶ms) { + VLOG(3) << "CopyDone thunk between a host and a device for: " + << copy_start_instr_->ToString(); + se::StreamExecutor *executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(se::Event event, + async_events_->Extract(executor, copy_start_instr_)); + return params.stream->WaitFor(&event); } -} // namespace gpu -} // namespace xla +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/copy_thunk.h b/xla/service/gpu/runtime/copy_thunk.h index 9ad1a2943de68..87543c641611e 100644 --- a/xla/service/gpu/runtime/copy_thunk.h +++ b/xla/service/gpu/runtime/copy_thunk.h @@ -18,62 +18,111 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/thunk.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/stream_executor.h" namespace xla { namespace gpu { // A thunk that copies data from a device buffer to another device buffer. class DeviceToDeviceCopyThunk : public Thunk { - public: +public: // Constructs a CopyThunk that copies host data from `source_buffer` to the // device buffer `destination_buffer`. `mem_size` is the size of the data in // bytes. DeviceToDeviceCopyThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, + const BufferAllocation::Slice &source_buffer, + const BufferAllocation::Slice &destination_buffer, uint64_t mem_size); - DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; - DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; + DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk &) = delete; + DeviceToDeviceCopyThunk &operator=(const DeviceToDeviceCopyThunk &) = delete; - absl::Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams ¶ms) override; void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } - const BufferAllocation::Slice& source() const { return source_buffer_; } - const BufferAllocation::Slice& destination() const { + const BufferAllocation::Slice &source() const { return source_buffer_; } + const BufferAllocation::Slice &destination() const { return destination_buffer_; } uint64_t size_bytes() const { return mem_size_; } - private: +private: const BufferAllocation::Slice source_buffer_; const BufferAllocation::Slice destination_buffer_; const uint64_t mem_size_; }; -class DeviceToHostCopyThunk : public DeviceToDeviceCopyThunk { - public: - DeviceToHostCopyThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, - uint64_t mem_size); - absl::Status ExecuteOnStream(const ExecuteParams& params) override; +//===----------------------------------------------------------------------===// +// CopyAsyncEvents +//===----------------------------------------------------------------------===// +class CopyAsyncEvents { +public: + // Add a new copy-start completion event. + // absl::StatusOr> Emplace( + absl::Status Emplace(se::StreamExecutor *executor, + const HloInstruction *instr, se::Event &&event); + + // Retrieve a completion event started by copy-start instruction + // `instr`, and remove the event from the collection. + + absl::StatusOr Extract(se::StreamExecutor *executor, + const HloInstruction *instr); + +private: + using Key = std::pair; + absl::Mutex mutex_; + absl::flat_hash_map events_ ABSL_GUARDED_BY(mutex_); }; -class HostToDeviceCopyThunk : public DeviceToDeviceCopyThunk { - public: - HostToDeviceCopyThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, - uint64_t mem_size); - absl::Status ExecuteOnStream(const ExecuteParams& params) override; +//===----------------------------------------------------------------------===// +// DeviceHostCopyThunk +//===----------------------------------------------------------------------===// +// A thunk that copies data from a device buffer to a host buffer. +class DeviceHostCopyThunk : public DeviceToDeviceCopyThunk { +public: + // Constructs a CopyThunk that copies host data from `source_buffer` to the + // device buffer `destination_buffer`. `mem_size` is the size of the data in + // bytes. `instr` is the copy-start instruction. `device_to_host` indicates + // whether to generate device-to-host or host-to-device memcpy. + DeviceHostCopyThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice &source_buffer, + const BufferAllocation::Slice &destination_buffer, + uint64_t mem_size, + std::shared_ptr events, + const HloInstruction *instr, bool device_to_host); + absl::Status ExecuteOnStream(const ExecuteParams ¶ms) override; + +private: + std::shared_ptr async_events_; + const HloInstruction *instr_; + bool device_to_host_; +}; + +//===----------------------------------------------------------------------===// +// DeviceHostCopyDoneThunk +//===----------------------------------------------------------------------===// + +class DeviceHostCopyDoneThunk : public Thunk { +public: + DeviceHostCopyDoneThunk(Thunk::Kind kind, ThunkInfo thunk_info, + std::shared_ptr events, + const HloInstruction *copy_start_instr); + + absl::Status ExecuteOnStream(const ExecuteParams ¶ms) override; + +private: + std::shared_ptr async_events_; + const HloInstruction *copy_start_instr_; }; -} // namespace gpu -} // namespace xla +} // namespace gpu +} // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ diff --git a/xla/service/gpu/thunk.cc b/xla/service/gpu/thunk.cc index 1019f027a36ec..ab48d422b9cf0 100644 --- a/xla/service/gpu/thunk.cc +++ b/xla/service/gpu/thunk.cc @@ -224,6 +224,7 @@ Thunk::ExecuteParams::ExecuteParams( CASE(kConvolution); CASE(kConvolutionReorder); CASE(kCopy); + CASE(kCopyDone); CASE(kCubSort); CASE(kCublasLtMatmul); CASE(kCustomCall); diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h index d7c06cbdb7284..41f2cfaf605bd 100644 --- a/xla/service/gpu/thunk.h +++ b/xla/service/gpu/thunk.h @@ -93,6 +93,7 @@ class Thunk { kConvolution, kConvolutionReorder, kCopy, + kCopyDone, kCommandBuffer, kCubSort, kCublasLtMatmul,