Skip to content

Commit

Permalink
Thread/task management improvements
Browse files Browse the repository at this point in the history
* [ThreadPoolTask] Rename `done` field to `avaialable`, for better
  clarity in the contexts it's used.

* [ThreadPoolTask] Add `awaitAndAcquireFirstAvailableTask` &
  `blockUntilCompletion` functions, remove `queue` function,
  replace usages of `queue` and code which would now duplicate
  the aforementioned additions.

* [ThreadPoolTask] Add `result` field for signalling the result
  of the callback.

* [gossip] Make use of `ThreadPoolTask` instead of manually implementing
  the VerifyMessageTask Task interface details.

* [gossip] Add `GOSSIP_VERIFY_PACKET_PARALLEL_TASKS` constant, and use
  it instead of hardcoding the tasks allocated in `verifyPackets`.

* [gossip] Make `verifyPackets` and `processMessages` accept an
  a shared allocator, and make `run` provide a thread-safe allocator
  which will only be contended by a maximum of the avilable tasks
  plus the number of messages processed at the same time.

* [tar] Panic on some invalid but possible invariants.

* Constify more things that don't need to be mutable.
  • Loading branch information
InKryption committed May 31, 2024
1 parent d0307e1 commit 3b7b21b
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 118 deletions.
140 changes: 61 additions & 79 deletions src/gossip/service.zig
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
const std = @import("std");
const builtin = @import("builtin");
const network = @import("zig-network");
const EndPoint = network.EndPoint;
const Packet = @import("../net/packet.zig").Packet;
const PACKET_DATA_SIZE = @import("../net/packet.zig").PACKET_DATA_SIZE;
const ThreadPoolTask = @import("../utils/thread.zig").ThreadPoolTask;
const ThreadPool = @import("../sync/thread_pool.zig").ThreadPool;
const Task = ThreadPool.Task;
const Batch = ThreadPool.Batch;
Expand Down Expand Up @@ -88,6 +90,8 @@ const DEFAULT_EPOCH_DURATION: u64 = 172800000;
pub const PUB_GOSSIP_STATS_INTERVAL_MS = 2 * std.time.ms_per_s;
pub const GOSSIP_TRIM_INTERVAL_MS = 10 * std.time.ms_per_s;

pub const GOSSIP_VERIFY_PACKET_PARALLEL_TASKS = 4;

pub const GossipService = struct {
allocator: std.mem.Allocator,

Expand Down Expand Up @@ -260,7 +264,7 @@ pub const GossipService = struct {

/// these threads should run forever - so if they join - somethings wrong
/// and we should shutdown
fn joinAndExit(self: *Self, handle: *std.Thread) void {
fn joinAndExit(self: *const Self, handle: std.Thread) void {
handle.join();
self.exit.store(true, .unordered);
}
Expand All @@ -279,61 +283,57 @@ pub const GossipService = struct {
// const echo_server_thread = try self.echo_server.listenAndServe();
// _ = echo_server_thread;

var receiver_handle = try Thread.spawn(.{}, socket_utils.readSocket, .{
const receiver_thread = try Thread.spawn(.{}, socket_utils.readSocket, .{
self.allocator,
&self.gossip_socket,
self.packet_incoming_channel,
self.exit,
self.logger,
});
defer self.joinAndExit(&receiver_handle);
defer self.joinAndExit(receiver_thread);

var packet_verifier_handle = try Thread.spawn(.{}, verifyPackets, .{self});
defer self.joinAndExit(&packet_verifier_handle);
// TODO: use better allocator, unless GPA becomes more performant.
var gp_message_allocator: std.heap.GeneralPurposeAllocator(.{ .thread_safe = !builtin.single_threaded }) = .{};
defer _ = gp_message_allocator.deinit();
const message_allocator = gp_message_allocator.allocator();

var packet_handle = try Thread.spawn(.{}, processMessages, .{self});
defer self.joinAndExit(&packet_handle);
const packet_verifier_handle = try Thread.spawn(.{}, verifyPackets, .{ self, message_allocator });
defer self.joinAndExit(packet_verifier_handle);

var maybe_build_messages_handle = if (!spy_node) try Thread.spawn(.{}, buildMessages, .{self}) else null;
defer {
if (maybe_build_messages_handle) |*handle| {
self.joinAndExit(handle);
}
}
const packet_handle = try Thread.spawn(.{}, processMessages, .{ self, message_allocator });
defer self.joinAndExit(packet_handle);

var responder_handle = try Thread.spawn(.{}, socket_utils.sendSocket, .{
const maybe_build_messages_handle = if (!spy_node) try Thread.spawn(.{}, buildMessages, .{self}) else null;
defer if (maybe_build_messages_handle) |h| self.joinAndExit(h);

const responder_handle = try Thread.spawn(.{}, socket_utils.sendSocket, .{
&self.gossip_socket,
self.packet_outgoing_channel,
self.exit,
self.logger,
});
defer self.joinAndExit(&responder_handle);
defer self.joinAndExit(responder_handle);

var dump_handle = if (dump) try Thread.spawn(.{}, GossipDumpService.run, .{.{
const dump_handle = if (dump) try Thread.spawn(.{}, GossipDumpService.run, .{.{
.allocator = self.allocator,
.logger = self.logger,
.gossip_table_rw = &self.gossip_table_rw,
.exit = self.exit,
}}) else null;
defer if (dump_handle) |*h| self.joinAndExit(h);
defer if (dump_handle) |h| self.joinAndExit(h);
}

const VerifyMessageTask = struct {
const VerifyMessageTask = ThreadPoolTask(VerifyMessageEntry);
const VerifyMessageEntry = struct {
allocator: std.mem.Allocator,
packet_batch: ArrayList(Packet),
verified_incoming_channel: *Channel(GossipMessageWithEndpoint),
logger: Logger,

task: Task,
done: std.atomic.Value(bool) = std.atomic.Value(bool).init(true),

pub fn callback(task: *Task) void {
var self: *@This() = @fieldParentPtr("task", task);
std.debug.assert(!self.done.load(.acquire));
defer self.done.store(true, .release);
pub fn callback(self: *VerifyMessageEntry) !void {
defer self.packet_batch.deinit();

for (self.packet_batch.items) |*packet| {
for (@as([]const Packet, self.packet_batch.items)) |*packet| {
var message = bincode.readFromSlice(
self.allocator,
GossipMessage,
Expand All @@ -359,53 +359,39 @@ pub const GossipService = struct {
continue;
};

const msg = GossipMessageWithEndpoint{
const msg: GossipMessageWithEndpoint = .{
.from_endpoint = packet.addr,
.message = message,
};

self.verified_incoming_channel.send(msg) catch unreachable;
try self.verified_incoming_channel.send(msg);
}
}

/// waits for the task to be done, then resets the done state to false
fn awaitAndReset(self: *VerifyMessageTask) void {
while (!self.done.load(.acquire)) {
// wait
}
self.done.store(false, .release);
}
};

/// main logic for deserializing Packets into GossipMessage messages
/// and verifing they have valid values, and have valid signatures.
/// Verified GossipMessagemessages are then sent to the verified_channel.
fn verifyPackets(self: *Self) !void {
var tasks = try self.allocator.alloc(VerifyMessageTask, 4);
fn verifyPackets(
self: *Self,
/// Must be thread-safe. Can be a specific allocator which will
/// only be contended for by the tasks spawned by in function.
task_allocator: std.mem.Allocator,
) !void {
const tasks = try VerifyMessageTask.init(self.allocator, GOSSIP_VERIFY_PACKET_PARALLEL_TASKS);
defer self.allocator.free(tasks);

// pre-allocate all the tasks
for (tasks) |*task| {
task.* = VerifyMessageTask{
// .allocator = std.heap.page_allocator, // TODO: swap out with arg
.allocator = self.allocator, // TODO: swap out with arg
.task = .{ .callback = VerifyMessageTask.callback },
.verified_incoming_channel = self.verified_incoming_channel,
.packet_batch = undefined,
.logger = self.logger,
};
}
for (tasks) |*task| task.entry = .{
.allocator = task_allocator,
.verified_incoming_channel = self.verified_incoming_channel,
.packet_batch = undefined,
.logger = self.logger,
};

while (!self.exit.load(.unordered)) {
const maybe_packets = try self.packet_incoming_channel.try_drain();
if (maybe_packets == null) {
continue;
}

const packet_batches = maybe_packets.?;
defer {
self.packet_incoming_channel.allocator.free(packet_batches);
}
const maybe_packet_batches = try self.packet_incoming_channel.try_drain();
const packet_batches = maybe_packet_batches orelse continue;
defer self.packet_incoming_channel.allocator.free(packet_batches);

// count number of packets
var n_packets_drained: usize = 0;
Expand All @@ -416,26 +402,23 @@ pub const GossipService = struct {

// verify in parallel using the threadpool
// PERF: investigate CPU pinning
var task_i: usize = 0;
const n_tasks = tasks.len;
var task_search_start_idx: usize = 0;
for (packet_batches) |packet_batch| {
// find a free task
var task_ptr = &tasks[task_i];
while (!task_ptr.done.load(.acquire)) {
task_i = (task_i + 1) % n_tasks;
task_ptr = &tasks[task_i];
}
// schedule it
task_ptr.done.store(false, .release);
task_ptr.packet_batch = packet_batch;
const acquired_task_idx = VerifyMessageTask.awaitAndAcquireFirstAvailableTask(tasks, task_search_start_idx);
task_search_start_idx = (acquired_task_idx + 1) % tasks.len;

const task_ptr = &tasks[acquired_task_idx];
task_ptr.entry.packet_batch = packet_batch;
task_ptr.result catch |err| self.logger.errf("VerifyMessageTask encountered error: {s}", .{@errorName(err)});

const batch = Batch.from(&task_ptr.task);
self.thread_pool.schedule(batch);
}
}

for (tasks) |*task| {
task.awaitAndReset();
task.blockUntilCompletion();
task.result catch |err| self.logger.errf("VerifyMessageTask encountered error: {s}", .{@errorName(err)});
}

self.logger.debugf("verify_packets loop closed", .{});
Expand Down Expand Up @@ -470,7 +453,7 @@ pub const GossipService = struct {
};

/// main logic for recieving and processing gossip messages.
pub fn processMessages(self: *Self) !void {
pub fn processMessages(self: *Self, message_allocator: std.mem.Allocator) !void {
var timer = std.time.Timer.start() catch unreachable;
var last_table_trim_ts: u64 = 0;
var msg_count: usize = 0;
Expand Down Expand Up @@ -539,7 +522,7 @@ pub const GossipService = struct {
// would be safer. For more info, see:
// - GossipTable.remove
// - https://github.com/Syndica/sig/pull/69
msg.message.shallowFree(self.allocator);
msg.message.shallowFree(message_allocator);
}
self.verified_incoming_channel.allocator.free(messages);
}
Expand Down Expand Up @@ -2626,13 +2609,12 @@ test "gossip.gossip_service: test packet verification" {
&exit,
logger,
);

defer gossip_service.deinit();

var packet_channel = gossip_service.packet_incoming_channel;
var verified_channel = gossip_service.verified_incoming_channel;

var packet_verifier_handle = try Thread.spawn(.{}, GossipService.verifyPackets, .{&gossip_service});
const packet_verifier_handle = try Thread.spawn(.{}, GossipService.verifyPackets, .{ &gossip_service, gossip_service.allocator });

var rng = std.rand.DefaultPrng.init(getWallclockMs());
var data = gossip.GossipData.randomFromIndex(rng.random(), 0);
Expand Down Expand Up @@ -2749,6 +2731,7 @@ test "gossip.gossip_service: test packet verification" {

test "gossip.gossip_service: process contact info push packet" {
const allocator = std.testing.allocator;
const message_allocator = allocator;
var exit = AtomicBool.init(false);
var my_keypair = try KeyPair.create([_]u8{1} ** 32);
const my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key);
Expand All @@ -2768,16 +2751,16 @@ test "gossip.gossip_service: process contact info push packet" {
);
defer gossip_service.deinit();

var verified_channel = gossip_service.verified_incoming_channel;
var responder_channel = gossip_service.packet_outgoing_channel;
const verified_channel = gossip_service.verified_incoming_channel;
const responder_channel = gossip_service.packet_outgoing_channel;

var kp = try KeyPair.create(null);
const pk = Pubkey.fromPublicKey(&kp.public_key);

var packet_handle = try Thread.spawn(
.{},
GossipService.processMessages,
.{&gossip_service},
.{ &gossip_service, message_allocator },
);

// send a push message
Expand All @@ -2789,8 +2772,7 @@ test "gossip.gossip_service: process contact info push packet" {
.LegacyContactInfo = legacy_contact_info,
};
const gossip_value = try gossip.SignedGossipData.initSigned(gossip_data, &kp);
var heap_values = try allocator.alloc(gossip.SignedGossipData, 1);
heap_values[0] = gossip_value;
const heap_values = try message_allocator.dupe(gossip.SignedGossipData, &.{gossip_value});
const msg = GossipMessage{
.PushMessage = .{ id, heap_values },
};
Expand Down
25 changes: 17 additions & 8 deletions src/utils/tar.zig
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub fn parallelUntarToFileSystem(
}

logger.infof("using {d} threads to unpack snapshot\n", .{n_threads});

const tasks = try UnTarTask.init(allocator, n_threads);
defer allocator.free(tasks);

Expand All @@ -87,8 +88,10 @@ pub fn parallelUntarToFileSystem(
var file_count: usize = 0;
const strip_components: u32 = 0;
loop: while (true) {
var header_buf = try allocator.alloc(u8, 512);
_ = try reader.readAtLeast(header_buf, 512);
const header_buf = try allocator.alloc(u8, 512);
if (try reader.readAtLeast(header_buf, 512) != 512) {
std.debug.panic("Actual file size too small for header (< 512).", .{});
}

const header: TarHeaderMinimal = .{ .bytes = header_buf[0..512] };

Expand Down Expand Up @@ -136,19 +139,26 @@ pub fn parallelUntarToFileSystem(
file_count += 1;

const contents = try allocator.alloc(u8, file_size);
_ = try reader.readAtLeast(contents, file_size);
const actual_contents_len = try reader.readAtLeast(contents, file_size);
if (actual_contents_len != file_size) {
std.debug.panic("Reported file ({d}) size does not match actual file size ({d})", .{ contents.len, actual_contents_len });
}

try reader.skipBytes(pad_len, .{});

const entry = UnTarEntry{
const task_ptr = &tasks[UnTarTask.awaitAndAcquireFirstAvailableTask(tasks, 0)];
task_ptr.result catch |err| logger.errf("UnTarTask encountered error: {s}", .{@errorName(err)});
task_ptr.entry = .{
.allocator = allocator,
.contents = contents,
.dir = dir,
.file_name = file_name,
.filename_buf = file_name_buf,
.header_buf = header_buf,
};
UnTarTask.queue(&thread_pool, tasks, entry);

const batch = ThreadPool.Batch.from(&task_ptr.task);
thread_pool.schedule(batch);
},
.global_extended_header, .extended_header => {
return error.TarUnsupportedFileType;
Expand All @@ -161,9 +171,8 @@ pub fn parallelUntarToFileSystem(

// wait for all tasks
for (tasks) |*task| {
while (!task.done.load(.acquire)) {
// wait
}
task.blockUntilCompletion();
task.result catch |err| logger.errf("UnTarTask encountered error: {s}", .{@errorName(err)});
}
}

Expand Down
Loading

0 comments on commit 3b7b21b

Please sign in to comment.