diff --git a/build.zig.zon b/build.zig.zon index f4ab3ef60..5afae6cc5 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -28,8 +28,8 @@ .hash = "1220fe113318a795d366cd63d2d3f1abacfaa6ff0739bc3240eaf7b68b06d21e4917", }, .zstd = .{ - .url = "https://github.com/Syndica/zstd.zig/archive/d9f86cad45380bf6b5d1e83ac8514cbddf3c4a38.tar.gz", - .hash = "1220be2064e0060e3836ee349595119c9d5874b58ee76147897d4070206f9e9e5db6", + .url = "https://github.com/Syndica/zstd.zig/archive/a052e839a3dfc44feb00c2eb425815baa3c76e0d.tar.gz", + .hash = "122001d56e43ef94e31243739ae83d7508abf0b8102795aff1ac91446e7ff450d875", }, .curl = .{ .url = "https://github.com/jiacai2050/zig-curl/archive/7b1e1c6adb1daca48bbae6bd18fc1386dc6076ab.tar.gz", diff --git a/src/accountsdb/snapshots.zig b/src/accountsdb/snapshots.zig index 4346ba5da..f39b015e3 100644 --- a/src/accountsdb/snapshots.zig +++ b/src/accountsdb/snapshots.zig @@ -1016,6 +1016,7 @@ pub fn parallelUnpackZstdTarBall( 0, ); var tar_stream = try ZstdReader.init(memory); + defer tar_stream.deinit(); const n_files_estimate: usize = if (full_snapshot) 421_764 else 100_000; // estimate try parallelUntarToFileSystem( diff --git a/src/cmd/cmd.zig b/src/cmd/cmd.zig index efb819950..169f84d6b 100644 --- a/src/cmd/cmd.zig +++ b/src/cmd/cmd.zig @@ -39,7 +39,7 @@ const requestIpEcho = @import("../net/echo.zig").requestIpEcho; const servePrometheus = @import("../prometheus/http.zig").servePrometheus; const parallelUnpackZstdTarBall = @import("../accountsdb/snapshots.zig").parallelUnpackZstdTarBall; const downloadSnapshotsFromGossip = @import("../accountsdb/download.zig").downloadSnapshotsFromGossip; -const SOCKET_TIMEOUT = @import("../net/socket_utils.zig").SOCKET_TIMEOUT; +const SOCKET_TIMEOUT_US = @import("../net/socket_utils.zig").SOCKET_TIMEOUT_US; const config = @import("config.zig"); // var validator_config = config.current; @@ -376,6 +376,7 @@ fn gossip() !void { &.{}, ); defer gossip_service.deinit(); + try runGossipWithConfigValues(&gossip_service); } @@ -411,7 +412,7 @@ fn validator() !void { // repair var repair_socket = try Socket.create(network.AddressFamily.ipv4, network.Protocol.udp); try repair_socket.bindToPort(repair_port); - try repair_socket.setReadTimeout(SOCKET_TIMEOUT); + try repair_socket.setReadTimeout(SOCKET_TIMEOUT_US); var repair_svc = try initRepair( logger, @@ -571,8 +572,16 @@ fn initRepair( } fn runGossipWithConfigValues(gossip_service: *GossipService) !void { + // TODO: use better allocator, unless GPA becomes more performant. + var gp_message_allocator: std.heap.GeneralPurposeAllocator(.{}) = .{}; + defer _ = gp_message_allocator.deinit(); + const gossip_config = config.current.gossip; - return gossip_service.run(gossip_config.spy_node, gossip_config.dump); + return gossip_service.run(.{ + .message_allocator = gp_message_allocator.allocator(), + .spy_node = gossip_config.spy_node, + .dump = gossip_config.dump, + }); } /// determine our shred version and ip. in the solana-labs client, the shred version diff --git a/src/gossip/fuzz.zig b/src/gossip/fuzz.zig index 7cff87e71..7f5581514 100644 --- a/src/gossip/fuzz.zig +++ b/src/gossip/fuzz.zig @@ -308,7 +308,13 @@ pub fn run() !void { .noop, ); - var fuzz_handle = try std.Thread.spawn(.{}, GossipService.run, .{ &gossip_service_fuzzer, true, false }); + const fuzz_handle = try std.Thread.spawn(.{}, GossipService.run, .{ + &gossip_service_fuzzer, .{ + .message_allocator = allocator, + .spy_node = true, + .dump = false, + }, + }); const SLEEP_TIME = 0; // const SLEEP_TIME = std.time.ns_per_ms * 10; diff --git a/src/gossip/service.zig b/src/gossip/service.zig index 68e5d0940..6c7159bc3 100644 --- a/src/gossip/service.zig +++ b/src/gossip/service.zig @@ -1,8 +1,10 @@ const std = @import("std"); +const builtin = @import("builtin"); const network = @import("zig-network"); const EndPoint = network.EndPoint; const Packet = @import("../net/packet.zig").Packet; const PACKET_DATA_SIZE = @import("../net/packet.zig").PACKET_DATA_SIZE; +const ThreadPoolTask = @import("../utils/thread.zig").ThreadPoolTask; const ThreadPool = @import("../sync/thread_pool.zig").ThreadPool; const Task = ThreadPool.Task; const Batch = ThreadPool.Batch; @@ -88,6 +90,8 @@ const DEFAULT_EPOCH_DURATION: u64 = 172800000; pub const PUB_GOSSIP_STATS_INTERVAL_MS = 2 * std.time.ms_per_s; pub const GOSSIP_TRIM_INTERVAL_MS = 10 * std.time.ms_per_s; +pub const GOSSIP_VERIFY_PACKET_PARALLEL_TASKS = 4; + pub const GossipService = struct { allocator: std.mem.Allocator, @@ -123,10 +127,10 @@ pub const GossipService = struct { stats: GossipStats, - const Entrypoint = struct { addr: SocketAddr, info: ?ContactInfo = null }; - const Self = @This(); + const Entrypoint = struct { addr: SocketAddr, info: ?ContactInfo = null }; + pub fn init( allocator: std.mem.Allocator, my_contact_info: ContactInfo, @@ -164,7 +168,7 @@ pub const GossipService = struct { const gossip_address = my_contact_info.getSocket(socket_tag.GOSSIP) orelse return error.GossipAddrUnspecified; var gossip_socket = UdpSocket.create(.ipv4, .udp) catch return error.SocketCreateFailed; gossip_socket.bindToPort(gossip_address.port()) catch return error.SocketBindFailed; - gossip_socket.setReadTimeout(socket_utils.SOCKET_TIMEOUT) catch return error.SocketSetTimeoutFailed; // 1 second + gossip_socket.setReadTimeout(socket_utils.SOCKET_TIMEOUT_US) catch return error.SocketSetTimeoutFailed; // 1 second const failed_pull_hashes = HashTimeQueue.init(allocator); const push_msg_q = ArrayList(SignedGossipData).init(allocator); @@ -186,7 +190,7 @@ pub const GossipService = struct { GOSSIP_PING_CACHE_CAPACITY, ); - return Self{ + return .{ .my_contact_info = my_contact_info, .my_keypair = my_keypair, .my_pubkey = my_pubkey, @@ -258,12 +262,40 @@ pub const GossipService = struct { deinitMux(&self.failed_pull_hashes_mux); } - /// these threads should run forever - so if they join - somethings wrong - /// and we should shutdown - fn joinAndExit(self: *Self, handle: *std.Thread) void { - handle.join(); - self.exit.store(true, .unordered); - } + pub const RunHandles = struct { + exit: *AtomicBool, + receiver_thread: std.Thread, + packet_verifier_thread: std.Thread, + message_processor_thread: std.Thread, + message_builder_thread: ?std.Thread, + responder_thread: std.Thread, + dumper_thread: ?std.Thread, + + /// If any of the threads join, all other threads will be signalled to join. + pub fn joinAndExit(handles: RunHandles) void { + inline for (@typeInfo(RunHandles).Struct.fields, 0..) |field, i| cont: { + comptime if (@field(std.meta.FieldEnum(RunHandles), field.name) == .exit) { + std.debug.assert(field.type == *AtomicBool); + continue; + }; + const maybe_thread: ?std.Thread = @field(handles, field.name); + const thread = maybe_thread orelse break :cont; + thread.join(); // if we end up joining, something's gone wrong, so signal exit + if (i == 0) handles.exit.store(true, .unordered); + } + } + }; + + pub const RunThreadsParams = struct { + /// Allocator used to allocate message metadata. + /// Helpful to use a dedicated allocator to reduce contention + /// during message allocation & deallocation. + /// Should be thread safe, and remain valid until calling `joinAll` on the result. + message_allocator: std.mem.Allocator, + + spy_node: bool, + dump: bool, + }; /// spawns required threads for the gossip serivce. /// including: @@ -273,67 +305,92 @@ pub const GossipService = struct { /// 4) build message loop (to send outgoing message) (only active if not a spy node) /// 5) a socket responder (to send outgoing packets) /// 6) echo server - pub fn run(self: *Self, spy_node: bool, dump: bool) !void { + pub fn runThreads( + self: *Self, + params: RunThreadsParams, + ) std.Thread.SpawnError!RunHandles { + const message_allocator = params.message_allocator; + const spy_node = params.spy_node; + const dump = params.dump; + // TODO(Ahmad): need new server impl, for now we don't join server thread // because http.zig's server doesn't stop when you call server.stop() - it's broken // const echo_server_thread = try self.echo_server.listenAndServe(); // _ = echo_server_thread; - var receiver_handle = try Thread.spawn(.{}, socket_utils.readSocket, .{ + const exitAndJoin = struct { + inline fn exitAndJoin(exit: *AtomicBool, thread: std.Thread) void { + exit.store(true, .unordered); + thread.join(); + } + }.exitAndJoin; + + const receiver_thread = try Thread.spawn(.{}, socket_utils.readSocket, .{ self.allocator, &self.gossip_socket, self.packet_incoming_channel, self.exit, self.logger, }); - defer self.joinAndExit(&receiver_handle); + errdefer exitAndJoin(self.exit, receiver_thread); - var packet_verifier_handle = try Thread.spawn(.{}, verifyPackets, .{self}); - defer self.joinAndExit(&packet_verifier_handle); + const packet_verifier_thread = try Thread.spawn(.{}, verifyPackets, .{ self, message_allocator }); + errdefer exitAndJoin(self.exit, packet_verifier_thread); - var packet_handle = try Thread.spawn(.{}, processMessages, .{self}); - defer self.joinAndExit(&packet_handle); + const message_processor_thread = try Thread.spawn(.{}, processMessages, .{ self, message_allocator }); + errdefer exitAndJoin(self.exit, message_processor_thread); - var maybe_build_messages_handle = if (!spy_node) try Thread.spawn(.{}, buildMessages, .{self}) else null; - defer { - if (maybe_build_messages_handle) |*handle| { - self.joinAndExit(handle); - } - } + const maybe_message_builder_thread: ?std.Thread = if (!spy_node) try Thread.spawn(.{}, buildMessages, .{self}) else null; + errdefer if (maybe_message_builder_thread) |thread| { + exitAndJoin(self.exit, thread); + }; - var responder_handle = try Thread.spawn(.{}, socket_utils.sendSocket, .{ + const responder_thread = try Thread.spawn(.{}, socket_utils.sendSocket, .{ &self.gossip_socket, self.packet_outgoing_channel, self.exit, self.logger, }); - defer self.joinAndExit(&responder_handle); + errdefer exitAndJoin(self.exit, responder_thread); - var dump_handle = if (dump) try Thread.spawn(.{}, GossipDumpService.run, .{.{ + const maybe_dumper_thread: ?std.Thread = if (dump) try Thread.spawn(.{}, GossipDumpService.run, .{.{ .allocator = self.allocator, .logger = self.logger, .gossip_table_rw = &self.gossip_table_rw, .exit = self.exit, }}) else null; - defer if (dump_handle) |*h| self.joinAndExit(h); + errdefer if (maybe_dumper_thread) |thread| { + exitAndJoin(self.exit, thread); + }; + + return .{ + .exit = self.exit, + + .receiver_thread = receiver_thread, + .packet_verifier_thread = packet_verifier_thread, + .message_processor_thread = message_processor_thread, + .message_builder_thread = maybe_message_builder_thread, + .responder_thread = responder_thread, + .dumper_thread = maybe_dumper_thread, + }; } - const VerifyMessageTask = struct { + pub fn run(self: *Self, params: RunThreadsParams) !void { + const run_handles = try self.runThreads(params); + defer run_handles.joinAndExit(); + } + + const VerifyMessageTask = ThreadPoolTask(VerifyMessageEntry); + const VerifyMessageEntry = struct { allocator: std.mem.Allocator, packet_batch: ArrayList(Packet), verified_incoming_channel: *Channel(GossipMessageWithEndpoint), logger: Logger, - task: Task, - done: std.atomic.Value(bool) = std.atomic.Value(bool).init(true), - - pub fn callback(task: *Task) void { - var self: *@This() = @fieldParentPtr("task", task); - std.debug.assert(!self.done.load(.acquire)); - defer self.done.store(true, .release); + pub fn callback(self: *VerifyMessageEntry) !void { defer self.packet_batch.deinit(); - for (self.packet_batch.items) |*packet| { + for (@as([]const Packet, self.packet_batch.items)) |*packet| { var message = bincode.readFromSlice( self.allocator, GossipMessage, @@ -359,37 +416,31 @@ pub const GossipService = struct { continue; }; - const msg = GossipMessageWithEndpoint{ + const msg: GossipMessageWithEndpoint = .{ .from_endpoint = packet.addr, .message = message, }; - - self.verified_incoming_channel.send(msg) catch unreachable; - } - } - - /// waits for the task to be done, then resets the done state to false - fn awaitAndReset(self: *VerifyMessageTask) void { - while (!self.done.load(.acquire)) { - // wait + try self.verified_incoming_channel.send(msg); } - self.done.store(false, .release); } }; /// main logic for deserializing Packets into GossipMessage messages /// and verifing they have valid values, and have valid signatures. /// Verified GossipMessagemessages are then sent to the verified_channel. - fn verifyPackets(self: *Self) !void { - var tasks = try self.allocator.alloc(VerifyMessageTask, 4); + fn verifyPackets( + self: *Self, + /// Must be thread-safe. Can be a specific allocator which will + /// only be contended for by the tasks spawned by in function. + task_allocator: std.mem.Allocator, + ) !void { + const tasks = try VerifyMessageTask.init(self.allocator, GOSSIP_VERIFY_PACKET_PARALLEL_TASKS); defer self.allocator.free(tasks); // pre-allocate all the tasks for (tasks) |*task| { - task.* = VerifyMessageTask{ - // .allocator = std.heap.page_allocator, // TODO: swap out with arg - .allocator = self.allocator, // TODO: swap out with arg - .task = .{ .callback = VerifyMessageTask.callback }, + task.entry = .{ + .allocator = task_allocator, .verified_incoming_channel = self.verified_incoming_channel, .packet_batch = undefined, .logger = self.logger, @@ -397,15 +448,9 @@ pub const GossipService = struct { } while (!self.exit.load(.unordered)) { - const maybe_packets = try self.packet_incoming_channel.try_drain(); - if (maybe_packets == null) { - continue; - } - - const packet_batches = maybe_packets.?; - defer { - self.packet_incoming_channel.allocator.free(packet_batches); - } + const maybe_packet_batches = try self.packet_incoming_channel.try_drain(); + const packet_batches = maybe_packet_batches orelse continue; + defer self.packet_incoming_channel.allocator.free(packet_batches); // count number of packets var n_packets_drained: usize = 0; @@ -416,18 +461,14 @@ pub const GossipService = struct { // verify in parallel using the threadpool // PERF: investigate CPU pinning - var task_i: usize = 0; - const n_tasks = tasks.len; + var task_search_start_idx: usize = 0; for (packet_batches) |packet_batch| { - // find a free task - var task_ptr = &tasks[task_i]; - while (!task_ptr.done.load(.acquire)) { - task_i = (task_i + 1) % n_tasks; - task_ptr = &tasks[task_i]; - } - // schedule it - task_ptr.done.store(false, .release); - task_ptr.packet_batch = packet_batch; + const acquired_task_idx = VerifyMessageTask.awaitAndAcquireFirstAvailableTask(tasks, task_search_start_idx); + task_search_start_idx = (acquired_task_idx + 1) % tasks.len; + + const task_ptr = &tasks[acquired_task_idx]; + task_ptr.entry.packet_batch = packet_batch; + task_ptr.result catch |err| self.logger.errf("VerifyMessageTask encountered error: {s}", .{@errorName(err)}); const batch = Batch.from(&task_ptr.task); self.thread_pool.schedule(batch); @@ -435,7 +476,8 @@ pub const GossipService = struct { } for (tasks) |*task| { - task.awaitAndReset(); + task.blockUntilCompletion(); + task.result catch |err| self.logger.errf("VerifyMessageTask encountered error: {s}", .{@errorName(err)}); } self.logger.debugf("verify_packets loop closed", .{}); @@ -470,7 +512,7 @@ pub const GossipService = struct { }; /// main logic for recieving and processing gossip messages. - pub fn processMessages(self: *Self) !void { + pub fn processMessages(self: *Self, message_allocator: std.mem.Allocator) !void { var timer = std.time.Timer.start() catch unreachable; var last_table_trim_ts: u64 = 0; var msg_count: usize = 0; @@ -539,7 +581,7 @@ pub const GossipService = struct { // would be safer. For more info, see: // - GossipTable.remove // - https://github.com/Syndica/sig/pull/69 - msg.message.shallowFree(self.allocator); + msg.message.shallowFree(message_allocator); } self.verified_incoming_channel.allocator.free(messages); } @@ -547,14 +589,12 @@ pub const GossipService = struct { msg_count += messages.len; for (messages) |*message| { - var from_endpoint: EndPoint = message.from_endpoint; - switch (message.message) { .PushMessage => |*push| { try push_messages.append(PushMessage{ .gossip_values = push[1], .from_pubkey = &push[0], - .from_endpoint = &from_endpoint, + .from_endpoint = &message.from_endpoint, }); }, .PullResponse => |*pull| { @@ -597,7 +637,7 @@ pub const GossipService = struct { }, } - const from_addr = SocketAddr.fromEndpoint(&from_endpoint); + const from_addr = SocketAddr.fromEndpoint(&message.from_endpoint); if (from_addr.isUnspecified() or from_addr.port() == 0) { // unable to respond to these messages self.stats.pull_requests_dropped.add(1); @@ -607,7 +647,7 @@ pub const GossipService = struct { try pull_requests.append(.{ .filter = pull[0], .value = value, - .from_endpoint = from_endpoint, + .from_endpoint = message.from_endpoint, }); }, .PruneMessage => |*prune| { @@ -624,7 +664,7 @@ pub const GossipService = struct { try prune_messages.append(prune_data); }, .PingMessage => |*ping| { - const from_addr = SocketAddr.fromEndpoint(&from_endpoint); + const from_addr = SocketAddr.fromEndpoint(&message.from_endpoint); if (from_addr.isUnspecified() or from_addr.port() == 0) { // unable to respond to these messages self.stats.ping_messages_dropped.add(1); @@ -633,13 +673,13 @@ pub const GossipService = struct { try ping_messages.append(PingMessage{ .ping = ping, - .from_endpoint = &from_endpoint, + .from_endpoint = &message.from_endpoint, }); }, .PongMessage => |*pong| { try pong_messages.append(PongMessage{ .pong = pong, - .from_endpoint = &from_endpoint, + .from_endpoint = &message.from_endpoint, }); }, } @@ -764,9 +804,7 @@ pub const GossipService = struct { /// main gossip loop for periodically sending new GossipMessagemessages. /// this includes sending push messages, pull requests, and triming old /// gossip data (in the gossip_table, active_set, and failed_pull_hashes). - fn buildMessages( - self: *Self, - ) !void { + fn buildMessages(self: *Self) !void { var last_push_ts: u64 = 0; var last_stats_publish_ts: u64 = 0; var last_pull_req_ts: u64 = 0; @@ -864,9 +902,7 @@ pub const GossipService = struct { self.stats.table_n_pubkeys.add(n_pubkeys); } - pub fn rotateActiveSet( - self: *Self, - ) !void { + pub fn rotateActiveSet(self: *Self) !void { const now = getWallclockMs(); var buf: [NUM_ACTIVE_SET_ENTRIES]ContactInfo = undefined; const gossip_peers = try self.getGossipNodes(&buf, NUM_ACTIVE_SET_ENTRIES, now); @@ -2628,13 +2664,12 @@ test "gossip.gossip_service: test packet verification" { &exit, logger, ); - defer gossip_service.deinit(); var packet_channel = gossip_service.packet_incoming_channel; var verified_channel = gossip_service.verified_incoming_channel; - var packet_verifier_handle = try Thread.spawn(.{}, GossipService.verifyPackets, .{&gossip_service}); + const packet_verifier_handle = try Thread.spawn(.{}, GossipService.verifyPackets, .{ &gossip_service, gossip_service.allocator }); var rng = std.rand.DefaultPrng.init(getWallclockMs()); var data = gossip.GossipData.randomFromIndex(rng.random(), 0); @@ -2751,6 +2786,7 @@ test "gossip.gossip_service: test packet verification" { test "gossip.gossip_service: process contact info push packet" { const allocator = std.testing.allocator; + const message_allocator = allocator; var exit = AtomicBool.init(false); var my_keypair = try KeyPair.create([_]u8{1} ** 32); const my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); @@ -2770,8 +2806,8 @@ test "gossip.gossip_service: process contact info push packet" { ); defer gossip_service.deinit(); - var verified_channel = gossip_service.verified_incoming_channel; - var responder_channel = gossip_service.packet_outgoing_channel; + const verified_channel = gossip_service.verified_incoming_channel; + const responder_channel = gossip_service.packet_outgoing_channel; var kp = try KeyPair.create(null); const pk = Pubkey.fromPublicKey(&kp.public_key); @@ -2779,7 +2815,7 @@ test "gossip.gossip_service: process contact info push packet" { var packet_handle = try Thread.spawn( .{}, GossipService.processMessages, - .{&gossip_service}, + .{ &gossip_service, message_allocator }, ); // send a push message @@ -2791,8 +2827,7 @@ test "gossip.gossip_service: process contact info push packet" { .LegacyContactInfo = legacy_contact_info, }; const gossip_value = try gossip.SignedGossipData.initSigned(gossip_data, &kp); - var heap_values = try allocator.alloc(gossip.SignedGossipData, 1); - heap_values[0] = gossip_value; + const heap_values = try message_allocator.dupe(gossip.SignedGossipData, &.{gossip_value}); const msg = GossipMessage{ .PushMessage = .{ id, heap_values }, }; @@ -2861,11 +2896,13 @@ test "gossip.service: init, exit, and deinit" { logger, ); - var handle = try std.Thread.spawn( - .{}, - GossipService.run, - .{ &gossip_service, true, false }, - ); + const handle = try std.Thread.spawn(.{}, GossipService.run, .{ + &gossip_service, .{ + .message_allocator = std.testing.allocator, + .spy_node = true, + .dump = false, + }, + }); gossip_service.echo_server.kill(); exit.store(true, .unordered); @@ -2948,10 +2985,12 @@ pub const BenchmarkGossipServiceGeneral = struct { // reset stats defer gossip_service.stats.reset(); - var packet_handle = try Thread.spawn(.{}, GossipService.run, .{ - &gossip_service, - true, // dont build any outgoing messages - false, + const packet_handle = try Thread.spawn(.{}, GossipService.run, .{ + &gossip_service, .{ + .message_allocator = allocator, + .spy_node = true, // dont build any outgoing messages + .dump = false, + }, }); const outgoing_channel = gossip_service.packet_incoming_channel; @@ -3113,10 +3152,12 @@ pub const BenchmarkGossipServicePullRequests = struct { table_lock.unlock(); } - var packet_handle = try Thread.spawn(.{}, GossipService.run, .{ - &gossip_service, - true, // dont build any outgoing messages - false, + const packet_handle = try Thread.spawn(.{}, GossipService.run, .{ + &gossip_service, .{ + .message_allocator = allocator, + .spy_node = true, // dont build any outgoing messages + .dump = false, + }, }); const outgoing_channel = gossip_service.packet_incoming_channel; diff --git a/src/net/socket_utils.zig b/src/net/socket_utils.zig index 0fd1e08bc..22e94fde4 100644 --- a/src/net/socket_utils.zig +++ b/src/net/socket_utils.zig @@ -7,7 +7,7 @@ const Channel = @import("../sync/channel.zig").Channel; const std = @import("std"); const Logger = @import("../trace/log.zig").Logger; -pub const SOCKET_TIMEOUT: usize = 1000000; +pub const SOCKET_TIMEOUT_US: usize = 1 * std.time.us_per_s; pub const PACKETS_PER_BATCH: usize = 64; pub fn readSocket( @@ -17,96 +17,49 @@ pub fn readSocket( exit: *const std.atomic.Value(bool), logger: Logger, ) !void { - //Performance out of the IO without poll + // Performance out of the IO without poll // * block on the socket until it's readable // * set the socket to non blocking // * read until it fails // * set it back to blocking before returning - const MAX_WAIT_NS = std.time.ns_per_ms; // 1ms + try socket.setReadTimeout(SOCKET_TIMEOUT_US); while (!exit.load(.unordered)) { // init a new batch - var count: usize = 0; - const capacity = PACKETS_PER_BATCH; var packet_batch = try std.ArrayList(Packet).initCapacity( allocator, - capacity, + PACKETS_PER_BATCH, ); - packet_batch.appendNTimesAssumeCapacity(Packet.default(), capacity); + errdefer packet_batch.deinit(); // NOTE: usually this would be null (ie, blocking) // but in order to exit cleanly in tests - we set to 1 second try socket.setReadTimeout(std.time.ms_per_s); - var timer = std.time.Timer.start() catch unreachable; // recv packets into batch - while (true) { - const n_packets_read = recvMmsg(socket, packet_batch.items[count..capacity], exit) catch |err| { - if (count > 0 and err == error.WouldBlock) { - if (timer.read() > MAX_WAIT_NS) { - break; - } - } - continue; + while (packet_batch.items.len != packet_batch.capacity) { + var packet: Packet = Packet.default(); + const recv_meta = socket.receiveFrom(&packet.data) catch |err| switch (err) { + error.WouldBlock => { + if (packet_batch.items.len > 0) break; + continue; + }, + else => |e| return e, }; - - if (count == 0) { - // set to nonblocking mode - try socket.setReadTimeout(SOCKET_TIMEOUT); - } - count += n_packets_read; - if (timer.read() > MAX_WAIT_NS or count >= capacity) { - break; - } + const bytes_read = recv_meta.numberOfBytes; + if (bytes_read == 0) return error.SocketClosed; + packet.addr = recv_meta.sender; + packet.size = bytes_read; + packet_batch.appendAssumeCapacity(packet); } - if (count < capacity) { - packet_batch.shrinkAndFree(count); - } + packet_batch.shrinkAndFree(packet_batch.items.len); try incoming_channel.send(packet_batch); } logger.debugf("readSocket loop closed", .{}); } -pub fn recvMmsg( - socket: *UdpSocket, - /// pre-allocated array of packets to fill up - packet_batch: []Packet, - exit: *const std.atomic.Value(bool), -) !usize { - const max_size = packet_batch.len; - var count: usize = 0; - - while (count < max_size) { - var packet = &packet_batch[count]; - const recv_meta = socket.receiveFrom(&packet.data) catch |err| { - // would block then return - if (count > 0 and err == error.WouldBlock) { - break; - } else { - if (exit.load(.unordered)) return 0; - continue; - } - }; - - const bytes_read = recv_meta.numberOfBytes; - if (bytes_read == 0) { - return error.SocketClosed; - } - packet.addr = recv_meta.sender; - packet.size = bytes_read; - - if (count == 0) { - // nonblocking mode - try socket.setReadTimeout(SOCKET_TIMEOUT); - } - count += 1; - } - - return count; -} - pub fn sendSocket( socket: *UdpSocket, outgoing_channel: *Channel(std.ArrayList(Packet)), @@ -117,12 +70,11 @@ pub fn sendSocket( while (!exit.load(.unordered)) { const maybe_packet_batches = try outgoing_channel.try_drain(); - if (maybe_packet_batches == null) { + const packet_batches = maybe_packet_batches orelse { // sleep for 1ms // std.time.sleep(std.time.ns_per_ms * 1); continue; - } - const packet_batches = maybe_packet_batches.?; + }; defer { for (packet_batches) |*packet_batch| { packet_batch.deinit(); diff --git a/src/utils/tar.zig b/src/utils/tar.zig index 90526c086..d4081682f 100644 --- a/src/utils/tar.zig +++ b/src/utils/tar.zig @@ -79,6 +79,7 @@ pub fn parallelUntarToFileSystem( } logger.infof("using {d} threads to unpack snapshot\n", .{n_threads}); + const tasks = try UnTarTask.init(allocator, n_threads); defer allocator.free(tasks); @@ -87,8 +88,10 @@ pub fn parallelUntarToFileSystem( var file_count: usize = 0; const strip_components: u32 = 0; loop: while (true) { - var header_buf = try allocator.alloc(u8, 512); - _ = try reader.readAtLeast(header_buf, 512); + const header_buf = try allocator.alloc(u8, 512); + if (try reader.readAtLeast(header_buf, 512) != 512) { + std.debug.panic("Actual file size too small for header (< 512).", .{}); + } const header: TarHeaderMinimal = .{ .bytes = header_buf[0..512] }; @@ -136,11 +139,16 @@ pub fn parallelUntarToFileSystem( file_count += 1; const contents = try allocator.alloc(u8, file_size); - _ = try reader.readAtLeast(contents, file_size); + const actual_contents_len = try reader.readAtLeast(contents, file_size); + if (actual_contents_len != file_size) { + std.debug.panic("Reported file ({d}) size does not match actual file size ({d})", .{ contents.len, actual_contents_len }); + } try reader.skipBytes(pad_len, .{}); - const entry = UnTarEntry{ + const task_ptr = &tasks[UnTarTask.awaitAndAcquireFirstAvailableTask(tasks, 0)]; + task_ptr.result catch |err| logger.errf("UnTarTask encountered error: {s}", .{@errorName(err)}); + task_ptr.entry = .{ .allocator = allocator, .contents = contents, .dir = dir, @@ -148,7 +156,9 @@ pub fn parallelUntarToFileSystem( .filename_buf = file_name_buf, .header_buf = header_buf, }; - UnTarTask.queue(&thread_pool, tasks, entry); + + const batch = ThreadPool.Batch.from(&task_ptr.task); + thread_pool.schedule(batch); }, .global_extended_header, .extended_header => { return error.TarUnsupportedFileType; @@ -161,9 +171,8 @@ pub fn parallelUntarToFileSystem( // wait for all tasks for (tasks) |*task| { - while (!task.done.load(.acquire)) { - // wait - } + task.blockUntilCompletion(); + task.result catch |err| logger.errf("UnTarTask encountered error: {s}", .{@errorName(err)}); } } diff --git a/src/utils/thread.zig b/src/utils/thread.zig index b72aa0daf..5bffbfab6 100644 --- a/src/utils/thread.zig +++ b/src/utils/thread.zig @@ -41,50 +41,56 @@ pub fn spawnThreadTasks( } } -pub fn ThreadPoolTask( - comptime EntryType: type, -) type { +pub fn ThreadPoolTask(comptime Entry: type) type { return struct { task: Task, - entry: EntryType, - done: std.atomic.Value(bool) = std.atomic.Value(bool).init(true), + entry: Entry, + available: std.atomic.Value(bool) = std.atomic.Value(bool).init(true), + result: CallbackError!void = {}, const Self = @This(); - pub fn init(allocator: std.mem.Allocator, capacity: usize) ![]Self { - const tasks = try allocator.alloc(Self, capacity); - for (tasks) |*t| { - t.* = .{ - .entry = undefined, - .task = .{ .callback = Self.callback }, - }; - } + const CallbackError = blk: { + const CallbackFn = @TypeOf(Entry.callback); + const CallbackResult = @typeInfo(CallbackFn).Fn.return_type.?; + break :blk switch (@typeInfo(CallbackResult)) { + .ErrorUnion => |info| info.error_set, + else => error{}, + }; + }; + + pub fn init(allocator: std.mem.Allocator, task_count: usize) ![]Self { + const tasks = try allocator.alloc(Self, task_count); + @memset(tasks, .{ + .entry = undefined, + .task = .{ .callback = Self.callback }, + }); return tasks; } fn callback(task: *Task) void { const self: *Self = @fieldParentPtr("task", task); - std.debug.assert(!self.done.load(.acquire)); - defer { - self.done.store(true, .release); - } - self.entry.callback() catch |err| { - std.debug.print("{s} error: {}\n", .{ @typeName(EntryType), err }); - return; - }; + self.result = undefined; + + std.debug.assert(!self.available.load(.acquire)); + defer self.available.store(true, .release); + + self.result = self.entry.callback(); } - pub fn queue(thread_pool: *ThreadPool, tasks: []Self, entry: EntryType) void { - var task_i: usize = 0; - var task_ptr = &tasks[task_i]; - while (!task_ptr.done.load(.acquire)) { - task_i = (task_i + 1) % tasks.len; - task_ptr = &tasks[task_i]; + /// Waits for any of the tasks in the slice to become available. Once one does, + /// it is atomically set to be unavailable, and its index is returned. + pub fn awaitAndAcquireFirstAvailableTask(tasks: []Self, start_index: usize) usize { + var task_index = start_index; + while (tasks[task_index].available.cmpxchgWeak(true, false, .release, .acquire) != null) { + task_index = (task_index + 1) % tasks.len; } - task_ptr.done.store(false, .release); - task_ptr.entry = entry; + return task_index; + } - const batch = Batch.from(&task_ptr.task); - thread_pool.schedule(batch); + pub fn blockUntilCompletion(task: *Self) void { + while (!task.available.load(.acquire)) { + std.atomic.spinLoopHint(); + } } }; }