From b060b1766ac754e183a34236e6186bf066707229 Mon Sep 17 00:00:00 2001 From: Trevor Berrange Sanchez Date: Thu, 9 Jan 2025 01:47:49 +0100 Subject: [PATCH 1/5] RPC server: io_uring upgrade Separates the server into two parts: the context, and the work pool; the context contains everything generally needed to run the server, the work pool contains a statically polymorphic implementation for a pool to dispatch the actual work to. In doing this, we also separate certain things out into a few different files. The RPC server context API has been modified slightly to reflect this, and the work pool directly exposed, for now. --- src/accountsdb/db.zig | 1 + src/accountsdb/snapshots.zig | 11 +- src/rpc/lib.zig | 1 - src/rpc/server.zig | 592 ++++++++--------------- src/rpc/server/LinuxIoUring.zig | 810 ++++++++++++++++++++++++++++++++ src/rpc/server/connection.zig | 223 +++++++++ src/rpc/server/requests.zig | 269 +++++++++++ src/utils/fmt.zig | 2 +- src/utils/io.zig | 84 ++++ 9 files changed, 1585 insertions(+), 408 deletions(-) create mode 100644 src/rpc/server/LinuxIoUring.zig create mode 100644 src/rpc/server/connection.zig create mode 100644 src/rpc/server/requests.zig diff --git a/src/accountsdb/db.zig b/src/accountsdb/db.zig index 9d36754b0..496a7d24b 100644 --- a/src/accountsdb/db.zig +++ b/src/accountsdb/db.zig @@ -3390,6 +3390,7 @@ test "testWriteSnapshot" { ); } +/// Unpacks the snapshots from `sig.TEST_DATA_DIR`. pub fn findAndUnpackTestSnapshots( n_threads: usize, /// The directory into which the snapshots are unpacked. diff --git a/src/accountsdb/snapshots.zig b/src/accountsdb/snapshots.zig index d0c0844e8..eb53d70f4 100644 --- a/src/accountsdb/snapshots.zig +++ b/src/accountsdb/snapshots.zig @@ -2225,7 +2225,7 @@ pub const FullSnapshotFileInfo = struct { slot: Slot, hash: Hash, - const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("snapshot-{[slot]d}-{[hash]s}.tar.zst"); + pub const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("snapshot-{[slot]d}-{[hash]s}.tar.zst"); pub const SnapshotArchiveNameStr = SnapshotArchiveNameFmtSpec.BoundedArrayValue(.{ .slot = std.math.maxInt(Slot), @@ -2352,7 +2352,7 @@ pub const IncrementalSnapshotFileInfo = struct { }; } - const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("incremental-snapshot-{[base_slot]d}-{[slot]d}-{[hash]s}.tar.zst"); + pub const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("incremental-snapshot-{[base_slot]d}-{[slot]d}-{[hash]s}.tar.zst"); pub const SnapshotArchiveNameStr = SnapshotArchiveNameFmtSpec.BoundedArrayValue(.{ .base_slot = std.math.maxInt(Slot), @@ -2498,15 +2498,16 @@ pub const SnapshotFiles = struct { full: FullSnapshotFileInfo, incremental_info: ?SlotAndHash, - pub fn incremental(snapshot_files: SnapshotFiles) ?IncrementalSnapshotFileInfo { - const inc_info = snapshot_files.incremental_info orelse return null; + pub fn incremental(self: SnapshotFiles) ?IncrementalSnapshotFileInfo { + const inc_info = self.incremental_info orelse return null; return .{ - .base_slot = snapshot_files.full.slot, + .base_slot = self.full.slot, .slot = inc_info.slot, .hash = inc_info.hash, }; } + /// Asserts that `if (maybe_incremental_info) |inc| inc.base_slot == full_info.slot`. pub fn fromFileInfos( full_info: FullSnapshotFileInfo, maybe_incremental_info: ?IncrementalSnapshotFileInfo, diff --git a/src/rpc/lib.zig b/src/rpc/lib.zig index 0b5fdb86c..49a932fcb 100644 --- a/src/rpc/lib.zig +++ b/src/rpc/lib.zig @@ -6,7 +6,6 @@ pub const response = @import("response.zig"); pub const types = @import("types.zig"); pub const Client = client.Client; -pub const Server = server.Server; pub const Request = request.Request; pub const Response = response.Response; diff --git a/src/rpc/server.zig b/src/rpc/server.zig index c320e3900..3c66d1e64 100644 --- a/src/rpc/server.zig +++ b/src/rpc/server.zig @@ -1,39 +1,30 @@ +const builtin = @import("builtin"); const std = @import("std"); const sig = @import("../sig.zig"); +const connection = @import("server/connection.zig"); +const requests = @import("server/requests.zig"); + +const IoUring = std.os.linux.IoUring; const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; -const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; -const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; -const ThreadPool = sig.sync.ThreadPool; - -const LOGGER_SCOPE = "rpc.Server"; -const ScopedLogger = sig.trace.log.ScopedLogger(LOGGER_SCOPE); - -pub const Server = struct { - //! Basic usage: - //! ```zig - //! var server = try Server.init(.{...}); - //! defer server.joinDeinit(); - //! - //! try server.serveSpawnDetached(); // or `.serveDirect`, if the caller can block or is managing the separate thread themselves. - //! ``` +pub const Context = struct { allocator: std.mem.Allocator, logger: ScopedLogger, - snapshot_dir: std.fs.Dir, latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), /// Wait group for all currently running tasks, used to wait for /// all of them to finish before deinitializing. wait_group: std.Thread.WaitGroup, - thread_pool: *ThreadPool, - + tcp: std.net.Server, /// Must not be mutated. read_buffer_size: usize, - tcp: std.net.Server, - pub const MIN_READ_BUFFER_SIZE = 256; + pub const LOGGER_SCOPE = "rpc.Server"; + pub const ScopedLogger = sig.trace.log.ScopedLogger(LOGGER_SCOPE); + + pub const MIN_READ_BUFFER_SIZE = 4096; /// The returned result must be pinned to a memory location before calling any methods. pub fn init(params: struct { @@ -47,14 +38,12 @@ pub const Server = struct { /// given time with respect to the contents of the specified `snapshot_dir`. latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), - thread_pool: *ThreadPool, - /// The size for the read buffer allocated to every request. /// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`. read_buffer_size: u32, /// The socket address to listen on for incoming HTTP and/or RPC requests. socket_addr: std.net.Address, - }) std.net.Address.ListenError!Server { + }) std.net.Address.ListenError!Context { var tcp_server = try params.socket_addr.listen(.{ // NOTE: ideally we would be doing this nonblockingly, however this doesn't work properly on mac, // so for testing purposes we can't test the `serve` functionality directly. @@ -65,467 +54,250 @@ pub const Server = struct { return .{ .allocator = params.allocator, .logger = params.logger.withScope(LOGGER_SCOPE), - .snapshot_dir = params.snapshot_dir, .latest_snapshot_gen_info = params.latest_snapshot_gen_info, .wait_group = .{}, - .thread_pool = params.thread_pool, - .read_buffer_size = @max(params.read_buffer_size, MIN_READ_BUFFER_SIZE), .tcp = tcp_server, }; } - /// Blocks until all tasks are completed, and then closes the server. + /// Blocks until all tasks are completed, and then closes the server context. /// Does not force the server to exit. - pub fn joinDeinit(server: *Server) void { + pub fn joinDeinit(server: *Context) void { server.wait_group.wait(); server.tcp.deinit(); } /// Spawn the serve loop as a separate thread. pub fn serveSpawn( - server: *Server, + server_ctx: *Context, exit: *std.atomic.Value(bool), + /// The pool to dispatch work to. + work_pool: WorkPool, ) std.Thread.SpawnError!std.Thread { - return std.Thread.spawn(.{}, serve, .{ server, exit }); + return try std.Thread.spawn(.{}, serve, .{ server_ctx, exit, work_pool }); } /// Calls `acceptAndServeConnection` in a loop until `exit.load(.acquire)`. pub fn serve( - server: *Server, + server_ctx: *Context, exit: *std.atomic.Value(bool), - ) AcceptAndServeConnectionError!void { + /// The pool to dispatch work to. + work_pool: WorkPool, + ) WorkPool.AcceptAndServeConnectionError!void { while (!exit.load(.acquire)) { - try server.acceptAndServeConnection(); + try work_pool.acceptAndServeConnection(server_ctx); } } - - pub const AcceptAndServeConnectionError = - std.mem.Allocator.Error || - std.http.Server.ReceiveHeadError || - AcceptConnectionError; - - pub fn acceptAndServeConnection(server: *Server) AcceptAndServeConnectionError!void { - const conn = (try acceptConnection(&server.tcp, server.logger)).?; - errdefer conn.stream.close(); - - server.wait_group.start(); - errdefer server.wait_group.finish(); - - const new_hct = try HandleConnectionTask.createAndReceiveHead(server, conn); - errdefer new_hct.destroyAndClose(); - - server.thread_pool.schedule(ThreadPool.Batch.from(&new_hct.task)); - } }; -const HandleConnectionTask = struct { - task: ThreadPool.Task, - server: *Server, - http_server: std.http.Server, - request: std.http.Server.Request, - - fn createAndReceiveHead( - server: *Server, - conn: std.net.Server.Connection, - ) (std.http.Server.ReceiveHeadError || std.mem.Allocator.Error)!*HandleConnectionTask { - const allocator = server.allocator; - - const hct_buf_align = @alignOf(HandleConnectionTask); - const hct_buf_size = initBufferSize(server.read_buffer_size); - - const hct_buffer = try allocator.alignedAlloc(u8, hct_buf_align, hct_buf_size); - errdefer server.allocator.free(hct_buffer); - - const hct: *HandleConnectionTask = std.mem.bytesAsValue( - HandleConnectionTask, - hct_buffer[0..@sizeOf(HandleConnectionTask)], - ); - hct.* = .{ - .task = .{ .callback = callback }, - .server = server, - .http_server = std.http.Server.init(conn, getReadBuffer(server.read_buffer_size, hct)), - .request = try hct.http_server.receiveHead(), - }; - - return hct; - } - - /// Does not release the connection. - fn destroyAndClose(hct: *HandleConnectionTask) void { - const allocator = hct.server.allocator; - - const full_buffer = getFullBuffer(hct.server.read_buffer_size, hct); - defer allocator.free(full_buffer); - - const connection = hct.http_server.connection; - defer connection.stream.close(); - } - - fn initBufferSize(read_buffer_size: usize) usize { - return @sizeOf(HandleConnectionTask) + read_buffer_size; - } - - fn getFullBuffer( - read_buffer_size: usize, - hct: *HandleConnectionTask, - ) []align(@alignOf(HandleConnectionTask)) u8 { - const ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(hct); - return ptr[0..initBufferSize(read_buffer_size)]; - } - - fn getReadBuffer( - read_buffer_size: usize, - hct: *HandleConnectionTask, - ) []u8 { - return getFullBuffer(read_buffer_size, hct)[@sizeOf(HandleConnectionTask)..]; - } - - fn callback(task: *ThreadPool.Task) void { - const hct: *HandleConnectionTask = @fieldParentPtr("task", task); - defer hct.destroyAndClose(); - - const server = hct.server; - const logger = server.logger; - - const wait_group = &server.wait_group; - defer wait_group.finish(); - - handleRequest( - logger, - &hct.request, - server.snapshot_dir, - server.latest_snapshot_gen_info, - ) catch |err| { - if (@errorReturnTrace()) |stack_trace| { - logger.err().logf("{s}\n{}", .{ @errorName(err), stack_trace }); - } else { - logger.err().logf("{s}", .{@errorName(err)}); - } - }; - } -}; - -fn handleRequest( - logger: ScopedLogger, - request: *std.http.Server.Request, - snapshot_dir: std.fs.Dir, - latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), -) !void { - const conn_address = request.server.connection.address; - - logger.info().logf("Responding to request from {}: {} {s}", .{ - conn_address, methodFmt(request.head.method), request.head.target, - }); - switch (request.head.method) { - .POST => { - logger.err().logf("{} tried to invoke our RPC", .{conn_address}); - return try request.respond("RPCs are not yet implemented", .{ - .status = .service_unavailable, - .keep_alive = false, - }); - }, - .GET => get_blk: { - if (!std.mem.startsWith(u8, request.head.target, "/")) break :get_blk; - const path = request.head.target[1..]; - - // we hold the lock for the entirety of this process in order to prevent - // the snapshot generation process from deleting the associated snapshot. - const maybe_latest_snapshot_gen_info, // - var latest_snapshot_info_lg // - = latest_snapshot_gen_info_rw.readWithLock(); - defer latest_snapshot_info_lg.unlock(); - - const full_info: ?FullSnapshotFileInfo, // - const inc_info: ?IncrementalSnapshotFileInfo // - = blk: { - const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse - break :blk .{ null, null }; - const latest_full = latest_snapshot_gen_info.full; - const full_info: FullSnapshotFileInfo = .{ - .slot = latest_full.slot, - .hash = latest_full.hash, - }; - const latest_incremental = latest_snapshot_gen_info.inc orelse - break :blk .{ full_info, null }; - const inc_info: IncrementalSnapshotFileInfo = .{ - .base_slot = latest_full.slot, - .slot = latest_incremental.slot, - .hash = latest_incremental.hash, - }; - break :blk .{ full_info, inc_info }; - }; - - logger.debug().logf("Available full: {?s}", .{ - if (full_info) |info| info.snapshotArchiveName().constSlice() else null, - }); - logger.debug().logf("Available inc: {?s}", .{ - if (inc_info) |info| info.snapshotArchiveName().constSlice() else null, - }); - - if (full_info) |full| { - const full_archive_name_bounded = full.snapshotArchiveName(); - const full_archive_name = full_archive_name_bounded.constSlice(); - if (std.mem.eql(u8, path, full_archive_name)) { - const archive_file = try snapshot_dir.openFile(full_archive_name, .{}); - defer archive_file.close(); - var send_buffer: [4096]u8 = undefined; - try httpResponseSendFile(request, archive_file, &send_buffer); - return; - } - } - - if (inc_info) |inc| { - const inc_archive_name_bounded = inc.snapshotArchiveName(); - const inc_archive_name = inc_archive_name_bounded.constSlice(); - if (std.mem.eql(u8, path, inc_archive_name)) { - const archive_file = try snapshot_dir.openFile(inc_archive_name, .{}); - defer archive_file.close(); - var send_buffer: [4096]u8 = undefined; - try httpResponseSendFile(request, archive_file, &send_buffer); - return; - } - } - }, - else => {}, - } - - logger.err().logf( - "{} made an unrecognized request '{} {s}'", - .{ conn_address, methodFmt(request.head.method), request.head.target }, - ); - try request.respond("", .{ - .status = .not_found, - .keep_alive = false, - }); -} - -fn httpResponseSendFile( - request: *std.http.Server.Request, - archive_file: std.fs.File, - send_buffer: []u8, -) !void { - const archive_len = try archive_file.getEndPos(); - - var response = request.respondStreaming(.{ - .send_buffer = send_buffer, - .content_length = archive_len, - }); - const writer = sig.utils.io.narrowAnyWriter( - response.writer(), - std.http.Server.Response.WriteError, - ); - - const Fifo = std.fifo.LinearFifo(u8, .{ .Static = 1 }); - var fifo: Fifo = Fifo.init(); - try archive_file.seekTo(0); - try fifo.pump(archive_file.reader(), writer); +pub const WorkPool = union(enum) { + basic, + linux_io_uring: switch (LinuxIoUring.can_use) { + .yes, .check => *LinuxIoUring, + .no => noreturn, + }, - try response.end(); -} + pub const LinuxIoUring = @import("server/LinuxIoUring.zig"); -const AcceptConnectionError = error{ - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - SystemResources, - ProtocolFailure, - BlockedByFirewall, - NetworkSubsystemFailed, -} || std.posix.UnexpectedError; - -fn acceptConnection( - tcp_server: *std.net.Server, - logger: ScopedLogger, -) AcceptConnectionError!?std.net.Server.Connection { - const conn = tcp_server.accept() catch |err| switch (err) { - error.Unexpected, - => |e| return e, - - error.ProcessFdQuotaExceeded, - error.SystemFdQuotaExceeded, - error.SystemResources, - error.ProtocolFailure, - error.BlockedByFirewall, - error.NetworkSubsystemFailed, - => |e| return e, - - error.FileDescriptorNotASocket, - error.SocketNotListening, - error.OperationNotSupported, - => @panic("Improperly initialized server."), - - error.WouldBlock, - => return null, - - error.ConnectionResetByPeer, - error.ConnectionAborted, - => |e| { - logger.warn().logf("{}", .{e}); - return null; - }, - }; - - return conn; -} + const BasicAASCError = + connection.AcceptHandledError || + std.mem.Allocator.Error || + std.http.Server.ReceiveHeadError || + requests.HandleRequestError; + const IoUringAASCError = + LinuxIoUring.AcceptAndServeConnectionsError; -fn methodFmt(method: std.http.Method) MethodFmt { - return .{ .method = method }; -} + pub const AcceptAndServeConnectionError = + BasicAASCError || + IoUringAASCError; -const MethodFmt = struct { - method: std.http.Method, - pub fn format( - fmt: MethodFmt, - comptime fmt_str: []const u8, - fmt_options: std.fmt.FormatOptions, - writer: anytype, - ) @TypeOf(writer).Error!void { - _ = fmt_options; - if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt); - try fmt.method.write(writer); + pub fn acceptAndServeConnection( + work_pool: WorkPool, + server: *Context, + ) AcceptAndServeConnectionError!void { + switch (work_pool) { + .basic => { + const conn = try connection.acceptHandled(server.tcp); + defer conn.stream.close(); + + server.wait_group.start(); + defer server.wait_group.finish(); + + const buffer = try server.allocator.alloc(u8, server.read_buffer_size); + defer server.allocator.free(buffer); + + var http_server = std.http.Server.init(conn, buffer); + var request = try http_server.receiveHead(); + + try requests.handleRequest( + server.logger, + &request, + server.snapshot_dir, + server.latest_snapshot_gen_info, + ); + }, + .linux_io_uring => |linux| try linux.acceptAndServeConnections(server), + } } }; -test Server { +test Context { const allocator = std.testing.allocator; var prng = std.Random.DefaultPrng.init(0); const random = prng.random(); - // const logger: sig.trace.Logger = .{ .direct_print = .{ .max_level = .trace } }; - const logger: sig.trace.Logger = .noop; - - var test_data_dir = try std.fs.cwd().openDir("data/test-data", .{ .iterate = true }); - defer test_data_dir.close(); - var tmp_dir_root = std.testing.tmpDir(.{}); defer tmp_dir_root.cleanup(); const tmp_dir = tmp_dir_root.dir; - var snap_dir = try tmp_dir.makeOpenPath("snapshot", .{ .iterate = true }); - defer snap_dir.close(); - - const SnapshotFiles = sig.accounts_db.snapshots.SnapshotFiles; - const snap_files = try SnapshotFiles.find(allocator, test_data_dir); - - const full_snap_name_bounded = snap_files.full.snapshotArchiveName(); - const maybe_inc_snap_name_bounded = - if (snap_files.incremental()) |inc| inc.snapshotArchiveName() else null; - - { - const full_snap_name = full_snap_name_bounded.constSlice(); - - try test_data_dir.copyFile(full_snap_name, snap_dir, full_snap_name, .{}); - const full_snap_file = try snap_dir.openFile(full_snap_name, .{}); - defer full_snap_file.close(); - - const unpack = sig.accounts_db.snapshots.parallelUnpackZstdTarBall; - try unpack(allocator, logger, full_snap_file, snap_dir, 1, true); - } - - if (maybe_inc_snap_name_bounded) |inc_snap_name_bounded| { - const inc_snap_name = inc_snap_name_bounded.constSlice(); + // const logger: sig.trace.Logger = .{ .direct_print = .{ .max_level = .trace } }; + const logger: sig.trace.Logger = .noop; - try test_data_dir.copyFile(inc_snap_name, snap_dir, inc_snap_name, .{}); - const inc_snap_file = try snap_dir.openFile(inc_snap_name, .{}); - defer inc_snap_file.close(); + // the directory into which the snapshots will be unpacked. + var unpacked_snap_dir = try tmp_dir.makeOpenPath("snapshot", .{}); + defer unpacked_snap_dir.close(); - const unpack = sig.accounts_db.snapshots.parallelUnpackZstdTarBall; - try unpack(allocator, logger, inc_snap_file, snap_dir, 1, false); - } + // the source from which `fundAndUnpackTestSnapshots` will unpack the snapshots. + var test_data_dir = try std.fs.cwd().openDir(sig.TEST_DATA_DIR, .{ .iterate = true }); + defer test_data_dir.close(); - var accountsdb = try sig.accounts_db.AccountsDB.init(.{ - .allocator = allocator, - .logger = logger, - .snapshot_dir = snap_dir, - .geyser_writer = null, - .gossip_view = null, - .index_allocation = .ram, - .number_of_index_shards = 4, - .lru_size = null, - }); - defer accountsdb.deinit(); + const snap_files = try sig.accounts_db.db.findAndUnpackTestSnapshots( + std.Thread.getCpuCount() catch 1, + unpacked_snap_dir, + ); - { + var latest_snapshot_gen_info = sig.sync.RwMux(?SnapshotGenerationInfo).init(blk: { const FullAndIncrementalManifest = sig.accounts_db.snapshots.FullAndIncrementalManifest; const all_snap_fields = try FullAndIncrementalManifest.fromFiles( allocator, logger, - snap_dir, + unpacked_snap_dir, snap_files, ); defer all_snap_fields.deinit(allocator); - (try accountsdb.loadWithDefaults( - allocator, - all_snap_fields, - 1, - true, - 300, - false, - false, - )).deinit(allocator); - } - - var thread_pool = sig.sync.ThreadPool.init(.{ .max_threads = 1 }); - defer { - thread_pool.shutdown(); - thread_pool.deinit(); - } + break :blk .{ + .full = .{ + .slot = snap_files.full.slot, + .hash = snap_files.full.hash, + .capitalization = all_snap_fields.full.bank_fields.capitalization, + }, + .inc = inc: { + const inc = all_snap_fields.incremental orelse break :inc null; + // if the incremental snapshot field is not null, these shouldn't be either + const inc_info = snap_files.incremental_info.?; + const inc_persist = inc.bank_extra.snapshot_persistence.?; + break :inc .{ + .slot = inc_info.slot, + .hash = inc_info.hash, + .capitalization = inc_persist.incremental_capitalization, + }; + }, + }; + }); const rpc_port = random.intRangeLessThan(u16, 8_000, 10_000); - var rpc_server = try Server.init(.{ + var rpc_server_ctx = try Context.init(.{ .allocator = allocator, .logger = logger, - .snapshot_dir = snap_dir, - .latest_snapshot_gen_info = &accountsdb.latest_snapshot_gen_info, - .thread_pool = &thread_pool, + .snapshot_dir = test_data_dir, + .latest_snapshot_gen_info = &latest_snapshot_gen_info, .socket_addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, rpc_port), .read_buffer_size = 4096, }); - defer rpc_server.joinDeinit(); + defer rpc_server_ctx.joinDeinit(); - try testExpectSnapshotResponse( - allocator, - &rpc_server, - &full_snap_name_bounded, - snap_dir, + var maybe_liou = try WorkPool.LinuxIoUring.init(); + // TODO: currently `if (a) |*b|` on `?noreturn` causes analysis of the unwrap block, even though `if (a) |b|` doesn't. + // Filed a bug report for this: https://github.com/ziglang/zig/issues/22556, has a linked PR to fix it; hopefully should be fixed in 0.14. + defer if (maybe_liou != null) maybe_liou.?.deinit(); + + const self_url_bounded_str = sig.utils.fmt.boundedFmt( + "http://localhost:{d}/", + .{rpc_server_ctx.tcp.listen_address.getPort()}, ); + const self_uri = try std.Uri.parse(self_url_bounded_str.constSlice()); + + for ([_]?WorkPool{ + .basic, + if (maybe_liou != null) .{ .linux_io_uring = &maybe_liou.? } else null, // TODO: see above TODO about `if (a) |*b|` on `?noreturn`. + }) |maybe_work_pool| { + const work_pool = maybe_work_pool orelse continue; + + var exit = std.atomic.Value(bool).init(false); + const serve_thread = try rpc_server_ctx.serveSpawn(&exit, work_pool); + defer blk: { + exit.store(true, .release); + // send a dummy request so that the serve thread will get the accept and observe `exit`. + if (work_pool == .basic) interruptSelf(allocator, self_uri) catch |err| { + if (@errorReturnTrace()) |st| { + std.log.err("{s}\n{}", .{ @errorName(err), st }); + } else { + std.log.err("{s}", .{@errorName(err)}); + } + break :blk; // don't attempt to join the thread if an error occurred + }; + serve_thread.join(); + } - if (maybe_inc_snap_name_bounded) |inc_snap_name_bounded| { try testExpectSnapshotResponse( allocator, - &rpc_server, - &inc_snap_name_bounded, - snap_dir, + test_data_dir, + rpc_server_ctx.tcp.listen_address.getPort(), + .full, + snap_files.full, ); + + if (snap_files.incremental()) |inc| { + try testExpectSnapshotResponse( + allocator, + test_data_dir, + rpc_server_ctx.tcp.listen_address.getPort(), + .incremental, + inc, + ); + } } } fn testExpectSnapshotResponse( allocator: std.mem.Allocator, - rpc_server: *Server, - snap_name_bounded: anytype, snap_dir: std.fs.Dir, + rpc_port: u16, + comptime kind: enum { full, incremental }, + snap_info: switch (kind) { + .full => sig.accounts_db.snapshots.FullSnapshotFileInfo, + .incremental => sig.accounts_db.snapshots.IncrementalSnapshotFileInfo, + }, ) !void { - const rpc_port = rpc_server.tcp.listen_address.getPort(); + const snap_name_bounded = snap_info.snapshotArchiveName(); + const snap_name = snap_name_bounded.constSlice(); + + const expected_file = try snap_dir.openFile(snap_name, .{}); + defer expected_file.close(); + + const expected_data: []align(std.mem.page_size) const u8 = try std.posix.mmap( + null, + try expected_file.getEndPos(), + std.posix.PROT.READ, + .{ .TYPE = .PRIVATE }, + expected_file.handle, + 0, + ); + defer std.posix.munmap(expected_data); + const snap_url_str_bounded = sig.utils.fmt.boundedFmt( "http://localhost:{d}/{s}", - .{ rpc_port, sig.utils.fmt.boundedString(snap_name_bounded) }, + .{ rpc_port, sig.utils.fmt.boundedString(&snap_name_bounded) }, ); const snap_url = try std.Uri.parse(snap_url_str_bounded.constSlice()); - const serve_thread = try std.Thread.spawn(.{}, Server.acceptAndServeConnection, .{rpc_server}); const actual_data = try testDownloadSelfSnapshot(allocator, snap_url); defer allocator.free(actual_data); - serve_thread.join(); - - const snap_name = snap_name_bounded.constSlice(); - - const expected_data = try snap_dir.readFileAlloc(allocator, snap_name, 1 << 32); - defer allocator.free(expected_data); try std.testing.expectEqualSlices(u8, expected_data, actual_data); } @@ -557,3 +329,21 @@ fn testDownloadSelfSnapshot( return response_content; } + +fn interruptSelf( + allocator: std.mem.Allocator, + snap_url: std.Uri, +) !void { + var client: std.http.Client = .{ .allocator = allocator }; + defer client.deinit(); + + var server_header_buffer: [4096 * 16]u8 = undefined; + var request = try client.open(.HEAD, snap_url, .{ + .server_header_buffer = &server_header_buffer, + }); + defer request.deinit(); + + try request.send(); + try request.finish(); + try request.wait(); +} diff --git a/src/rpc/server/LinuxIoUring.zig b/src/rpc/server/LinuxIoUring.zig new file mode 100644 index 000000000..e4992e09d --- /dev/null +++ b/src/rpc/server/LinuxIoUring.zig @@ -0,0 +1,810 @@ +const builtin = @import("builtin"); +const std = @import("std"); +const sig = @import("../../sig.zig"); + +const connection = @import("connection.zig"); +const requests = @import("requests.zig"); + +const IoUring = std.os.linux.IoUring; +const ServerCtx = sig.rpc.server.Context; + +const LinuxIoUring = @This(); +io_uring: IoUring, +multishot_accept_submitted: bool, +pending_cqes_count: u8, +pending_cqes_buf: [255]std.os.linux.io_uring_cqe, + +pub const can_use: enum { no, yes, check } = switch (builtin.os.getVersionRange()) { + .linux => |version| can_use: { + const min_version: std.SemanticVersion = .{ .major = 6, .minor = 0, .patch = 0 }; + const is_at_least = version.isAtLeast(min_version) orelse break :can_use .check; + break :can_use if (is_at_least) .yes else .no; + }, + else => .no, +}; + +pub const InitError = std.posix.MMapError || error{ + EntriesZero, + EntriesNotPowerOfTwo, + + ParamsOutsideAccessibleAddressSpace, + ArgumentsInvalid, + ProcessFdQuotaExceeded, + SystemFdQuotaExceeded, + SystemResources, + + PermissionDenied, + SystemOutdated, +}; + +// NOTE(ink): constructing the return type as `E!?T`, where `E` and `T` are resolved +// separately seems to help ZLS with understanding the types involved better, which is +// why I've done it like that here. If ZLS gets smarter in the future, you could probably +// inline this into a single branch in the return type expression. +const InitErrOrEmpty = if (can_use == .no) error{} else InitError; +const InitResultOrNoreturn = if (can_use == .no) noreturn else LinuxIoUring; +pub fn init() InitErrOrEmpty!?InitResultOrNoreturn { + const need_runtime_check = switch (can_use) { + .no => return null, + .yes => false, + .check => true, + }; + + var io_uring = IoUring.init(4096, 0) catch |err| return switch (err) { + error.SystemOutdated, + error.PermissionDenied, + => |e| if (!need_runtime_check) e else return null, + else => |e| e, + }; + errdefer io_uring.deinit(); + + return .{ + .io_uring = io_uring, + .multishot_accept_submitted = false, + .pending_cqes_count = 0, + .pending_cqes_buf = undefined, + }; +} + +pub fn deinit(self: *LinuxIoUring) void { + self.io_uring.deinit(); +} + +pub const AcceptAndServeConnectionsError = error{ + /// This was the first call, and we failed to prep, queue, and submit the multishot accept. + FailedToAcceptMultishot, + SubmissionQueueFull, +} || IouSubmitError || + HandleOurCqeError || + std.mem.Allocator.Error; + +pub fn acceptAndServeConnections( + self: *LinuxIoUring, + server_ctx: *ServerCtx, +) AcceptAndServeConnectionsError!void { + if (!self.multishot_accept_submitted) { + self.multishot_accept_submitted = true; + errdefer self.multishot_accept_submitted = false; + _ = self.io_uring.accept_multishot( + @bitCast(Entry.ACCEPT), + server_ctx.tcp.stream.handle, + null, + null, + std.os.linux.SOCK.CLOEXEC, + ) catch |err| return switch (err) { + error.SubmissionQueueFull => { + server_ctx.logger.err().log( + "Under normal circumstances the accept_multishot would be" ++ + " the first SQE to be queued, but somehow the queue was full.", + ); + return error.FailedToAcceptMultishot; + }, + }; + if (try self.io_uring.submit() != 1) { + return error.FailedToAcceptMultishot; + } + return; + } + + _ = try self.io_uring.submit(); + + if (self.pending_cqes_count != self.pending_cqes_buf.len) { + self.pending_cqes_count += @intCast(try self.io_uring.copy_cqes(self.pending_cqes_buf[self.pending_cqes_count..], 0)); + } + const cqes_pending = self.pending_cqes_buf[0..self.pending_cqes_count]; + + for (cqes_pending, 0..) |raw_cqe, i| { + self.pending_cqes_count -= 1; + errdefer std.mem.copyForwards( + std.os.linux.io_uring_cqe, + self.pending_cqes_buf[0..self.pending_cqes_count], + self.pending_cqes_buf[i + 1 ..][0..self.pending_cqes_count], + ); + const our_cqe = OurCqe.fromCqe(raw_cqe); + consumeOurCqe(self, server_ctx, our_cqe) catch |err| switch (err) { + // connection errors + error.ConnectionAborted, + error.ConnectionRefused, + error.ConnectionResetByPeer, + error.ConnectionTimedOut, + + // our http parse errors + error.RequestHeadersTooBig, + error.RequestTargetTooLong, + error.RequestContentTypeUnrecognized, + + // std http parse errors + error.UnknownHttpMethod, + error.HttpHeadersInvalid, + error.InvalidContentLength, + error.HttpHeaderContinuationsUnsupported, + error.HttpTransferEncodingUnsupported, + error.HttpConnectionHeaderUnsupported, + error.CompressionUnsupported, + error.MissingFinalNewline, + + // splice errors + error.BadFileDescriptors, + error.BadFdOffset, + error.InvalidSplice, + => |e| { + server_ctx.logger.err().logf("{s}", .{@errorName(e)}); + continue; + }, + + error.SubmissionQueueFull => |e| return e, + else => |e| return e, + }; + } +} + +const HandleOurCqeError = error{ + SubmissionQueueFull, + + /// Connection was aborted; not necessarily critical. + ConnectionAborted, + /// A remote host refused to allow the network connection, typically because it is not + /// running the requested service. + ConnectionRefused, + /// A remote host refused to allow the network connection, typically because it is not + /// running the requested service. + ConnectionResetByPeer, + ConnectionTimedOut, + + /// The headers recv'd in a request were too big. + RequestHeadersTooBig, + /// The request line recv'd was too long. + RequestTargetTooLong, + /// The request 'Content-Type' did not match any recognized `ContentType`. + RequestContentTypeUnrecognized, +} || connection.HandleAcceptError || + connection.HandleRecvError || + connection.HandleSendError || + connection.HandleSpliceError || + std.mem.Allocator.Error || + std.http.Server.Request.Head.ParseError || + std.fs.File.OpenError || + std.fs.File.GetSeekPosError; + +/// On return, `cqe.user_data` is in an undefined state - this is to say, +/// it has either already been `deinit`ed, or it has been been re-submitted +/// in a new `SQE` and should not be modified; in either scenario, the caller +/// should not interact with it. +fn consumeOurCqe( + liou: *LinuxIoUring, + server_ctx: *ServerCtx, + cqe: OurCqe, +) HandleOurCqeError!void { + const entry = cqe.user_data; + errdefer entry.deinit(server_ctx.allocator); + + const entry_data: *EntryData = entry.ptr orelse { + // multishot accept cqe + + switch (try connection.handleAcceptResult(cqe.err())) { + .success => {}, + .intr => std.debug.panic("TODO: does this mean the multishot accept has stopped? If no, just warn. If yes, re-queue here and warn.", .{}), // TODO: + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + .conn_aborted => return error.ConnectionAborted, + } + + const stream: std.net.Stream = .{ .handle = cqe.res }; + errdefer stream.close(); + + server_ctx.wait_group.start(); + errdefer server_ctx.wait_group.finish(); + + const buffer = try server_ctx.allocator.alloc(u8, server_ctx.read_buffer_size); + errdefer server_ctx.allocator.free(buffer); + + const new_recv_entry: Entry = entry: { + const data_ptr = try server_ctx.allocator.create(EntryData); + errdefer comptime unreachable; + + data_ptr.* = .{ + .buffer = buffer, + .stream = stream, + .state = EntryState.INIT, + }; + break :entry .{ .ptr = data_ptr }; + }; + errdefer if (new_recv_entry.ptr) |data_ptr| server_ctx.allocator.destroy(data_ptr); + + _ = liou.io_uring.recv( + @bitCast(new_recv_entry), + stream.handle, + .{ .buffer = buffer }, + 0, + ) catch |err| switch (err) { + error.SubmissionQueueFull => |e| { + server_ctx.logger.err().logf( + "Failed to submit the SQE for the initial recv" ++ + " for the connection from '{!}'", + .{connection.getSockName(stream.handle)}, // if we fail to getSockName, just print the error in place of the address + ); + return e; + }, + }; + + return; + }; + errdefer server_ctx.wait_group.finish(); + + const err_logger = server_ctx.logger.err().field( + "address", + // if we fail to getSockName, just print the error in place of the address; + connection.getSockName(entry_data.stream.handle), + ); + errdefer err_logger.logf("Dropping connection", .{}); + + switch (entry_data.state) { + .recv_head => |*head| { + switch (try connection.handleRecvResult(cqe.err())) { + .success => {}, + + .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + + .conn_refused => return error.ConnectionRefused, + .conn_reset => return error.ConnectionResetByPeer, + .timed_out => return error.ConnectionTimedOut, + } + + const recv_len: usize = @intCast(cqe.res); + std.debug.assert(head.parser.state != .finished); + + const recv_start = head.end; + const recv_end = recv_start + recv_len; + head.end += head.parser.feed(entry_data.buffer[recv_start..recv_end]); + + if (head.parser.state != .finished) { + std.debug.assert(head.end == recv_end); + if (head.end == entry_data.buffer.len) { + return error.RequestHeadersTooBig; + } + + _ = try liou.io_uring.recv( + @bitCast(entry), + entry_data.stream.handle, + .{ .buffer = entry_data.buffer[head.end..] }, + 0, + ); + return; + } + + // copy relevant headers and information out of the buffer, + // so we can use the buffer exclusively for the request body. + const HeadInfo = requests.HeadInfo; + const head_info: HeadInfo = head_info: { + const head_bytes = entry_data.buffer[0..head.end]; + const std_head = try std.http.Server.Request.Head.parse(head_bytes); + std.debug.assert(std_head.compression == .none); // at the time of writing, this always holds true for the result of `Head.parse`. + break :head_info HeadInfo.parseFromStdHead(std_head) catch |err| switch (err) { + error.RequestTargetTooLong => |e| { + err_logger.logf("Request target was too long: '{}'", .{ + std.zig.fmtEscapes(std_head.target), + }); + return e; + }, + else => |e| return e, + }; + }; + + // ^ we just copied the relevant head info, so we're going to move + // the body content to the start of the buffer. + const content_end = blk: { + const old_content_bytes = entry_data.buffer[head.end..recv_end]; + std.mem.copyForwards( + u8, + entry_data.buffer[0..old_content_bytes.len], + old_content_bytes, + ); + break :blk old_content_bytes.len; + }; + + entry_data.state = .{ .recv_body = .{ + .head_info = head_info, + .need_to_check_cqe = false, + .content_end = content_end, + } }; + const body = &entry_data.state.recv_body; + try handleRecvBody(liou, server_ctx, err_logger, entry, body); + return; + }, + + .recv_body => |*body| { + if (body.need_to_check_cqe) { + switch (try connection.handleRecvResult(cqe.err())) { + .success => {}, + + .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + + .conn_refused => return error.ConnectionRefused, + .conn_reset => return error.ConnectionResetByPeer, + .timed_out => return error.ConnectionTimedOut, + } + + const recv_len: usize = @intCast(cqe.res); + body.content_end += recv_len; + } + + try handleRecvBody(liou, server_ctx, err_logger, entry, body); + return; + }, + + .send_file_head => |*sfh| { + switch (try connection.handleSendResult(cqe.err())) { + .success => {}, + .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + } + const sent_len: usize = @intCast(cqe.res); + sfh.sent_bytes += sent_len; + + switch (try sfh.computeAndMaybePrepSend(entry, &liou.io_uring)) { + .sending_more => return, + .all_sent => { + const sfd = sfh.sfd; + entry_data.state = .{ .send_file_body = .{ + .sfd = sfd, + .spliced_to_pipe = 0, + .spliced_to_socket = 0, + .which = .to_pipe, + } }; + const sfb = &entry_data.state.send_file_body; + try sfb.prepSpliceFileToPipe(entry, &liou.io_uring); + return; + }, + } + }, + + .send_file_body => |*sfb| switch (sfb.which) { + .to_pipe => { + switch (try connection.handleSpliceResult(cqe.err())) { + .success => {}, + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + } + sfb.spliced_to_pipe += @intCast(cqe.res); + + sfb.which = .to_socket; + try sfb.prepSplicePipeToSocket(entry, &liou.io_uring); + + return; + }, + .to_socket => { + switch (try connection.handleSpliceResult(cqe.err())) { + .success => {}, + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + } + sfb.spliced_to_socket += @intCast(cqe.res); + + if (sfb.spliced_to_socket < sfb.sfd.file_size) { + sfb.which = .to_pipe; + try sfb.prepSpliceFileToPipe(entry, &liou.io_uring); + } else { + std.debug.assert(sfb.spliced_to_socket == sfb.spliced_to_pipe); + entry.deinit(server_ctx.allocator); + server_ctx.wait_group.finish(); + } + return; + }, + }, + + .send_no_body => |*snb| { + switch (try connection.handleSendResult(cqe.err())) { + .success => {}, + .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + } + const sent_len: usize = @intCast(cqe.res); + snb.end_index += sent_len; + + if (snb.end_index < snb.head.len) { + try snb.prepSend(entry, &liou.io_uring); + } else std.debug.assert(snb.end_index == snb.head.len); + + entry.deinit(server_ctx.allocator); + server_ctx.wait_group.finish(); + return; + }, + } +} + +fn handleRecvBody( + liou: *LinuxIoUring, + server_ctx: *ServerCtx, + err_logger: anytype, + entry: Entry, + body: *EntryState.RecvBody, +) !void { + const entry_data = entry.ptr.?; + std.debug.assert(body == &entry_data.state.recv_body); + + if (!body.head_info.method.requestHasBody()) { + if (body.head_info.content_len) |content_len| { + err_logger.logf( + "{} request isn't expected to have a body, but got Content-Length: {d}", + .{ requests.methodFmt(body.head_info.method), content_len }, + ); + } + } + + switch (body.head_info.method) { + .POST => { + entry_data.state = .{ + .send_no_body = EntryState.SendNoBody.initHttStatus( + .@"HTTP/1.0", + .service_unavailable, + ), + }; + const snb = &entry_data.state.send_no_body; + try snb.prepSend(entry, &liou.io_uring); + return; + }, + + .GET => switch (requests.getRequestTargetResolve( + server_ctx.logger, + body.head_info.target.constSlice(), + server_ctx.latest_snapshot_gen_info, + )) { + inline .full_snapshot, .inc_snapshot => |pair| { + const snap_info, var full_info_lg = pair; + errdefer full_info_lg.unlock(); + + const archive_name_bounded = snap_info.snapshotArchiveName(); + const archive_name = archive_name_bounded.constSlice(); + + const snapshot_dir = server_ctx.snapshot_dir; + const archive_file = try snapshot_dir.openFile(archive_name, .{}); + errdefer archive_file.close(); + const file_size = try archive_file.getEndPos(); + + const pipe_r, const pipe_w = try std.posix.pipe(); + errdefer std.posix.close(pipe_w); + errdefer std.posix.close(pipe_r); + + entry_data.state = .{ .send_file_head = .{ + .sfd = .{ + .file_lg = full_info_lg, + .file = archive_file, + .file_size = file_size, + + .pipe_w = pipe_w, + .pipe_r = pipe_r, + }, + .sent_bytes = 0, + } }; + const sfh = &entry_data.state.send_file_head; + switch (try sfh.computeAndMaybePrepSend(entry, &liou.io_uring)) { + .sending_more => return, + .all_sent => unreachable, // we know this for certain + } + }, + .unrecognized => {}, + }, + + else => {}, + } + + entry_data.state = .{ + .send_no_body = EntryState.SendNoBody.initHttStatus( + .@"HTTP/1.0", + .not_found, + ), + }; + const snb = &entry_data.state.send_no_body; + try snb.prepSend(entry, &liou.io_uring); +} + +const OurCqe = extern struct { + user_data: Entry, + res: i32, + flags: u32, + + fn fromCqe(cqe: std.os.linux.io_uring_cqe) OurCqe { + return .{ + .user_data = @bitCast(cqe.user_data), + .res = cqe.res, + .flags = cqe.flags, + }; + } + + fn asCqe(self: OurCqe) std.os.linux.io_uring_cqe { + return .{ + .user_data = @bitCast(self.user_data), + .res = self.res, + .flags = self.flags, + }; + } + + fn err(self: OurCqe) std.os.linux.E { + return self.asCqe().err(); + } +}; + +const Entry = packed struct(u64) { + /// If null, this is an `accept` entry. + ptr: ?*EntryData, + + const ACCEPT: Entry = .{ .ptr = null }; + + fn deinit(entry: Entry, allocator: std.mem.Allocator) void { + const ptr = entry.ptr orelse return; + ptr.deinit(allocator); + allocator.destroy(ptr); + } +}; + +const EntryData = struct { + buffer: []u8, + stream: std.net.Stream, + state: EntryState, + + fn deinit(data: *EntryData, allocator: std.mem.Allocator) void { + data.state.deinit(); + allocator.free(data.buffer); + data.stream.close(); + } +}; + +const EntryState = union(enum) { + recv_head: RecvHead, + recv_body: RecvBody, + send_file_head: SendFileHead, + send_file_body: SendFileBody, + send_no_body: SendNoBody, + + const INIT: EntryState = .{ + .recv_head = .{ + .end = 0, + .parser = .{}, + }, + }; + + fn deinit(state: *EntryState) void { + switch (state.*) { + .recv_head => {}, + .recv_body => {}, + .send_file_head => |*sfh| sfh.deinit(), + .send_file_body => |*sfb| sfb.deinit(), + .send_no_body => {}, + } + } + + const RecvHead = struct { + end: usize, + parser: std.http.HeadParser, + }; + + const RecvBody = struct { + head_info: requests.HeadInfo, + /// Should be true when submitting the SQE. + /// Will be true when receving the CQE, and false when we've + /// been `continue`'d into by another prong in the switch loop. + need_to_check_cqe: bool, + /// The current number of content bytes read into the buffer. + content_end: usize, + }; + + const SendFileData = struct { + file_lg: requests.GetRequestTargetResolved.SnapshotReadLock, + file: std.fs.File, + file_size: u64, + + pipe_w: std.os.linux.fd_t, + pipe_r: std.os.linux.fd_t, + + fn deinit(self: *SendFileData) void { + self.file.close(); + self.file_lg.unlock(); + std.posix.close(self.pipe_w); + std.posix.close(self.pipe_r); + } + }; + + const SendFileHead = struct { + sfd: SendFileData, + sent_bytes: u64, + + fn deinit(self: *SendFileHead) void { + self.sfd.deinit(); + } + + fn computeAndMaybePrepSend( + self: *SendFileHead, + entry: Entry, + io_uring: *IoUring, + ) !enum { + /// The head has been fully sent already, no send was prepped. + all_sent, + /// There is still more head data to send. + sending_more, + } { + const entry_data = entry.ptr.?; + std.debug.assert(self == &entry_data.state.send_file_head); + + const rendered_len = blk: { + // render segments of the head into our buffer, + // sending them as they become rendered. + + var ww = sig.utils.io.WindowedWriter.init(entry_data.buffer, self.sent_bytes); + var cw = std.io.countingWriter(ww.writer()); + const writer = cw.writer(); + + const status: std.http.Status = .ok; + writer.print("{[version]s} {[status]d}{[space]s}{[phrase]s}\r\n", .{ + .version = @tagName(std.http.Version.@"HTTP/1.0"), + .status = @intFromEnum(status), + .space = if (status.phrase() != null) " " else "", + .phrase = if (status.phrase()) |str| str else "", + }) catch |err| switch (err) {}; + + writer.print("Content-Length: {d}\r\n", .{ + self.sfd.file_size, + }) catch |err| switch (err) {}; + + writer.writeAll("\r\n") catch |err| switch (err) {}; + + if (self.sent_bytes == cw.bytes_written) return .all_sent; + std.debug.assert(self.sent_bytes < cw.bytes_written); + break :blk ww.end_index; + }; + + _ = try io_uring.send( + @bitCast(entry), + entry_data.stream.handle, + entry_data.buffer[0..rendered_len], + 0, + ); + + return .sending_more; + } + }; + + const SendFileBody = struct { + sfd: SendFileData, + spliced_to_pipe: u64, + spliced_to_socket: u64, + which: Which, + + const Which = enum { + to_pipe, + to_socket, + }; + + fn deinit(self: *SendFileBody) void { + self.sfd.deinit(); + } + + fn prepSpliceFileToPipe( + self: *const SendFileBody, + entry: Entry, + io_uring: *IoUring, + ) !void { + const entry_ptr = entry.ptr.?; + std.debug.assert(self == &entry_ptr.state.send_file_body); + std.debug.assert(self.which == .to_pipe); + + _ = try io_uring.splice( + @bitCast(entry), + self.sfd.file.handle, + self.spliced_to_pipe, + self.sfd.pipe_w, + std.math.maxInt(u64), + self.sfd.file_size - self.spliced_to_pipe, + ); + } + + fn prepSplicePipeToSocket( + self: *const SendFileBody, + entry: Entry, + io_uring: *IoUring, + ) !void { + const entry_ptr = entry.ptr.?; + std.debug.assert(self == &entry_ptr.state.send_file_body); + std.debug.assert(self.which == .to_socket); + + const stream = entry_ptr.stream; + _ = try io_uring.splice( + @bitCast(entry), + self.sfd.pipe_r, + std.math.maxInt(u64), + stream.handle, + std.math.maxInt(u64), + self.sfd.file_size - self.spliced_to_socket, + ); + } + }; + + const SendNoBody = struct { + /// Should be a statically-lived string. + head: []const u8, + end_index: usize, + + fn initString(comptime str: []const u8) SendNoBody { + return .{ + .head = str, + .end_index = 0, + }; + } + + fn initHttStatus( + comptime version: std.http.Version, + comptime status: std.http.Status, + ) SendNoBody { + const head = comptime std.fmt.comptimePrint("{s} {d}{s}\r\n\r\n", .{ + @tagName(version), + @intFromEnum(status), + if (status.phrase()) |phrase| " " ++ phrase else "", + }); + return initString(head); + } + + fn prepSend( + self: *const SendNoBody, + entry: Entry, + io_uring: *IoUring, + ) !void { + const entry_ptr = entry.ptr.?; + std.debug.assert(self == &entry_ptr.state.send_no_body); + _ = try io_uring.send( + @bitCast(entry), + entry_ptr.stream.handle, + self.head[self.end_index..], + 0, + ); + } + }; +}; + +/// Extracted from `std.os.linux.IoUring.submit` +const IouSubmitError = IouEnterError; + +/// Extracted from `std.os.linux.IoUring.enter`. +const IouEnterError = error{ + /// The kernel was unable to allocate memory or ran out of resources for the request. + /// The application should wait for some completions and try again. + SystemResources, + /// The SQE `fd` is invalid, or IOSQE_FIXED_FILE was set but no files were registered. + FileDescriptorInvalid, + /// The file descriptor is valid, but the ring is not in the right state. + /// See io_uring_register(2) for how to enable the ring. + FileDescriptorInBadState, + /// The application attempted to overcommit the number of requests it can have pending. + /// The application should wait for some completions and try again. + CompletionQueueOvercommitted, + /// The SQE is invalid, or valid but the ring was setup with IORING_SETUP_IOPOLL. + SubmissionQueueEntryInvalid, + /// The buffer is outside the process' accessible address space, or IORING_OP_READ_FIXED + /// or IORING_OP_WRITE_FIXED was specified but no buffers were registered, or the range + /// described by `addr` and `len` is not within the buffer registered at `buf_index`: + BufferInvalid, + RingShuttingDown, + /// The kernel believes our `self.fd` does not refer to an io_uring instance, + /// or the opcode is valid but not supported by this kernel (more likely): + OpcodeNotSupported, + /// The operation was interrupted by a delivery of a signal before it could complete. + /// This can happen while waiting for events with IORING_ENTER_GETEVENTS: + SignalInterrupt, +} || std.posix.UnexpectedError; diff --git a/src/rpc/server/connection.zig b/src/rpc/server/connection.zig new file mode 100644 index 000000000..133bb06b2 --- /dev/null +++ b/src/rpc/server/connection.zig @@ -0,0 +1,223 @@ +const builtin = @import("builtin"); +const std = @import("std"); + +pub fn getSockName( + socket_handle: std.posix.socket_t, +) std.posix.GetSockNameError!std.net.Address { + var addr: std.net.Address = .{ .any = undefined }; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr.any)); + try std.posix.getsockname(socket_handle, &addr.any, &addr_len); + return addr; +} + +pub const AcceptHandledError = HandleAcceptError || error{ + ConnectionAborted, + WouldBlock, +}; +pub fn acceptHandled( + tcp_server: std.net.Server, +) AcceptHandledError!std.net.Server.Connection { + while (true) { + var addr: std.net.Address = .{ .any = undefined }; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr.any)); + const rc = if (!builtin.target.isDarwin()) std.posix.system.accept4( + tcp_server.stream.handle, + &addr.any, + &addr_len, + std.posix.SOCK.CLOEXEC, + ) else std.posix.system.accept( + tcp_server.stream.handle, + &addr.any, + &addr_len, + ); + + return switch (try handleAcceptResult(std.posix.errno(rc))) { + .intr => continue, + .conn_aborted => error.ConnectionAborted, + .again => error.WouldBlock, + .success => .{ + .stream = .{ .handle = rc }, + .address = addr, + }, + }; + } +} + +pub const HandleAcceptError = error{ + ProcessFdQuotaExceeded, + SystemFdQuotaExceeded, + SystemResources, + ProtocolFailure, + BlockedByFirewall, +} || std.posix.UnexpectedError; + +pub const HandleAcceptResult = enum { + success, + intr, + again, + conn_aborted, +}; + +/// Resembles the error handling of `std.posix.accept`. +pub fn handleAcceptResult( + /// Must be the result of `std.posix.accept` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleAcceptError!HandleAcceptResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + .CONNABORTED => .conn_aborted, + + .BADF, // always a race condition + .FAULT, // don't address bad memory + .NOTSOCK, // don't call accept on a non-socket + .OPNOTSUPP, // socket must support accept + .INVAL, // socket must be listening + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .MFILE => return error.ProcessFdQuotaExceeded, + .NFILE => return error.SystemFdQuotaExceeded, + .NOBUFS => return error.SystemResources, + .NOMEM => return error.SystemResources, + .PROTO => return error.ProtocolFailure, + .PERM => return error.BlockedByFirewall, + else => |err| return std.posix.unexpectedErrno(err), + }; +} + +pub const HandleRecvError = error{ + SystemResources, +} || std.posix.UnexpectedError; + +pub const HandleRecvResult = enum { + success, + intr, + again, + conn_refused, + conn_reset, + timed_out, +}; + +/// Resembles the error handling of `std.posix.recv`. +pub fn handleRecvResult( + /// Must be the result of `std.posix.recv` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleRecvError!HandleRecvResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + .CONNREFUSED => .conn_refused, + .CONNRESET => .conn_reset, + .TIMEDOUT => .timed_out, + + .BADF, // always a race condition + .FAULT, // don't address bad memory + .INVAL, // socket must be listening + .NOTSOCK, // don't call accept on a non-socket + .NOTCONN, // we should always be connected + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .NOMEM => return error.SystemResources, + else => |err| return std.posix.unexpectedErrno(err), + }; +} + +pub const HandleSendError = error{ + AccessDenied, + FastOpenAlreadyInProgress, + ConnectionResetByPeer, + MessageTooBig, + SystemResources, + BrokenPipe, + NetworkSubsystemFailed, +} || std.posix.UnexpectedError; + +pub const HandleSendResult = enum { + success, + intr, + again, +}; + +pub fn handleSendResult( + /// Must be the result of `std.posix.send` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleSendError!HandleSendResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + + .BADF, // always a race condition + .DESTADDRREQ, // The socket is not connection-mode, and no peer address is set. + .FAULT, // An invalid user space address was specified for an argument. + .ISCONN, // connection-mode socket was connected already but a recipient was specified + .NOTSOCK, // The file descriptor sockfd does not refer to a socket. + .OPNOTSUPP, // Some bit in the flags argument is inappropriate for the socket type. + + // these are all reachable through `sendto`, but unreachable through `send`. + .AFNOSUPPORT, + .LOOP, + .NAMETOOLONG, + .NOENT, + .NOTDIR, + .HOSTUNREACH, + .NETUNREACH, + .NOTCONN, + .INVAL, + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .ACCES => return error.AccessDenied, + .ALREADY => return error.FastOpenAlreadyInProgress, + .CONNRESET => return error.ConnectionResetByPeer, + .MSGSIZE => return error.MessageTooBig, + .NOBUFS, .NOMEM => return error.SystemResources, + .PIPE => return error.BrokenPipe, + .NETDOWN => return error.NetworkSubsystemFailed, + else => |e| std.posix.unexpectedErrno(e), + }; +} + +pub const HandleSpliceError = error{ + /// One or both file descriptors are not valid, or do not have proper read-write mode. + BadFileDescriptors, + /// Either off_in or off_out was not NULL, but the corresponding file descriptor refers to a pipe. + BadFdOffset, + /// Could be one of many reasons, see the manpage for splice. + InvalidSplice, + /// Out of memory. + SystemResources, +} || std.posix.UnexpectedError; + +pub const HandleSpliceResult = enum { + success, + again, +}; + +pub fn handleSpliceResult( + /// Must be the result of calling the `splice` syscall or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleSpliceError!HandleSpliceResult { + comptime std.debug.assert( // + builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .AGAIN => .again, + .INVAL => return error.InvalidSplice, + .SPIPE => return error.BadFdOffset, + .BADF => return error.BadFileDescriptors, + .NOMEM => return error.SystemResources, + else => |err| std.posix.unexpectedErrno(err), + }; +} diff --git a/src/rpc/server/requests.zig b/src/rpc/server/requests.zig new file mode 100644 index 000000000..8d9e15f0f --- /dev/null +++ b/src/rpc/server/requests.zig @@ -0,0 +1,269 @@ +//! This file defines most of the shared logic for the bounds and handling +//! of RPC requests. + +const builtin = @import("builtin"); +const std = @import("std"); +const sig = @import("../../sig.zig"); +const connection = @import("connection.zig"); + +const IoUring = std.os.linux.IoUring; + +const ServerCtx = sig.rpc.server.Context; +const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; +const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; +const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; + +/// A single request body cannot be larger than this; +/// a single chunk in a chunked request body cannot be larger than this, +/// but all together they may be allowed to be larger than this, +/// depending on the request. +pub const MAX_REQUEST_BODY_SIZE: usize = 50 * 1024; // 50 KiB + +/// All of the relevant information from a request head parsed into a narrow +/// format that is comprised of bounded data and can be copied by value. +pub const HeadInfo = struct { + method: std.http.Method, + target: TargetBoundedStr, + content_len: ?u64, + content_type: ?ContentType, + transfer_encoding: std.http.TransferEncoding, + content_encoding: std.http.ContentEncoding, + + const StdHead = std.http.Server.Request.Head; + + pub const ParseError = StdHead.ParseError || ParseFromStdHeadError; + + pub fn parse(head_bytes: []const u8) ParseError!HeadInfo { + const parsed_head = try StdHead.parse(head_bytes); + std.debug.assert(parsed_head.compression == .none); // at the time of writing, this always holds true for the result of `Head.parse`. + return try parseFromStdHead(parsed_head); + } + + pub const ParseFromStdHeadError = error{ + RequestTargetTooLong, + RequestContentTypeUnrecognized, + }; + + pub fn parseFromStdHead(std_head: StdHead) ParseFromStdHeadError!HeadInfo { + // TODO: should we care about these? + _ = std_head.version; + _ = std_head.expect; + _ = std_head.keep_alive; + + const target = TargetBoundedStr.fromSlice(std_head.target) catch + return error.RequestTargetTooLong; + + const content_type: ?ContentType = ct: { + const str = std_head.content_type orelse break :ct null; + break :ct std.meta.stringToEnum(ContentType, str) orelse + return error.RequestContentTypeUnrecognized; + }; + + return .{ + .method = std_head.method, + .target = target, + .content_len = std_head.content_length, + .content_type = content_type, + .transfer_encoding = std_head.transfer_encoding, + .content_encoding = std_head.transfer_compression, + }; + } +}; + +pub const ContentType = enum(u8) { + @"application/json", +}; + +pub const MAX_TARGET_LEN: usize = blk: { + const SnapSpec = IncrementalSnapshotFileInfo.SnapshotArchiveNameFmtSpec; + break :blk "/".len + SnapSpec.fmtLenValue(.{ + .base_slot = std.math.maxInt(sig.core.Slot), + .slot = std.math.maxInt(sig.core.Slot), + .hash = sig.core.Hash.base58String(.{ .data = .{255} ** sig.core.Hash.size }).constSlice(), + }); +}; +pub const TargetBoundedStr = std.BoundedArray(u8, MAX_TARGET_LEN); + +pub const GetRequestTargetResolved = union(enum) { + unrecognized, + full_snapshot: struct { FullSnapshotFileInfo, SnapshotReadLock }, + inc_snapshot: struct { IncrementalSnapshotFileInfo, SnapshotReadLock }, + + // TODO: also handle the snapshot archive aliases & other routes + + pub const SnapshotReadLock = sig.sync.RwMux(?SnapshotGenerationInfo).RLockGuard; +}; + +/// Resolve a `GET` request target. +pub fn getRequestTargetResolve( + logger: ServerCtx.ScopedLogger, + target: []const u8, + latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), +) GetRequestTargetResolved { + if (!std.mem.startsWith(u8, target, "/")) return .unrecognized; + const path = target[1..]; + + const is_snapshot_archive_like = + !std.meta.isError(FullSnapshotFileInfo.parseFileNameTarZst(path)) or + !std.meta.isError(IncrementalSnapshotFileInfo.parseFileNameTarZst(path)); + + if (is_snapshot_archive_like) { + // we hold the lock for the entirety of this process in order to prevent + // the snapshot generation process from deleting the associated snapshot. + const maybe_latest_snapshot_gen_info, // + var latest_snapshot_info_lg // + = latest_snapshot_gen_info_rw.readWithLock(); + errdefer latest_snapshot_info_lg.unlock(); + + const full_info: ?FullSnapshotFileInfo, // + const inc_info: ?IncrementalSnapshotFileInfo // + = blk: { + const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse + break :blk .{ null, null }; + const latest_full = latest_snapshot_gen_info.full; + const full_info: FullSnapshotFileInfo = .{ + .slot = latest_full.slot, + .hash = latest_full.hash, + }; + const latest_incremental = latest_snapshot_gen_info.inc orelse + break :blk .{ full_info, null }; + const inc_info: IncrementalSnapshotFileInfo = .{ + .base_slot = latest_full.slot, + .slot = latest_incremental.slot, + .hash = latest_incremental.hash, + }; + break :blk .{ full_info, inc_info }; + }; + + logger.debug().logf("Available full: {?s}", .{ + if (full_info) |info| info.snapshotArchiveName().constSlice() else null, + }); + logger.debug().logf("Available inc: {?s}", .{ + if (inc_info) |info| info.snapshotArchiveName().constSlice() else null, + }); + + if (full_info) |full| { + const full_archive_name_bounded = full.snapshotArchiveName(); + const full_archive_name = full_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, path, full_archive_name)) { + return .{ .full_snapshot = .{ full, latest_snapshot_info_lg } }; + } + } + + if (inc_info) |inc| { + const inc_archive_name_bounded = inc.snapshotArchiveName(); + const inc_archive_name = inc_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, path, inc_archive_name)) { + return .{ .inc_snapshot = .{ inc, latest_snapshot_info_lg } }; + } + } + } + + return .unrecognized; +} + +pub const HandleRequestError = + std.fs.File.OpenError || + std.http.Server.Response.WriteError || + std.fs.File.GetSeekPosError || + std.fs.File.ReadError; + +pub fn handleRequest( + logger: ServerCtx.ScopedLogger, + request: *std.http.Server.Request, + snapshot_dir: std.fs.Dir, + latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), +) !void { + const conn_address = request.server.connection.address; + logger.info().logf("Responding to request from {}: {} {s}", .{ + conn_address, methodFmt(request.head.method), request.head.target, + }); + + switch (request.head.method) { + .HEAD, .GET => switch (getRequestTargetResolve( + logger, + request.head.target, + latest_snapshot_gen_info_rw, + )) { + inline .full_snapshot, .inc_snapshot => |pair| { + const snap_info, var full_info_lg = pair; + defer full_info_lg.unlock(); + + const archive_name_bounded = snap_info.snapshotArchiveName(); + const archive_name = archive_name_bounded.constSlice(); + + const archive_file = try snapshot_dir.openFile(archive_name, .{}); + defer archive_file.close(); + + const archive_len = try archive_file.getEndPos(); + + var send_buffer: [4096]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .content_length = archive_len, + .respond_options = .{}, + }); + + if (!response.elide_body) { + // use a length which is still a multiple of 2, greater than the send_buffer length, + // in order to almost always force the http server method to flush, instead of + // pointlessly copying data into the send buffer. + const read_buffer_len = comptime std.mem.alignForward(usize, send_buffer.len + 1, 2); + var read_buffer: [read_buffer_len]u8 = undefined; + + while (true) { + const file_data_len = try archive_file.read(&read_buffer); + if (file_data_len == 0) break; + const file_data = read_buffer[0..file_data_len]; + try response.writeAll(file_data); + } + } else { + std.debug.assert(response.transfer_encoding.content_length == archive_len); + // NOTE: in order to avoid needing to actually spend time writing the response body, + // just trick the API into thinking we already wrote the entire thing by setting this + // to 0. + response.transfer_encoding.content_length = 0; + } + + try response.end(); + return; + }, + .unrecognized => {}, + }, + .POST => { + logger.err().logf("{} tried to invoke our RPC", .{conn_address}); + return try request.respond("RPCs are not yet implemented", .{ + .status = .service_unavailable, + .keep_alive = false, + }); + }, + else => {}, + } + + logger.err().logf( + "{} made an unrecognized request '{} {s}'", + .{ conn_address, methodFmt(request.head.method), request.head.target }, + ); + try request.respond("", .{ + .status = .not_found, + .keep_alive = false, + }); +} + +pub fn methodFmt(method: std.http.Method) MethodFmt { + return .{ .method = method }; +} + +pub const MethodFmt = struct { + method: std.http.Method, + pub fn format( + fmt: MethodFmt, + comptime fmt_str: []const u8, + fmt_options: std.fmt.FormatOptions, + writer: anytype, + ) @TypeOf(writer).Error!void { + _ = fmt_options; + if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt); + try fmt.method.write(writer); + } +}; diff --git a/src/utils/fmt.zig b/src/utils/fmt.zig index bfd739a57..e8a3f519b 100644 --- a/src/utils/fmt.zig +++ b/src/utils/fmt.zig @@ -36,7 +36,7 @@ pub fn BoundedSpec(comptime spec: []const u8) type { /// try expectEqual("foo-255".len, boundedLenValue("{[a]s}-{[b]d}", .{ .a = "foo", .b = 255 })); /// ``` pub inline fn fmtLenValue(comptime args_value: anytype) usize { - comptime return fmtLen(fmt_str, @TypeOf(args_value)); + comptime return fmtLen(@TypeOf(args_value)); } pub fn BoundedArray(comptime Args: type) type { diff --git a/src/utils/io.zig b/src/utils/io.zig index becea8b07..f2d72dcab 100644 --- a/src/utils/io.zig +++ b/src/utils/io.zig @@ -87,3 +87,87 @@ fn NarrowAnyStream(comptime Error: type) type { } }; } + +/// Writer which captures only an offset window of data into a buffer. +/// This can be useful for incrementally capturing formatted data. +pub const WindowedWriter = struct { + remaining_to_ignore: u64, + end_index: usize, + buffer: []u8, + + pub fn init( + buffer: []u8, + start_bytes_to_ignore: u64, + ) WindowedWriter { + std.debug.assert(buffer.len != 0); + return .{ + .remaining_to_ignore = start_bytes_to_ignore, + .end_index = 0, + .buffer = buffer, + }; + } + + pub fn reset(self: *WindowedWriter, start_bytes_to_ignore: usize) void { + self.remaining_to_ignore = start_bytes_to_ignore; + self.end_index = 0; + } + + pub fn write(self: *WindowedWriter, bytes: []const u8) void { + const bytes_to_skip = @min(self.remaining_to_ignore, bytes.len); + self.remaining_to_ignore -|= bytes.len; + + const src_target_bytes = bytes[bytes_to_skip..]; + const writable = self.buffer[self.end_index..]; + + const amt = @min(writable.len, src_target_bytes.len); + @memcpy(writable[0..amt], src_target_bytes[0..amt]); + self.end_index += amt; + } + + pub const Writer = std.io.GenericWriter(*WindowedWriter, error{}, writerFn); + pub fn writer(self: *WindowedWriter) Writer { + return .{ .context = self }; + } + + fn writerFn(self: *WindowedWriter, bytes: []const u8) error{}!usize { + self.write(bytes); + return bytes.len; + } +}; + +fn testWindowedWriter( + comptime kind: enum { bin, str }, + params: struct { start: usize, size: usize }, + data: []const u8, + expected: []const u8, +) !void { + const buffer = try std.testing.allocator.alloc(u8, params.size); + defer std.testing.allocator.free(buffer); + + var ww = WindowedWriter.init(buffer, params.start); + for (0..data.len) |split_i| { + ww.reset(params.start); + ww.write(data[0..split_i]); + ww.write(data[split_i..]); + try std.testing.expectEqual(expected.len, ww.end_index); + switch (kind) { + .bin => try std.testing.expectEqualSlices(u8, expected, ww.buffer), + .str => try std.testing.expectEqualStrings(expected, ww.buffer), + } + } +} + +test WindowedWriter { + try testWindowedWriter(.str, .{ .start = 0, .size = 3 }, "foo\n", "foo"); + try testWindowedWriter(.str, .{ .start = 1, .size = 2 }, "foo\n", "oo"); + try testWindowedWriter(.str, .{ .start = 1, .size = 1 }, "foo\n", "o"); + try testWindowedWriter(.str, .{ .start = 2, .size = 1 }, "foo\n", "o"); + + try testWindowedWriter(.str, .{ .start = 1, .size = 3 }, "foo\n", "oo\n"); + try testWindowedWriter(.str, .{ .start = 2, .size = 2 }, "foo\n", "o\n"); + + try testWindowedWriter(.str, .{ .start = 0, .size = 1 }, "foo\n", "f"); + try testWindowedWriter(.str, .{ .start = 1, .size = 1 }, "foo\n", "o"); + try testWindowedWriter(.str, .{ .start = 2, .size = 1 }, "foo\n", "o"); + try testWindowedWriter(.str, .{ .start = 3, .size = 1 }, "foo\n", "\n"); +} From d9c53115d8076a29bd036d9065d18a8fe076991b Mon Sep 17 00:00:00 2001 From: Trevor Berrange Sanchez Date: Mon, 20 Jan 2025 20:23:20 +0100 Subject: [PATCH 2/5] Don't use file-as-struct --- src/rpc/server.zig | 2 +- .../{LinuxIoUring.zig => linux_io_uring.zig} | 271 +++++++++--------- 2 files changed, 138 insertions(+), 135 deletions(-) rename src/rpc/server/{LinuxIoUring.zig => linux_io_uring.zig} (80%) diff --git a/src/rpc/server.zig b/src/rpc/server.zig index 3c66d1e64..17282fd05 100644 --- a/src/rpc/server.zig +++ b/src/rpc/server.zig @@ -100,7 +100,7 @@ pub const WorkPool = union(enum) { .no => noreturn, }, - pub const LinuxIoUring = @import("server/LinuxIoUring.zig"); + pub const LinuxIoUring = @import("server/linux_io_uring.zig").LinuxIoUring; const BasicAASCError = connection.AcceptHandledError || diff --git a/src/rpc/server/LinuxIoUring.zig b/src/rpc/server/linux_io_uring.zig similarity index 80% rename from src/rpc/server/LinuxIoUring.zig rename to src/rpc/server/linux_io_uring.zig index e4992e09d..4d35529cb 100644 --- a/src/rpc/server/LinuxIoUring.zig +++ b/src/rpc/server/linux_io_uring.zig @@ -8,155 +8,158 @@ const requests = @import("requests.zig"); const IoUring = std.os.linux.IoUring; const ServerCtx = sig.rpc.server.Context; -const LinuxIoUring = @This(); -io_uring: IoUring, -multishot_accept_submitted: bool, -pending_cqes_count: u8, -pending_cqes_buf: [255]std.os.linux.io_uring_cqe, - -pub const can_use: enum { no, yes, check } = switch (builtin.os.getVersionRange()) { - .linux => |version| can_use: { - const min_version: std.SemanticVersion = .{ .major = 6, .minor = 0, .patch = 0 }; - const is_at_least = version.isAtLeast(min_version) orelse break :can_use .check; - break :can_use if (is_at_least) .yes else .no; - }, - else => .no, -}; - -pub const InitError = std.posix.MMapError || error{ - EntriesZero, - EntriesNotPowerOfTwo, +pub const LinuxIoUring = struct { + io_uring: IoUring, + multishot_accept_submitted: bool, + pending_cqes_count: u8, + pending_cqes_buf: [255]std.os.linux.io_uring_cqe, + + pub const can_use: enum { no, yes, check } = switch (builtin.os.getVersionRange()) { + .linux => |version| can_use: { + const min_version: std.SemanticVersion = .{ .major = 6, .minor = 0, .patch = 0 }; + const is_at_least = version.isAtLeast(min_version) orelse break :can_use .check; + break :can_use if (is_at_least) .yes else .no; + }, + else => .no, + }; - ParamsOutsideAccessibleAddressSpace, - ArgumentsInvalid, - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - SystemResources, + pub const InitError = std.posix.MMapError || error{ + EntriesZero, + EntriesNotPowerOfTwo, - PermissionDenied, - SystemOutdated, -}; + ParamsOutsideAccessibleAddressSpace, + ArgumentsInvalid, + ProcessFdQuotaExceeded, + SystemFdQuotaExceeded, + SystemResources, -// NOTE(ink): constructing the return type as `E!?T`, where `E` and `T` are resolved -// separately seems to help ZLS with understanding the types involved better, which is -// why I've done it like that here. If ZLS gets smarter in the future, you could probably -// inline this into a single branch in the return type expression. -const InitErrOrEmpty = if (can_use == .no) error{} else InitError; -const InitResultOrNoreturn = if (can_use == .no) noreturn else LinuxIoUring; -pub fn init() InitErrOrEmpty!?InitResultOrNoreturn { - const need_runtime_check = switch (can_use) { - .no => return null, - .yes => false, - .check => true, + PermissionDenied, + SystemOutdated, }; - var io_uring = IoUring.init(4096, 0) catch |err| return switch (err) { - error.SystemOutdated, - error.PermissionDenied, - => |e| if (!need_runtime_check) e else return null, - else => |e| e, - }; - errdefer io_uring.deinit(); + // NOTE(ink): constructing the return type as `E!?T`, where `E` and `T` are resolved + // separately seems to help ZLS with understanding the types involved better, which is + // why I've done it like that here. If ZLS gets smarter in the future, you could probably + // inline this into a single branch in the return type expression. + const InitErrOrEmpty = if (can_use == .no) error{} else InitError; + const InitResultOrNoreturn = if (can_use == .no) noreturn else LinuxIoUring; + pub fn init() InitErrOrEmpty!?InitResultOrNoreturn { + const need_runtime_check = switch (can_use) { + .no => return null, + .yes => false, + .check => true, + }; - return .{ - .io_uring = io_uring, - .multishot_accept_submitted = false, - .pending_cqes_count = 0, - .pending_cqes_buf = undefined, - }; -} + var io_uring = IoUring.init(4096, 0) catch |err| return switch (err) { + error.SystemOutdated, + error.PermissionDenied, + => |e| if (!need_runtime_check) e else return null, + else => |e| e, + }; + errdefer io_uring.deinit(); -pub fn deinit(self: *LinuxIoUring) void { - self.io_uring.deinit(); -} + return .{ + .io_uring = io_uring, + .multishot_accept_submitted = false, + .pending_cqes_count = 0, + .pending_cqes_buf = undefined, + }; + } -pub const AcceptAndServeConnectionsError = error{ - /// This was the first call, and we failed to prep, queue, and submit the multishot accept. - FailedToAcceptMultishot, - SubmissionQueueFull, -} || IouSubmitError || - HandleOurCqeError || - std.mem.Allocator.Error; + pub fn deinit(self: *LinuxIoUring) void { + self.io_uring.deinit(); + } -pub fn acceptAndServeConnections( - self: *LinuxIoUring, - server_ctx: *ServerCtx, -) AcceptAndServeConnectionsError!void { - if (!self.multishot_accept_submitted) { - self.multishot_accept_submitted = true; - errdefer self.multishot_accept_submitted = false; - _ = self.io_uring.accept_multishot( - @bitCast(Entry.ACCEPT), - server_ctx.tcp.stream.handle, - null, - null, - std.os.linux.SOCK.CLOEXEC, - ) catch |err| return switch (err) { - error.SubmissionQueueFull => { - server_ctx.logger.err().log( - "Under normal circumstances the accept_multishot would be" ++ - " the first SQE to be queued, but somehow the queue was full.", - ); + pub const AcceptAndServeConnectionsError = error{ + /// This was the first call, and we failed to prep, queue, and submit the multishot accept. + FailedToAcceptMultishot, + SubmissionQueueFull, + } || IouSubmitError || + HandleOurCqeError || + std.mem.Allocator.Error; + + pub fn acceptAndServeConnections( + self: *LinuxIoUring, + server_ctx: *ServerCtx, + ) AcceptAndServeConnectionsError!void { + if (!self.multishot_accept_submitted) { + self.multishot_accept_submitted = true; + errdefer self.multishot_accept_submitted = false; + _ = self.io_uring.accept_multishot( + @bitCast(Entry.ACCEPT), + server_ctx.tcp.stream.handle, + null, + null, + std.os.linux.SOCK.CLOEXEC, + ) catch |err| return switch (err) { + error.SubmissionQueueFull => { + server_ctx.logger.err().log( + "Under normal circumstances the accept_multishot would be" ++ + " the first SQE to be queued, but somehow the queue was full.", + ); + return error.FailedToAcceptMultishot; + }, + }; + if (try self.io_uring.submit() != 1) { return error.FailedToAcceptMultishot; - }, - }; - if (try self.io_uring.submit() != 1) { - return error.FailedToAcceptMultishot; + } + return; } - return; - } - _ = try self.io_uring.submit(); + _ = try self.io_uring.submit(); - if (self.pending_cqes_count != self.pending_cqes_buf.len) { - self.pending_cqes_count += @intCast(try self.io_uring.copy_cqes(self.pending_cqes_buf[self.pending_cqes_count..], 0)); - } - const cqes_pending = self.pending_cqes_buf[0..self.pending_cqes_count]; - - for (cqes_pending, 0..) |raw_cqe, i| { - self.pending_cqes_count -= 1; - errdefer std.mem.copyForwards( - std.os.linux.io_uring_cqe, - self.pending_cqes_buf[0..self.pending_cqes_count], - self.pending_cqes_buf[i + 1 ..][0..self.pending_cqes_count], - ); - const our_cqe = OurCqe.fromCqe(raw_cqe); - consumeOurCqe(self, server_ctx, our_cqe) catch |err| switch (err) { - // connection errors - error.ConnectionAborted, - error.ConnectionRefused, - error.ConnectionResetByPeer, - error.ConnectionTimedOut, - - // our http parse errors - error.RequestHeadersTooBig, - error.RequestTargetTooLong, - error.RequestContentTypeUnrecognized, - - // std http parse errors - error.UnknownHttpMethod, - error.HttpHeadersInvalid, - error.InvalidContentLength, - error.HttpHeaderContinuationsUnsupported, - error.HttpTransferEncodingUnsupported, - error.HttpConnectionHeaderUnsupported, - error.CompressionUnsupported, - error.MissingFinalNewline, - - // splice errors - error.BadFileDescriptors, - error.BadFdOffset, - error.InvalidSplice, - => |e| { - server_ctx.logger.err().logf("{s}", .{@errorName(e)}); - continue; - }, + if (self.pending_cqes_count != self.pending_cqes_buf.len) { + const unused = self.pending_cqes_buf[self.pending_cqes_count..]; + const new_cqe_count = try self.io_uring.copy_cqes(unused, 0); + self.pending_cqes_count += @intCast(new_cqe_count); + } + const cqes_pending = self.pending_cqes_buf[0..self.pending_cqes_count]; + + for (cqes_pending, 0..) |raw_cqe, i| { + self.pending_cqes_count -= 1; + errdefer std.mem.copyForwards( + std.os.linux.io_uring_cqe, + self.pending_cqes_buf[0..self.pending_cqes_count], + self.pending_cqes_buf[i + 1 ..][0..self.pending_cqes_count], + ); + const our_cqe = OurCqe.fromCqe(raw_cqe); + consumeOurCqe(self, server_ctx, our_cqe) catch |err| switch (err) { + // connection errors + error.ConnectionAborted, + error.ConnectionRefused, + error.ConnectionResetByPeer, + error.ConnectionTimedOut, + + // our http parse errors + error.RequestHeadersTooBig, + error.RequestTargetTooLong, + error.RequestContentTypeUnrecognized, + + // std http parse errors + error.UnknownHttpMethod, + error.HttpHeadersInvalid, + error.InvalidContentLength, + error.HttpHeaderContinuationsUnsupported, + error.HttpTransferEncodingUnsupported, + error.HttpConnectionHeaderUnsupported, + error.CompressionUnsupported, + error.MissingFinalNewline, + + // splice errors + error.BadFileDescriptors, + error.BadFdOffset, + error.InvalidSplice, + => |e| { + server_ctx.logger.err().logf("{s}", .{@errorName(e)}); + continue; + }, - error.SubmissionQueueFull => |e| return e, - else => |e| return e, - }; + error.SubmissionQueueFull => |e| return e, + else => |e| return e, + }; + } } -} +}; const HandleOurCqeError = error{ SubmissionQueueFull, From 6679100b7d34e24253c0707ec25cd981253368dd Mon Sep 17 00:00:00 2001 From: Trevor Berrange Sanchez Date: Mon, 20 Jan 2025 20:30:17 +0100 Subject: [PATCH 3/5] Run style script, respect line length limit --- src/rpc/server.zig | 4 ++-- src/rpc/server/linux_io_uring.zig | 30 +++++++++++++++++++++++------- src/rpc/server/requests.zig | 11 +++++++---- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/rpc/server.zig b/src/rpc/server.zig index 17282fd05..048992a23 100644 --- a/src/rpc/server.zig +++ b/src/rpc/server.zig @@ -5,7 +5,6 @@ const sig = @import("../sig.zig"); const connection = @import("server/connection.zig"); const requests = @import("server/requests.zig"); -const IoUring = std.os.linux.IoUring; const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; pub const Context = struct { @@ -224,7 +223,8 @@ test Context { for ([_]?WorkPool{ .basic, - if (maybe_liou != null) .{ .linux_io_uring = &maybe_liou.? } else null, // TODO: see above TODO about `if (a) |*b|` on `?noreturn`. + // TODO: see above TODO about `if (a) |*b|` on `?noreturn`. + if (maybe_liou != null) .{ .linux_io_uring = &maybe_liou.? } else null, }) |maybe_work_pool| { const work_pool = maybe_work_pool orelse continue; diff --git a/src/rpc/server/linux_io_uring.zig b/src/rpc/server/linux_io_uring.zig index 4d35529cb..787cf088a 100644 --- a/src/rpc/server/linux_io_uring.zig +++ b/src/rpc/server/linux_io_uring.zig @@ -206,7 +206,10 @@ fn consumeOurCqe( switch (try connection.handleAcceptResult(cqe.err())) { .success => {}, - .intr => std.debug.panic("TODO: does this mean the multishot accept has stopped? If no, just warn. If yes, re-queue here and warn.", .{}), // TODO: + + // TODO: does this mean the multishot accept has stopped? If no, just warn. If yes, re-queue here and warn. + .intr => std.debug.panic("TODO:", .{}), + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), .conn_aborted => return error.ConnectionAborted, } @@ -243,7 +246,8 @@ fn consumeOurCqe( server_ctx.logger.err().logf( "Failed to submit the SQE for the initial recv" ++ " for the connection from '{!}'", - .{connection.getSockName(stream.handle)}, // if we fail to getSockName, just print the error in place of the address + // if we fail to getSockName, just print the error in place of the address + .{connection.getSockName(stream.handle)}, ); return e; }, @@ -301,7 +305,8 @@ fn consumeOurCqe( const head_info: HeadInfo = head_info: { const head_bytes = entry_data.buffer[0..head.end]; const std_head = try std.http.Server.Request.Head.parse(head_bytes); - std.debug.assert(std_head.compression == .none); // at the time of writing, this always holds true for the result of `Head.parse`. + // at the time of writing, this always holds true for the result of `Head.parse`. + std.debug.assert(std_head.compression == .none); break :head_info HeadInfo.parseFromStdHead(std_head) catch |err| switch (err) { error.RequestTargetTooLong => |e| { err_logger.logf("Request target was too long: '{}'", .{ @@ -340,8 +345,13 @@ fn consumeOurCqe( switch (try connection.handleRecvResult(cqe.err())) { .success => {}, - .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: - .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + // TODO: how to handle interrupts on this? + .intr => std.debug.panic("TODO:", .{}), + + .again => std.debug.panic( + "The socket should not be in nonblocking mode.", + .{}, + ), .conn_refused => return error.ConnectionRefused, .conn_reset => return error.ConnectionResetByPeer, @@ -386,7 +396,10 @@ fn consumeOurCqe( .to_pipe => { switch (try connection.handleSpliceResult(cqe.err())) { .success => {}, - .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + .again => std.debug.panic( + "The socket should not be in nonblocking mode.", + .{}, + ), } sfb.spliced_to_pipe += @intCast(cqe.res); @@ -398,7 +411,10 @@ fn consumeOurCqe( .to_socket => { switch (try connection.handleSpliceResult(cqe.err())) { .success => {}, - .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + .again => std.debug.panic( + "The socket should not be in nonblocking mode.", + .{}, + ), } sfb.spliced_to_socket += @intCast(cqe.res); diff --git a/src/rpc/server/requests.zig b/src/rpc/server/requests.zig index 8d9e15f0f..b98fa0a05 100644 --- a/src/rpc/server/requests.zig +++ b/src/rpc/server/requests.zig @@ -6,8 +6,6 @@ const std = @import("std"); const sig = @import("../../sig.zig"); const connection = @import("connection.zig"); -const IoUring = std.os.linux.IoUring; - const ServerCtx = sig.rpc.server.Context; const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; @@ -35,7 +33,8 @@ pub const HeadInfo = struct { pub fn parse(head_bytes: []const u8) ParseError!HeadInfo { const parsed_head = try StdHead.parse(head_bytes); - std.debug.assert(parsed_head.compression == .none); // at the time of writing, this always holds true for the result of `Head.parse`. + // at the time of writing, this always holds true for the result of `Head.parse`. + std.debug.assert(parsed_head.compression == .none); return try parseFromStdHead(parsed_head); } @@ -208,7 +207,11 @@ pub fn handleRequest( // use a length which is still a multiple of 2, greater than the send_buffer length, // in order to almost always force the http server method to flush, instead of // pointlessly copying data into the send buffer. - const read_buffer_len = comptime std.mem.alignForward(usize, send_buffer.len + 1, 2); + const read_buffer_len = comptime std.mem.alignForward( + usize, + send_buffer.len + 1, + 2, + ); var read_buffer: [read_buffer_len]u8 = undefined; while (true) { From 2867ffc33eb1718ecff3143a68bdfcc1f6b319f1 Mon Sep 17 00:00:00 2001 From: Trevor Berrange Sanchez Date: Tue, 21 Jan 2025 12:48:09 +0100 Subject: [PATCH 4/5] Improve accept failure handling & update TODOs * Handle potentially failing/cancelling of `accept_multishot` by re-queueing it, based on the `IORING_CQE_F_MORE` flag. * Revise/simplify the queueing logic for the `accept_multishot` SQE. * Resolve the EINTR TODO panics, returning a catch-all error value indicating it as a bad but non-critical error. * Update the `a: ?noreturn` `if (a) |*b|` TODO, adding that it's solved in 0.14; it should be resolved after we update to 0.14. * Unify EAGAIN panic message. --- src/rpc/server.zig | 4 +- src/rpc/server/linux_io_uring.zig | 67 +++++++++++++++---------------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/src/rpc/server.zig b/src/rpc/server.zig index 048992a23..faf92b0f8 100644 --- a/src/rpc/server.zig +++ b/src/rpc/server.zig @@ -211,8 +211,8 @@ test Context { defer rpc_server_ctx.joinDeinit(); var maybe_liou = try WorkPool.LinuxIoUring.init(); - // TODO: currently `if (a) |*b|` on `?noreturn` causes analysis of the unwrap block, even though `if (a) |b|` doesn't. - // Filed a bug report for this: https://github.com/ziglang/zig/issues/22556, has a linked PR to fix it; hopefully should be fixed in 0.14. + // TODO: currently `if (a) |*b|` on `a: ?noreturn` causes analysis of + // the unwrap block, even though `if (a) |b|` doesn't; fixed in 0.14 defer if (maybe_liou != null) maybe_liou.?.deinit(); const self_url_bounded_str = sig.utils.fmt.boundedFmt( diff --git a/src/rpc/server/linux_io_uring.zig b/src/rpc/server/linux_io_uring.zig index 787cf088a..7b50a035f 100644 --- a/src/rpc/server/linux_io_uring.zig +++ b/src/rpc/server/linux_io_uring.zig @@ -100,10 +100,6 @@ pub const LinuxIoUring = struct { return error.FailedToAcceptMultishot; }, }; - if (try self.io_uring.submit() != 1) { - return error.FailedToAcceptMultishot; - } - return; } _ = try self.io_uring.submit(); @@ -124,6 +120,9 @@ pub const LinuxIoUring = struct { ); const our_cqe = OurCqe.fromCqe(raw_cqe); consumeOurCqe(self, server_ctx, our_cqe) catch |err| switch (err) { + // EINTR catch-all + error.SignalInterruptedOperation, + // connection errors error.ConnectionAborted, error.ConnectionRefused, @@ -164,6 +163,12 @@ pub const LinuxIoUring = struct { const HandleOurCqeError = error{ SubmissionQueueFull, + /// Operation resulted in EINTR. In general there doesn't seem to be a very good way to handle or recover from interruptions, + /// so we just fail and drop whatever connection is interrupted; however in theory, this should not be a huge issue in practice, + /// with the assumption being that the RPC server will not be running in a process/thread that will be interrupted often enough + /// for this to be a problem. + SignalInterruptedOperation, + /// Connection was aborted; not necessarily critical. ConnectionAborted, /// A remote host refused to allow the network connection, typically because it is not @@ -189,6 +194,12 @@ const HandleOurCqeError = error{ std.fs.File.OpenError || std.fs.File.GetSeekPosError; +/// Panic message for handling `EAGAIN`; we're not using nonblocking sockets at all, +/// so it should be impossible to receive that error, or for such an error to be +/// triggered just from malicious connections. +const EAGAIN_PANIC_MSG = + "The socket should not be in nonblocking mode; server or socket configuration error."; + /// On return, `cqe.user_data` is in an undefined state - this is to say, /// it has either already been `deinit`ed, or it has been been re-submitted /// in a new `SQE` and should not be modified; in either scenario, the caller @@ -202,15 +213,16 @@ fn consumeOurCqe( errdefer entry.deinit(server_ctx.allocator); const entry_data: *EntryData = entry.ptr orelse { - // multishot accept cqe + // `accept_multishot` cqe + + // we may need to re-submit the `accept_multishot` sqe. + const accept_cancelled = cqe.flags & std.os.linux.IORING_CQE_F_MORE == 0; + if (accept_cancelled) liou.multishot_accept_submitted = false; switch (try connection.handleAcceptResult(cqe.err())) { .success => {}, - - // TODO: does this mean the multishot accept has stopped? If no, just warn. If yes, re-queue here and warn. - .intr => std.debug.panic("TODO:", .{}), - - .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + .again => std.debug.panic(EAGAIN_PANIC_MSG, .{}), + .intr => return error.SignalInterruptedOperation, .conn_aborted => return error.ConnectionAborted, } @@ -268,10 +280,8 @@ fn consumeOurCqe( .recv_head => |*head| { switch (try connection.handleRecvResult(cqe.err())) { .success => {}, - - .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: - .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), - + .again => std.debug.panic(EAGAIN_PANIC_MSG, .{}), + .intr => return error.SignalInterruptedOperation, .conn_refused => return error.ConnectionRefused, .conn_reset => return error.ConnectionResetByPeer, .timed_out => return error.ConnectionTimedOut, @@ -344,15 +354,8 @@ fn consumeOurCqe( if (body.need_to_check_cqe) { switch (try connection.handleRecvResult(cqe.err())) { .success => {}, - - // TODO: how to handle interrupts on this? - .intr => std.debug.panic("TODO:", .{}), - - .again => std.debug.panic( - "The socket should not be in nonblocking mode.", - .{}, - ), - + .again => std.debug.panic(EAGAIN_PANIC_MSG, .{}), + .intr => return error.SignalInterruptedOperation, .conn_refused => return error.ConnectionRefused, .conn_reset => return error.ConnectionResetByPeer, .timed_out => return error.ConnectionTimedOut, @@ -369,8 +372,8 @@ fn consumeOurCqe( .send_file_head => |*sfh| { switch (try connection.handleSendResult(cqe.err())) { .success => {}, - .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: - .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + .again => std.debug.panic(EAGAIN_PANIC_MSG, .{}), + .intr => return error.SignalInterruptedOperation, } const sent_len: usize = @intCast(cqe.res); sfh.sent_bytes += sent_len; @@ -396,10 +399,7 @@ fn consumeOurCqe( .to_pipe => { switch (try connection.handleSpliceResult(cqe.err())) { .success => {}, - .again => std.debug.panic( - "The socket should not be in nonblocking mode.", - .{}, - ), + .again => std.debug.panic(EAGAIN_PANIC_MSG, .{}), } sfb.spliced_to_pipe += @intCast(cqe.res); @@ -411,10 +411,7 @@ fn consumeOurCqe( .to_socket => { switch (try connection.handleSpliceResult(cqe.err())) { .success => {}, - .again => std.debug.panic( - "The socket should not be in nonblocking mode.", - .{}, - ), + .again => std.debug.panic(EAGAIN_PANIC_MSG, .{}), } sfb.spliced_to_socket += @intCast(cqe.res); @@ -433,8 +430,8 @@ fn consumeOurCqe( .send_no_body => |*snb| { switch (try connection.handleSendResult(cqe.err())) { .success => {}, - .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: - .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + .again => std.debug.panic(EAGAIN_PANIC_MSG, .{}), + .intr => return error.SignalInterruptedOperation, } const sent_len: usize = @intCast(cqe.res); snb.end_index += sent_len; From 922a7bb19504ed99279782cf23e0f31d53dbb0cf Mon Sep 17 00:00:00 2001 From: Trevor Berrange Sanchez Date: Tue, 21 Jan 2025 13:08:46 +0100 Subject: [PATCH 5/5] Add TODO to remove hacky-ish workaround --- src/rpc/server.zig | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rpc/server.zig b/src/rpc/server.zig index faf92b0f8..cd10f4696 100644 --- a/src/rpc/server.zig +++ b/src/rpc/server.zig @@ -233,6 +233,10 @@ test Context { defer blk: { exit.store(true, .release); // send a dummy request so that the serve thread will get the accept and observe `exit`. + // TODO(ink): we only issue interruptSelf for the basic work pool, and not the io_uring one, + // because for some reason that causes it to hang. this is kinda nasty and it would be + // nice for this to Just Work, however I suspect it may have something to do with the + // fact that it's the process sending itself a connection multiple times. if (work_pool == .basic) interruptSelf(allocator, self_uri) catch |err| { if (@errorReturnTrace()) |st| { std.log.err("{s}\n{}", .{ @errorName(err), st });