diff --git a/build.zig b/build.zig index 444240d3e..fde06def7 100644 --- a/build.zig +++ b/build.zig @@ -123,43 +123,63 @@ pub fn build(b: *std.Build) void { const run_step = b.step("run", "Run the app"); run_step.dependOn(&run_cmd.step); - // gossip fuzz testing - // find ./zig-cache/o/* | grep fuzz - // lldb $(above path) - const fuzz_exe = b.addExecutable(.{ - .name = "fuzz", - .root_source_file = .{ .path = "src/gossip/fuzz.zig" }, - .target = target, - .optimize = optimize, - .main_pkg_path = .{ .path = "src" }, - }); - fuzz_exe.addModule("base58-zig", base58_module); - fuzz_exe.addModule("zig-network", zig_network_module); - fuzz_exe.addModule("zig-cli", zig_cli_module); - fuzz_exe.addModule("getty", getty_mod); - b.installArtifact(fuzz_exe); - const fuzz_cmd = b.addRunArtifact(fuzz_exe); - b.step("fuzz_gossip", "fuzz gossip").dependOn(&fuzz_cmd.step); - - // benchmarking - const benchmark_exe = b.addExecutable(.{ - .name = "benchmark", - .root_source_file = .{ .path = "src/benchmarks.zig" }, - .target = target, - // TODO: make it work - // .optimize = std.builtin.Mode.ReleaseSafe, // to get decent results - but things get optimized away - .optimize = optimize, - .main_pkg_path = .{ .path = "src" }, - }); - benchmark_exe.addModule("base58-zig", base58_module); - benchmark_exe.addModule("zig-network", zig_network_module); - benchmark_exe.addModule("zig-cli", zig_cli_module); - benchmark_exe.addModule("getty", getty_mod); - b.installArtifact(benchmark_exe); - const benchmark_cmd = b.addRunArtifact(benchmark_exe); - if (b.args) |args| { - benchmark_cmd.addArgs(args); + const ExecCommand = struct { + name: []const u8, + path: []const u8, + description: []const u8 = "", + }; + + const exec_commands = [_]ExecCommand{ + ExecCommand { + .name = "fuzz", + .path = "src/gossip/fuzz.zig", + .description = "gossip fuzz testing", + }, + ExecCommand { + .name = "benchmark", + .path = "src/benchmarks.zig", + .description = "benchmark client", + }, + ExecCommand { + .name = "snapshot_utils", + .path = "src/cmd/snapshot_utils.zig", + .description = "snapshot utils", + }, + ExecCommand { + .name = "snapshot_verify", + .path = "src/cmd/snapshot_verify.zig", + .description = "verify snapshot account hashes", + }, + // tmp :: remove when done with + ExecCommand { + .name = "accounts", + .path = "src/core/accounts_db.zig", + .description = "tmp file", + }, + }; + + for (exec_commands) |command_info| { + const exec = b.addExecutable(.{ + .name = command_info.name, + .root_source_file = .{ .path = command_info.path }, + .target = target, + .optimize = optimize, + .main_pkg_path = .{ .path = "src" }, + }); + + // TODO: maybe we dont need all these for all bins + exec.addModule("base58-zig", base58_module); + exec.addModule("zig-network", zig_network_module); + exec.addModule("zig-cli", zig_cli_module); + exec.addModule("getty", getty_mod); + + // this lets us run it as an exec + b.installArtifact(exec); + + const cmd = b.addRunArtifact(exec); + if (b.args) |args| cmd.addArgs(args); + b + .step(command_info.name, command_info.description) + .dependOn(&cmd.step); } - - b.step("benchmark", "benchmark gossip").dependOn(&benchmark_cmd.step); } diff --git a/src/cmd/snapshot_utils.zig b/src/cmd/snapshot_utils.zig new file mode 100644 index 000000000..684c2398b --- /dev/null +++ b/src/cmd/snapshot_utils.zig @@ -0,0 +1,508 @@ +const std = @import("std"); +const cli = @import("zig-cli"); +const bincode = @import("../bincode/bincode.zig"); +const AccountsDbFields = @import("../core/snapshot_fields.zig").AccountsDbFields; +const AppendVecInfo = @import("../core/snapshot_fields.zig").AppendVecInfo; +const AppendVec = @import("../core/append_vec.zig").AppendVec; +const TmpPubkey = @import("../core/append_vec.zig").TmpPubkey; +const Account = @import("../core/account.zig").Account; +const Pubkey = @import("../core/pubkey.zig").Pubkey; +const Slot = @import("../core/clock.zig").Slot; +const ArrayList = std.ArrayList; +const ThreadPool = @import("../sync/thread_pool.zig").ThreadPool; +const Task = ThreadPool.Task; +const Batch = ThreadPool.Batch; + +const Channel = @import("../sync/channel.zig").Channel; +const SnapshotFields = @import("../core/snapshot_fields.zig").SnapshotFields; + +pub const AccountAndPubkey = struct { + pubkey: TmpPubkey, + account: Account, +}; + +pub const CsvRows = []u8; +pub const CsvChannel = Channel(CsvRows); + +pub fn accountsToCsvRowAndSend( + alloc: std.mem.Allocator, + accounts_db_fields: *AccountsDbFields, + accounts_dir_path: []const u8, + channel: *CsvChannel, + owner_filter: ?TmpPubkey, + // ! + filename: []const u8, +) !void { + // parse "{slot}.{id}" from the filename + var fiter = std.mem.tokenizeSequence(u8, filename, "."); + const slot = try std.fmt.parseInt(Slot, fiter.next().?, 10); + const append_vec_id = try std.fmt.parseInt(usize, fiter.next().?, 10); + + // read metadata + const slot_metas: ArrayList(AppendVecInfo) = accounts_db_fields.map.get(slot).?; + std.debug.assert(slot_metas.items.len == 1); + const slot_meta = slot_metas.items[0]; + std.debug.assert(slot_meta.id == append_vec_id); + + // read appendVec from file + var abs_path_buf: [1024]u8 = undefined; + const abs_path = try std.fmt.bufPrint(&abs_path_buf, "{s}/{s}", .{ accounts_dir_path, filename }); + const append_vec_file = try std.fs.openFileAbsolute(abs_path, .{ .mode = .read_write }); + + var append_vec = AppendVec.init(append_vec_file, slot_meta, slot) catch return; + defer append_vec.deinit(); + + // verify its valid + append_vec.sanitize() catch { + append_vec.deinit(); + return; + }; + + const pubkey_and_refs = try append_vec.getAccountsRefs(alloc); + defer pubkey_and_refs.deinit(); + + // compute the full size to allocate at once + var total_fmt_size: u64 = 0; + for (pubkey_and_refs.items) |*pubkey_and_ref| { + const account = try append_vec.getAccount(pubkey_and_ref.account_ref.offset); + + if (owner_filter) |owner| { + if (!account.account_info.owner.equals(&owner)) continue; + } + + // 5 seperators = 5 bytes + // new line = 1 byte + // pubkey string = 44 bytes + // owner string = 44 bytes + // data = { 1, 2, 3, 4 } + // ?? the number themeselves (1*data.len bytes) + // + comma per datapoint ( 1*data.len) + // + whitespace ( ~2*data.len ) + '{' '}' = ~4 * data.len + 2 + // lamports = 8 bytes + // executable = "true" or "false" = 5 bytes + // rent_epoch = 8 bytes + + // estimate? + const fmt_count = 120 + 5 * account.data.len; + total_fmt_size += fmt_count; + } + + const csv_string = alloc.alloc(u8, total_fmt_size) catch unreachable; + var csv_string_offset: usize = 0; + + for (pubkey_and_refs.items) |*pubkey_and_ref| { + const pubkey = pubkey_and_ref.pubkey; + const account = try append_vec.getAccount(pubkey_and_ref.account_ref.offset); + if (owner_filter) |owner| { + if (!account.account_info.owner.equals(&owner)) continue; + } + + const owner_pk = try Pubkey.fromBytes(&account.account_info.owner.data, .{}); + + + const fmt_slice_len = (std.fmt.bufPrint( + csv_string[csv_string_offset..], + "{s};{s};{any};{d};{any};{d}\n", + .{ + try pubkey.toString(), + owner_pk.string(), + account.data, + account.account_info.lamports, + account.account_info.executable, + account.account_info.rent_epoch, + }, + ) catch unreachable).len; + + csv_string_offset += fmt_slice_len; + } + + _ = channel.send(csv_string) catch unreachable; +} + +// what all the tasks will need +const CsvTask = struct { + allocator: std.mem.Allocator, + accounts_db_fields: *AccountsDbFields, + accounts_dir_path: []const u8, + channel: *CsvChannel, + owner_filter: ?TmpPubkey, + + file_names: [][]const u8, + + task: Task, + done: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), + + pub fn callback(task: *Task) void { + var self = @fieldParentPtr(@This(), "task", task); + defer self.done.store(true, std.atomic.Ordering.Release); + + for (self.file_names) |file_name| { + accountsToCsvRowAndSend( + self.allocator, + self.accounts_db_fields, + self.accounts_dir_path, + self.channel, + self.owner_filter, + file_name, + ) catch {}; + } + } + + pub fn reset(self: *@This()) void { + self.done.store(false, std.atomic.Ordering.Release); + for (self.file_names) |file_name| { + self.allocator.free(file_name); + } + self.allocator.free(self.file_names); + } +}; + +pub fn runTaskScheduler( + allocator: std.mem.Allocator, + thread_pool: *ThreadPool, + iter: *std.fs.IterableDir.Iterator, + tasks_slice: anytype, // []SomeAccountFileTask + is_done: *std.atomic.Atomic(bool), + comptime chunk_size: usize, +) void { + const n_tasks = tasks_slice.len; + + var ready_indexes = std.ArrayList(usize).initCapacity(allocator, n_tasks) catch unreachable; + defer ready_indexes.deinit(); + var running_indexes = std.ArrayList(usize).initCapacity(allocator, n_tasks) catch unreachable; + defer running_indexes.deinit(); + + // at the start = all ready to schedule + for (0..n_tasks) |i| ready_indexes.appendAssumeCapacity(i); + + var account_name_buf: [chunk_size][]const u8 = undefined; + var has_sent_all_accounts = false; + while (!has_sent_all_accounts) { + // queue the ready tasks + var batch = Batch{}; + const n_ready = ready_indexes.items.len; + for (0..n_ready) |_| { + var i: usize = 0; + while (i < chunk_size) : (i += 1) { + const file = iter.next() catch { + has_sent_all_accounts = true; + break; + } orelse { + has_sent_all_accounts = true; + break; + }; + account_name_buf[i] = file.name; + } + if (i == 0) break; + + // populate the task + const task_index = ready_indexes.pop(); + const task = &tasks_slice[task_index]; + + // fill out the filename + var file_names = allocator.alloc([]const u8, i) catch unreachable; + for (0..i) |idx| { + var filename: []const u8 = account_name_buf[idx]; + var heap_filename = allocator.alloc(u8, filename.len) catch unreachable; + @memcpy(heap_filename, filename); + file_names[idx] = heap_filename; + } + task.file_names = file_names[0..i]; + + const task_batch = Batch.from(&task.task); + batch.push(task_batch); + + running_indexes.appendAssumeCapacity(task_index); + + if (has_sent_all_accounts) break; + } + + if (batch.len != 0) { + ThreadPool.schedule(thread_pool, batch); + } + + if (has_sent_all_accounts) { + std.debug.print("sent all account files!\n", .{}); + } + + var current_index: usize = 0; + const n_running = running_indexes.items.len; + for (0..n_running) |_| { + const task_index = running_indexes.items[current_index]; + const task = &tasks_slice[task_index]; + + if (!task.done.load(std.atomic.Ordering.Acquire)) { + if (has_sent_all_accounts) { + // these are the last tasks so we wait for them until they are done + while (!task.done.load(std.atomic.Ordering.Acquire)) {} + } + // check the next task + current_index += 1; + } else { + ready_indexes.appendAssumeCapacity(task_index); + // removing so next task can be checked without changing current_index + _ = running_indexes.orderedRemove(current_index); + task.reset(); + } + } + } + + is_done.store(true, std.atomic.Ordering.Release); +} + +pub fn recvAndWriteCsv( + total_append_vec_count: usize, + csv_file: std.fs.File, + channel: *CsvChannel, + is_done: *std.atomic.Atomic(bool), +) void { + var append_vec_count: usize = 0; + var writer = csv_file.writer(); + const start_time: u64 = @intCast(std.time.milliTimestamp() * std.time.ns_per_ms); + + while (true) { + const maybe_csv_rows = channel.try_drain() catch unreachable; + + var csv_rows = maybe_csv_rows orelse { + // check if all tasks are done + if (is_done.load(std.atomic.Ordering.Acquire)) break; + continue; + }; + defer channel.allocator.free(csv_rows); + + for (csv_rows) |csv_row| { + writer.writeAll(csv_row) catch unreachable; + channel.allocator.free(csv_row); + append_vec_count += 1; + + const vecs_left = total_append_vec_count - append_vec_count; + if (append_vec_count % 100 == 0 or vecs_left < 100) { + // estimate how long left + const now: u64 = @intCast(std.time.milliTimestamp() * std.time.ns_per_ms); + const elapsed = now - start_time; + const ns_per_vec = elapsed / append_vec_count; + const time_left = ns_per_vec * vecs_left / std.time.ns_per_min; + + std.debug.print("dumped {d}/{d} appendvecs - (mins left: {d})\r", .{ + append_vec_count, + total_append_vec_count, + time_left, + }); + } + } + } +} + +var owner_filter_option = cli.Option{ + .long_name = "owner-filter", + .short_alias = 's', + .help = "owner pubkey to filter what accounts to dump", + .required = false, + .value = .{ .string = null }, +}; + +var snapshot_dir_option = cli.Option{ + .long_name = "snapshot-dir", + .short_alias = 's', + .help = "absolute path to the snapshot directory", + .required = true, + .value = .{ .string = null }, +}; + +var app = &cli.App{ + .name = "dump_snapshot", + .description = "utils for snapshot dumping", + .author = "Syndica & Contributors", + .subcommands = &.{ + // requires: dump_account_fields to be run first + &cli.Command{ + .name = "dump_snapshot", + .help = "Dump snapshot accounts to a csv file", + .action = dumpSnapshot, + .options = &.{ + &snapshot_dir_option, + &owner_filter_option, + }, + }, + &cli.Command{ + .name = "dump_account_fields", + .help = "dumps account db fields for faster loading (should run first)", + .options = &.{ + &snapshot_dir_option, + }, + .action = dumpAccountFields, + }, + }, +}; + +pub fn main() !void { + // eg, + // zig build snapshot_utils -Doptimize=ReleaseSafe + // 1) dump the account fields + // ./zig-out/bin/snapshot_utils dump_account_fields -s /Users/tmp/snapshots + // 2) dump the snapshot info + // ./zig-out/bin/snapshot_utils dump_snapshot -s /Users/tmp/snapshots + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + var allocator = gpa.allocator(); + try cli.run(app, allocator); +} + +/// we do this bc the bank_fields in the snapshot metadata is very large +pub fn dumpAccountFields(_: []const []const u8) !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + var allocator = gpa.allocator(); + + const snapshot_dir = snapshot_dir_option.value.string.?; + + // iterate through the snapshot dir + const metadata_sub_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_dir, "snapshots" }, + ); + var metadata_dir = try std.fs.openIterableDirAbsolute(metadata_sub_path, .{}); + var metadata_dir_iter = metadata_dir.iterate(); + var maybe_snapshot_slot: ?usize = null; + while (try metadata_dir_iter.next()) |entry| { + if (entry.kind == std.fs.File.Kind.directory) { + maybe_snapshot_slot = try std.fmt.parseInt(usize, entry.name, 10); + break; + } + } + var snapshot_slot = maybe_snapshot_slot orelse unreachable; + + const metadata_path = try std.fmt.allocPrint( + allocator, + "{s}/{d}/{d}", + .{ metadata_sub_path, snapshot_slot, snapshot_slot }, + ); + + const output_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_dir, "accounts_db.bincode" }, + ); + + std.debug.print("reading metadata path: {s}\n", .{metadata_path}); + std.debug.print("saving to output path: {s}\n", .{output_path}); + + var snapshot_fields = try SnapshotFields.readFromFilePath(allocator, metadata_path); + const fields = snapshot_fields.getFieldRefs(); + + // rewrite the accounts_db_fields seperate + const db_file = try std.fs.createFileAbsolute(output_path, .{}); + defer db_file.close(); + + var db_buf = try bincode.writeToArray(allocator, fields.accounts_db_fields.*, .{}); + defer db_buf.deinit(); + + _ = try db_file.write(db_buf.items); +} + +pub fn dumpSnapshot(_: []const []const u8) !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + var allocator = gpa.allocator(); + + const owner_filter_str = owner_filter_option.value.string; + var owner_filter: ?TmpPubkey = null; + if (owner_filter_str) |str| { + owner_filter = try TmpPubkey.fromString(str); + } + + const snapshot_dir = snapshot_dir_option.value.string.?; + const accounts_db_fields_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_dir, "accounts_db.bincode" }, + ); + const accounts_dir_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_dir, "accounts" }, + ); + const dump_csv_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_dir, "accounts.csv" }, + ); + defer { + allocator.free(accounts_db_fields_path); + allocator.free(accounts_dir_path); + allocator.free(dump_csv_path); + } + + const csv_file = try std.fs.createFileAbsolute(dump_csv_path, .{}); + defer csv_file.close(); + + const accounts_db_fields_file = std.fs.openFileAbsolute(accounts_db_fields_path, .{}) catch { + std.debug.print("could not open accounts_db.bincode - run `prepare` first\n", .{}); + return; + }; + var accounts_db_fields = try bincode.read(allocator, AccountsDbFields, accounts_db_fields_file.reader(), .{}); + defer bincode.free(allocator, accounts_db_fields); + + var accounts_dir = try std.fs.openIterableDirAbsolute(accounts_dir_path, .{}); + var accounts_dir_iter = accounts_dir.iterate(); + + var n_threads = @as(u32, @truncate(std.Thread.getCpuCount() catch unreachable)); + var thread_pool = ThreadPool.init(.{ + .max_threads = n_threads, + .stack_size = 2 * 1024 * 1024, + }); + // clean up threadpool once done + defer thread_pool.shutdown(); + + std.debug.print("starting with {d} threads\n", .{n_threads}); + + // compute the total size (to compute time left) + var total_append_vec_count: usize = 0; + while (try accounts_dir_iter.next()) |_| { + total_append_vec_count += 1; + } + accounts_dir_iter = accounts_dir.iterate(); // reset + + var channel = CsvChannel.init(allocator, 1000); + defer channel.deinit(); + + // setup the tasks + const n_tasks = thread_pool.max_threads * 2; + // pre-allocate all the tasks + var tasks = allocator.alloc(CsvTask, n_tasks) catch unreachable; + defer allocator.free(tasks); + + for (0..tasks.len) |i| { + tasks[i] = CsvTask{ + .task = .{ .callback = CsvTask.callback }, + .accounts_db_fields = &accounts_db_fields, + .accounts_dir_path = accounts_dir_path, + .allocator = allocator, + .channel = channel, + .owner_filter = owner_filter, + // to be filled + .file_names = undefined, + }; + } + + var is_done = std.atomic.Atomic(bool).init(false); + + var handle = std.Thread.spawn(.{}, runTaskScheduler, .{ + allocator, + &thread_pool, + &accounts_dir_iter, + tasks, + &is_done, + 20, + }) catch unreachable; + + recvAndWriteCsv( + total_append_vec_count, + csv_file, + channel, + &is_done, + ); + + handle.join(); + + std.debug.print("done!\n", .{}); +} diff --git a/src/cmd/snapshot_verify.zig b/src/cmd/snapshot_verify.zig new file mode 100644 index 000000000..2973c80fa --- /dev/null +++ b/src/cmd/snapshot_verify.zig @@ -0,0 +1,512 @@ +const std = @import("std"); +const ArrayList = std.ArrayList; +const HashMap = std.AutoHashMap; + +const Account = @import("../core/account.zig").Account; +const Hash = @import("../core/hash.zig").Hash; +const Slot = @import("../core/clock.zig").Slot; +const Pubkey = @import("../core/pubkey.zig").Pubkey; +const bincode = @import("../bincode/bincode.zig"); + +const AccountsDbFields = @import("../core/snapshot_fields.zig").AccountsDbFields; +const AppendVecInfo = @import("../core/snapshot_fields.zig").AppendVecInfo; + +const AppendVec = @import("../core/append_vec.zig").AppendVec; +const TmpPubkey = @import("../core/append_vec.zig").TmpPubkey; +const alignToU64 = @import("../core/append_vec.zig").alignToU64; + +const ThreadPool = @import("../sync/thread_pool.zig").ThreadPool; +const Task = ThreadPool.Task; +const Batch = ThreadPool.Batch; + +const hashAccount = @import("../core/account.zig").hashAccount; +const merkleTreeHash = @import("../common/merkle_tree.zig").merkleTreeHash; + +pub const MERKLE_FANOUT: usize = 16; + +const AccountHashData = struct { + pubkey: TmpPubkey, + hash: Hash, + slot: Slot, + lamports: u64, + id: usize, + offset: usize, +}; + +pub fn indexAndBinFiles( + accounts_db_fields: *const AccountsDbFields, + accounts_dir_path: []const u8, + // task specific + file_names: [][]const u8, + bins: *PubkeyBins, +) !void { + const total_append_vec_count = file_names.len; + + var timer = try std.time.Timer.start(); + // TODO: might need to be longer depending on abs path length + var abs_path_buf: [1024]u8 = undefined; + for (file_names, 1..) |file_name, append_vec_count| { + // parse "{slot}.{id}" from the file_name + var fiter = std.mem.tokenizeSequence(u8, file_name, "."); + const slot = try std.fmt.parseInt(Slot, fiter.next().?, 10); + const append_vec_id = try std.fmt.parseInt(usize, fiter.next().?, 10); + + // read metadata + const slot_metas: ArrayList(AppendVecInfo) = accounts_db_fields.map.get(slot).?; + std.debug.assert(slot_metas.items.len == 1); + const slot_meta = slot_metas.items[0]; + std.debug.assert(slot_meta.id == append_vec_id); + + // read appendVec from file + const abs_path = try std.fmt.bufPrint(&abs_path_buf, "{s}/{s}", .{ accounts_dir_path, file_name }); + const append_vec_file = try std.fs.openFileAbsolute(abs_path, .{ .mode = .read_write }); + var append_vec = AppendVec.init(append_vec_file, slot_meta, slot) catch |err| { + var buf: [1024]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + var writer = stream.writer(); + try std.fmt.format(writer, "failed to *open* appendVec {s}: {s}", .{ file_name, @errorName(err) }); + @panic(stream.getWritten()); + }; + // close after + defer append_vec.deinit(); + + sanitizeAndBin( + &append_vec, + bins, + ) catch |err| { + var buf: [1024]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + var writer = stream.writer(); + try std.fmt.format(writer, "failed to *sanitize* appendVec {s}: {s}", .{ file_name, @errorName(err) }); + @panic(stream.getWritten()); + }; + + if (append_vec_count % 1_000 == 0) { + // estimate how long left + printTimeEstimate( + &timer, + total_append_vec_count, + append_vec_count, + "parsing append vecs", + ); + } + } +} + +/// used for initial loading +/// we want to sanitize and index and bin (for hash verification) in one go +pub fn sanitizeAndBin(append_vec: *AppendVec, bins: *PubkeyBins) !void { + var offset: usize = 0; + var n_accounts: usize = 0; + + while (true) { + var account = append_vec.getAccount(offset) catch break; + try account.sanitize(); + + const pubkey = account.store_info.pubkey; + const hash_is_missing = std.mem.eql(u8, &account.hash.data, &Hash.default().data); + const hash = hashAccount( + account.account_info.lamports, + account.data, + &account.account_info.owner.data, + account.account_info.executable, + account.account_info.rent_epoch, + &pubkey.data, + ); + + if (hash_is_missing) { + account.hash.* = hash; + } else { + const hash_matches = std.mem.eql(u8, &account.hash.data, &hash.data); + if (!hash_matches) { + std.debug.print("account hash mismatch: {s} != {s}\n", .{ account.hash, hash }); + } + } + + try bins.insert(AccountHashData{ + .id = append_vec.id, + .pubkey = pubkey, + .hash = hash, + .lamports = account.account_info.lamports, + .slot = append_vec.slot, + .offset = offset, + }); + + offset = offset + account.len; + n_accounts += 1; + } + + if (offset != alignToU64(append_vec.length)) { + return error.InvalidAppendVecLength; + } + + append_vec.n_accounts = n_accounts; +} + +pub fn sortThreadBins( + allocator: std.mem.Allocator, + thread_bins: []PubkeyBins, + bin_start_index: usize, + bin_end_index: usize, +) !void { + const SlotAndIndex = struct { slot: Slot, index: usize }; + var hashmap = HashMap(TmpPubkey, SlotAndIndex).init(allocator); + defer hashmap.deinit(); + + var timer = try std.time.Timer.start(); + const n_threads = thread_bins.len; + + for (bin_start_index..bin_end_index, 1..) |bin_i, count| { + + // compute total capacity required + var total_len_required: usize = 0; + for (0..n_threads) |i| { + total_len_required += thread_bins[i].bins[bin_i].items.len; + } + var main_bin = try allocator.alloc(AccountHashData, total_len_required); + + // fill the main bin + var main_bin_index: usize = 0; + for (0..n_threads) |thread_i| { + var thread_bin = &thread_bins[thread_i].bins[bin_i]; + defer thread_bin.deinit(); + + for (thread_bin.items) |account_hash_data| { + if (hashmap.getEntry(account_hash_data.pubkey)) |*entry| { + // only track the most recent slot + if (account_hash_data.slot > entry.value_ptr.slot) { + const index = entry.value_ptr.index; + main_bin[index] = account_hash_data; + entry.value_ptr.slot = account_hash_data.slot; + } + } else { + main_bin[main_bin_index] = account_hash_data; + + try hashmap.putNoClobber(account_hash_data.pubkey, .{ + .slot = account_hash_data.slot, + .index = main_bin_index, + }); + main_bin_index += 1; + } + } + } + + // sort main_bin + std.mem.sort(AccountHashData, main_bin[0..main_bin_index], {}, struct { + fn lessThan(_: void, lhs: AccountHashData, rhs: AccountHashData) bool { + return std.mem.lessThan(u8, &lhs.pubkey.data, &rhs.pubkey.data); + } + }.lessThan); + + // update + var main_bin_array = ArrayList(AccountHashData).fromOwnedSlice(allocator, main_bin); + main_bin_array.items.len = main_bin_index; + + thread_bins[0].bins[bin_i] = main_bin_array; + + // clear mem for next iteration + hashmap.clearRetainingCapacity(); + + if (count % 1000 == 0) { + printTimeEstimate( + &timer, + bin_end_index - bin_start_index, + count, + "sorting pubkey bins", + ); + } + } +} + +pub fn printTimeEstimate( + // timer should be started at the beginning + timer: *std.time.Timer, + total: usize, + i: usize, + comptime name: []const u8, +) void { + if (i == 0 or total == 0) return; + + const p_done = i * 100 / total; + const left = total - i; + + const elapsed = timer.read(); + const ns_per_vec = elapsed / i; + const time_left = ns_per_vec * left; + + const min_left = time_left / std.time.ns_per_min; + const sec_left = (time_left / std.time.ns_per_s) - (min_left * std.time.s_per_min); + + if (sec_left < 10) { + std.debug.print("{s}: {d}/{d} ({d}%) (time left: {d}:0{d})\r", .{ + name, + i, + total, + p_done, + min_left, + sec_left, + }); + } else { + std.debug.print("{s}: {d}/{d} ({d}%) (time left: {d}:{d})\r", .{ + name, + i, + total, + p_done, + min_left, + sec_left, + }); + } +} + +pub const PUBKEY_BINS_FOR_CALCULATING_HASHES: usize = 65_536; + +pub const PubkeyBins = struct { + bins: []ArrayList(AccountHashData), + calculator: PubkeyBinCalculator, + + pub fn init(allocator: std.mem.Allocator, n_bins: usize) !PubkeyBins { + const calculator = PubkeyBinCalculator.init(n_bins); + + var bins = try allocator.alloc(ArrayList(AccountHashData), n_bins); + for (bins) |*bin| { + const INIT_BUCKET_LENGTH = 1_000; + bin.* = try ArrayList(AccountHashData).initCapacity(allocator, INIT_BUCKET_LENGTH); + } + + return PubkeyBins{ + .bins = bins, + .calculator = calculator, + }; + } + + pub fn deinit(self: *PubkeyBins) void { + const allocator = self.bins[0].allocator; + for (self.bins) |*bin| { + bin.deinit(); + } + allocator.free(self.bins); + } + + pub fn insert(self: *PubkeyBins, account: AccountHashData) !void { + const bin_index = self.calculator.binIndex(&account.pubkey); + try self.bins[bin_index].append(account); + } +}; + +pub const PubkeyBinCalculator = struct { + shift_bits: u6, + + pub fn init(n_bins: usize) PubkeyBinCalculator { + // u8 * 3 (ie, we consider on the first 3 bytes of a pubkey) + const MAX_BITS: u32 = 24; + // within bounds + std.debug.assert(n_bins > 0); + std.debug.assert(n_bins <= (1 << MAX_BITS)); + // power of two + std.debug.assert((n_bins & (n_bins - 1)) == 0); + // eg, + // 8 bins + // => leading zeros = 28 + // => shift_bits = (24 - (32 - 28 - 1)) = 21 + // ie, + // if we have the first 24 bits set (u8 << 16, 8 + 16 = 24) + // want to consider the first 3 bits of those 24 + // 0000 ... [100]0 0000 0000 0000 0000 0000 + // then we want to shift right by 21 + // 0000 ... 0000 0000 0000 0000 0000 0[100] + // those 3 bits can represent 2^3 (= 8) bins + const shift_bits = @as(u6, @intCast(MAX_BITS - (32 - @clz(@as(u32, @intCast(n_bins))) - 1))); + + return PubkeyBinCalculator{ + .shift_bits = shift_bits, + }; + } + + pub fn binIndex(self: *const PubkeyBinCalculator, pubkey: *const TmpPubkey) usize { + const data = &pubkey.data; + return (@as(usize, data[0]) << 16 | + @as(usize, data[1]) << 8 | + @as(usize, data[2])) >> self.shift_bits; + } +}; + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + var allocator = gpa.allocator(); + + const snapshot_path = "/Users/tmp/Documents/zig-solana/snapshots"; + + const accounts_dir_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_path, "accounts" }, + ); + const accounts_db_fields_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_path, "accounts_db.bincode" }, + ); + + var accounts_dir = try std.fs.openIterableDirAbsolute(accounts_dir_path, .{}); + var accounts_dir_iter = accounts_dir.iterate(); + + // compute the total size (to compute time left) + var total_append_vec_count: usize = 0; + while (try accounts_dir_iter.next()) |_| { + total_append_vec_count += 1; + } + accounts_dir_iter = accounts_dir.iterate(); // reset + std.debug.print("total_append_vec_count: {d}\n", .{total_append_vec_count}); + + // time it + var full_timer = try std.time.Timer.start(); + var timer = try std.time.Timer.start(); + + // allocate all the filenames + var total_name_size: usize = 0; + while (try accounts_dir_iter.next()) |entry| { + total_name_size += entry.name.len; + } + var filename_mem = try allocator.alloc(u8, total_name_size); + defer allocator.free(filename_mem); + accounts_dir_iter = accounts_dir.iterate(); // reset + + var filename_slices = try ArrayList([]u8).initCapacity(allocator, total_append_vec_count); + defer filename_slices.deinit(); + + var index: usize = 0; + while (try accounts_dir_iter.next()) |file_entry| { + const file_name_len = file_entry.name.len; + @memcpy(filename_mem[index..(index + file_name_len)], file_entry.name); + filename_slices.appendAssumeCapacity(filename_mem[index..(index + file_name_len)]); + index += file_name_len; + } + accounts_dir_iter = accounts_dir.iterate(); // reset + std.debug.assert(filename_slices.items.len == total_append_vec_count); + + // read accounts_db.bincode + const accounts_db_fields_file = std.fs.openFileAbsolute(accounts_db_fields_path, .{}) catch |err| { + std.debug.print("failed to open accounts-db fields file: {s} ... skipping test\n", .{@errorName(err)}); + return; + }; + var accounts_db_fields = try bincode.read(allocator, AccountsDbFields, accounts_db_fields_file.reader(), .{}); + defer bincode.free(allocator, accounts_db_fields); + + const accounts_hash_exp = accounts_db_fields.bank_hash_info.accounts_hash; + const total_lamports_exp = accounts_db_fields.bank_hash_info.stats.num_lamports_stored; + std.debug.print("expected hash: {s}\n", .{accounts_hash_exp}); + std.debug.print("expected total lamports: {d}\n", .{total_lamports_exp}); + + // setup the threads + // double the number of CPUs bc of the high I/O from mmap (and cache misses) + timer.reset(); + + var n_threads = @as(u32, @truncate(try std.Thread.getCpuCount())) * 2; + var handles = try ArrayList(std.Thread).initCapacity(allocator, n_threads); + var chunk_size = total_append_vec_count / n_threads; + if (chunk_size == 0) { + n_threads = 1; + } + std.debug.print("starting {d} threads with {d} files per thread\n", .{ n_threads, chunk_size }); + + var start_index: usize = 0; + var end_index: usize = 0; + + // !! + // const n_bins = PUBKEY_BINS_FOR_CALCULATING_HASHES; + const n_bins = 128; + var thread_bins = try allocator.alloc(PubkeyBins, n_threads); + for (thread_bins) |*thread_bin| { + thread_bin.* = try PubkeyBins.init(allocator, n_bins); + } + + for (0..n_threads) |i| { + if (i == (n_threads - 1)) { + end_index = total_append_vec_count; + } else { + end_index = start_index + chunk_size; + } + + const handle = try std.Thread.spawn(.{}, indexAndBinFiles, .{ + &accounts_db_fields, + accounts_dir_path, + filename_slices.items[start_index..end_index], + &thread_bins[i], + }); + handles.appendAssumeCapacity(handle); + start_index = end_index; + } + std.debug.assert(end_index == total_append_vec_count); + + for (handles.items) |handle| { + handle.join(); + } + std.debug.print("\n", .{}); + std.debug.print("done in {d}ms\n", .{timer.read() / std.time.ns_per_ms}); + timer.reset(); + + // process per bin + // no I/O so we use cpu count exact + n_threads = @as(u32, @truncate(try std.Thread.getCpuCount())); + chunk_size = n_bins / n_threads; + if (chunk_size == 0) { + n_threads = 1; + } + std.debug.print("starting {d} threads with {d} bins per thread\n", .{ n_threads, chunk_size }); + + start_index = 0; + end_index = 0; + + handles.clearRetainingCapacity(); + try handles.ensureTotalCapacity(n_threads); + + for (0..n_threads) |i| { + if (i == (n_threads - 1)) { + end_index = n_bins; + } else { + end_index = start_index + chunk_size; + } + + const handle = try std.Thread.spawn(.{}, sortThreadBins, .{ + allocator, + thread_bins, + start_index, + end_index, + }); + handles.appendAssumeCapacity(handle); + start_index = end_index; + } + std.debug.assert(end_index == n_bins); + + for (handles.items) |handle| { + handle.join(); + } + std.debug.print("\n", .{}); + std.debug.print("done in {d}ms\n", .{timer.read() / std.time.ns_per_ms}); + timer.reset(); + + // compute merkle tree over the slices + std.debug.print("computing merkle tree\n", .{}); + var total_count: usize = 0; + for (thread_bins[0].bins) |bin| { + total_count += bin.items.len; + } + + // var dest: [44]u8 = undefined; + var total_lamports: u64 = 0; + var hashes = try ArrayList(Hash).initCapacity(allocator, total_count); + for (thread_bins[0].bins) |bin| { + for (bin.items) |account_info| { + if (account_info.lamports == 0) continue; + // std.debug.print("pubkey: {s} slot: {d} lamports: {d} bin: {d}\n", .{account_info.pubkey.toStringWithBuf(dest[0..44]), account_info.slot, account_info.lamports, bin_i}); + hashes.appendAssumeCapacity(account_info.hash); + total_lamports += account_info.lamports; + } + } + std.debug.print("total lamports: {d}\n", .{total_lamports}); + + const root_hash = try merkleTreeHash(hashes.items, MERKLE_FANOUT); + std.debug.print("merkle root: {any}\n", .{root_hash.*}); + + std.debug.print("done in {d}ms\n", .{full_timer.read() / std.time.ns_per_ms}); +} diff --git a/src/common/merkle_tree.zig b/src/common/merkle_tree.zig new file mode 100644 index 000000000..f153fb358 --- /dev/null +++ b/src/common/merkle_tree.zig @@ -0,0 +1,39 @@ +const std = @import("std"); +const Hash = @import("../core/hash.zig").Hash; +const Sha256 = std.crypto.hash.sha2.Sha256; + +pub fn merkleTreeHash(hashes: []Hash, fanout: usize) !*Hash { + var length = hashes.len; + while (true) { + const chunks = try std.math.divCeil(usize, length, fanout); + var index: usize = 0; + for (0..chunks) |i| { + const start = i * fanout; + const end = @min(start + fanout, length); + + var hasher = Sha256.init(.{}); + for (start..end) |j| { + hasher.update(&hashes[j].data); + } + var hash = hasher.finalResult(); + hashes[index] = Hash{ .data = hash }; + index += 1; + } + length = index; + if (length == 1) { + return &hashes[0]; + } + } +} + +test "common.merkle_tree: test tree impl" { + const init_length: usize = 10; + var hashes: [init_length]Hash = undefined; + for (&hashes, 0..) |*hash, i| { + hash.* = Hash{ .data = [_]u8{@intCast(i)} ** 32 }; + } + + const root = try merkleTreeHash(&hashes, 3); + const expected_root: [32]u8 = .{ 56, 239, 163, 39, 169, 252, 144, 195, 85, 228, 99, 82, 225, 185, 237, 141, 186, 90, 36, 220, 86, 140, 59, 47, 18, 172, 250, 231, 79, 178, 51, 100 }; + try std.testing.expect(std.mem.eql(u8, &expected_root, &root.data)); +} diff --git a/src/core/account.zig b/src/core/account.zig index 628684596..16d39c101 100644 --- a/src/core/account.zig +++ b/src/core/account.zig @@ -8,3 +8,68 @@ pub const Account = struct { executable: bool, rent_epoch: Epoch, }; + +const std = @import("std"); +const Blake3 = std.crypto.hash.Blake3; +const Hash = @import("./hash.zig").Hash; + +pub fn hashAccount( + lamports: u64, + data: []u8, + owner_pubkey_data: []const u8, + executable: bool, + rent_epoch: u64, + address_pubkey_data: []const u8, +) Hash { + var hasher = Blake3.init(.{}); + var hash_buf: [32]u8 = undefined; + + var int_buf: [8]u8 = undefined; + std.mem.writeIntLittle(u64, &int_buf, lamports); + hasher.update(&int_buf); + + std.mem.writeIntLittle(u64, &int_buf, rent_epoch); + hasher.update(&int_buf); + + hasher.update(data); + + if (executable) { + hasher.update(&[_]u8{1}); + } else { + hasher.update(&[_]u8{0}); + } + + hasher.update(owner_pubkey_data); + hasher.update(address_pubkey_data); + + hasher.final(&hash_buf); + const hash = Hash{ + .data = hash_buf, + }; + + return hash; +} + +test "core.account: test account hash matches rust" { + var data: [3]u8 = .{ 1, 2, 3 }; + var account = Account{ + .lamports = 10, + .data = &data, + .owner = Pubkey.default(), + .executable = false, + .rent_epoch = 20, + }; + const pubkey = Pubkey.default(); + + const hash = hashAccount( + account.lamports, + account.data, + &account.owner.data, + account.executable, + account.rent_epoch, + &pubkey.data, + ); + + const expected_hash: [32]u8 = .{ 170, 75, 87, 73, 60, 156, 174, 14, 105, 6, 129, 108, 167, 156, 166, 213, 28, 4, 163, 187, 252, 155, 24, 253, 158, 13, 86, 100, 103, 89, 232, 28 }; + try std.testing.expect(std.mem.eql(u8, &expected_hash, &hash.data)); +} diff --git a/src/core/accounts_db.zig b/src/core/accounts_db.zig new file mode 100644 index 000000000..3c263b9cc --- /dev/null +++ b/src/core/accounts_db.zig @@ -0,0 +1,528 @@ +const std = @import("std"); +const ArrayList = std.ArrayList; + +const Account = @import("../core/account.zig").Account; +const Hash = @import("../core/hash.zig").Hash; +const Slot = @import("../core/clock.zig").Slot; +const Pubkey = @import("../core/pubkey.zig").Pubkey; +const bincode = @import("../bincode/bincode.zig"); + +const AccountsDbFields = @import("../core/snapshot_fields.zig").AccountsDbFields; +const AppendVecInfo = @import("../core/snapshot_fields.zig").AppendVecInfo; + +const AppendVec = @import("../core/append_vec.zig").AppendVec; +const TmpPubkey = @import("../core/append_vec.zig").TmpPubkey; +const alignToU64 = @import("../core/append_vec.zig").alignToU64; + +const ThreadPool = @import("../sync/thread_pool.zig").ThreadPool; +const Task = ThreadPool.Task; +const Batch = ThreadPool.Batch; +const Channel = @import("../sync/channel.zig").Channel; + +const hashAccount = @import("../core/account.zig").hashAccount; +const merkleTreeHash = @import("../common/merkle_tree.zig").merkleTreeHash; + +pub const MERKLE_FANOUT: usize = 16; + +pub const FileId = usize; +pub const AccountRef = struct { + slot: Slot, + file_id: FileId, + offset: usize, +}; + +pub const AccountsDB = struct { + account_files: std.AutoArrayHashMap(FileId, AppendVec), + index: std.AutoArrayHashMap(TmpPubkey, ArrayList(AccountRef)), + + pub fn init(alloc: std.mem.Allocator) AccountsDB { + return AccountsDB{ + .account_files = std.AutoArrayHashMap(FileId, AppendVec).init(alloc), + .index = std.AutoArrayHashMap(TmpPubkey, ArrayList(AccountRef)).init(alloc), + }; + } +}; + +// accounts-db { +// accounts-files: hashmap +// index: hashmap +// } + +// read account files +// thread1: +// open append_vec +// generate vec +// send vec to channel + +// thread2: +// index vec + +// once all index +// compute_accounts_hash(max_slot) +// iterate over the index and get accounts +// bin the pubkeys +// run sorting algo across bins +// get the full hash across bins +// compute the merkle tree + +// dump_to_csv(max_slot) +// iterate over the index and get accounts +// look up the full accounts in the accounts-db +// dump to csv + +// iterate over the index and get accounts +// get their hash +// compute the merkle tree + +// dump_to_csv(max_slot) +// iterate over the index and get accounts +// look up the full accounts in the accounts-db +// dump to csv + +const PubkeyAccountRef = struct { + pubkey: TmpPubkey, + offset: usize, + slot: Slot, +}; + +const AccountFileChannel = Channel(struct { AppendVec, ArrayList(PubkeyAccountRef) }); + +pub fn openFiles( + allocator: std.mem.Allocator, + accounts_db_fields: *const AccountsDbFields, + accounts_dir_path: []const u8, + // task specific + file_names: [][]const u8, + channel: *AccountFileChannel, +) !void { + // estimate of how many accounts per append vec + const ACCOUNTS_PER_FILE_EST = 20_000; + var refs = try ArrayList(PubkeyAccountRef).initCapacity(allocator, ACCOUNTS_PER_FILE_EST); + + // NOTE: might need to be longer depending on abs path length + var abs_path_buf: [1024]u8 = undefined; + for (file_names) |file_name| { + // parse "{slot}.{id}" from the file_name + var fiter = std.mem.tokenizeSequence(u8, file_name, "."); + const slot = try std.fmt.parseInt(Slot, fiter.next().?, 10); + const append_vec_id = try std.fmt.parseInt(usize, fiter.next().?, 10); + + // read metadata + const slot_metas: ArrayList(AppendVecInfo) = accounts_db_fields.map.get(slot).?; + std.debug.assert(slot_metas.items.len == 1); + const slot_meta = slot_metas.items[0]; + std.debug.assert(slot_meta.id == append_vec_id); + + // read appendVec from file + const abs_path = try std.fmt.bufPrint(&abs_path_buf, "{s}/{s}", .{ accounts_dir_path, file_name }); + const append_vec_file = try std.fs.openFileAbsolute(abs_path, .{ .mode = .read_write }); + var append_vec = AppendVec.init(append_vec_file, slot_meta, slot) catch |err| { + var buf: [1024]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + var writer = stream.writer(); + try std.fmt.format(writer, "failed to *open* appendVec {s}: {s}", .{ file_name, @errorName(err) }); + @panic(stream.getWritten()); + }; + + sanitizeAndParseAccounts(&append_vec, &refs) catch |err| { + var buf: [1024]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + var writer = stream.writer(); + try std.fmt.format(writer, "failed to *sanitize* appendVec {s}: {s}", .{ file_name, @errorName(err) }); + @panic(stream.getWritten()); + }; + + try channel.send(.{ append_vec, refs }); + + // re-allocate + refs = try ArrayList(PubkeyAccountRef).initCapacity(allocator, ACCOUNTS_PER_FILE_EST); + } +} + +pub fn sanitizeAndParseAccounts(append_vec: *AppendVec, refs: *ArrayList(PubkeyAccountRef)) !void { + var offset: usize = 0; + var n_accounts: usize = 0; + + while (true) { + var account = append_vec.getAccount(offset) catch break; + try account.sanitize(); + + const pubkey = account.store_info.pubkey; + + const hash_is_missing = std.mem.eql(u8, &account.hash.data, &Hash.default().data); + if (hash_is_missing) { + const hash = hashAccount( + account.account_info.lamports, + account.data, + &account.account_info.owner.data, + account.account_info.executable, + account.account_info.rent_epoch, + &pubkey.data, + ); + account.hash.* = hash; + } + + try refs.append(PubkeyAccountRef{ + .pubkey = pubkey, + .offset = offset, + .slot = append_vec.slot, + }); + + offset = offset + account.len; + n_accounts += 1; + } + + if (offset != alignToU64(append_vec.length)) { + return error.InvalidAppendVecLength; + } + + append_vec.n_accounts = n_accounts; +} + +pub fn recvFilesAndIndex( + allocator: std.mem.Allocator, + channel: *AccountFileChannel, + accounts_db: *AccountsDB, + total_files: usize, +) !void { + var timer = try std.time.Timer.start(); + var file_count: usize = 0; + + while (true) { + const maybe_task_outputs = channel.try_drain() catch unreachable; + var task_outputs = maybe_task_outputs orelse continue; + defer channel.allocator.free(task_outputs); + + for (task_outputs) |task_output| { + const account_file: AppendVec = task_output[0]; + const refs: ArrayList(PubkeyAccountRef) = task_output[1]; + defer refs.deinit(); + + // track the file + try accounts_db.account_files.putNoClobber(account_file.id, account_file); + + // populate index + for (refs.items) |account_ref| { + var entry = try accounts_db.index.getOrPut(account_ref.pubkey); + if (!entry.found_existing) { + entry.value_ptr.* = ArrayList(AccountRef).init(allocator); + } + + try entry.value_ptr.append(AccountRef{ + .file_id = account_file.id, + .offset = account_ref.offset, + .slot = account_ref.slot, + }); + } + + file_count += 1; + if (file_count % 1000 == 0 or file_count < 1000) { + printTimeEstimate(&timer, total_files, file_count, "recvFilesAndIndex"); + if (file_count == total_files) return; + } + } + } +} + +pub fn printTimeEstimate( + // timer should be started at the beginning + timer: *std.time.Timer, + total: usize, + i: usize, + comptime name: []const u8, +) void { + if (i == 0 or total == 0) return; + + const p_done = i * 100 / total; + const left = total - i; + + const elapsed = timer.read(); + const ns_per_vec = elapsed / i; + const time_left = ns_per_vec * left; + + const min_left = time_left / std.time.ns_per_min; + const sec_left = (time_left / std.time.ns_per_s) - (min_left * std.time.s_per_min); + + if (sec_left < 10) { + std.debug.print("{s}: {d}/{d} ({d}%) (time left: {d}:0{d})\r", .{ + name, + i, + total, + p_done, + min_left, + sec_left, + }); + } else { + std.debug.print("{s}: {d}/{d} ({d}%) (time left: {d}:{d})\r", .{ + name, + i, + total, + p_done, + min_left, + sec_left, + }); + } +} + +pub fn readDirectory( + allocator: std.mem.Allocator, + directory: std.fs.IterableDir, +) !struct { filenames: ArrayList([]u8), mem: []u8 } { + var dir_iter = directory.iterate(); + var total_name_size: usize = 0; + var total_files: usize = 0; + while (try dir_iter.next()) |entry| { + total_name_size += entry.name.len; + total_files += 1; + } + var mem = try allocator.alloc(u8, total_name_size); + errdefer allocator.free(mem); + + dir_iter = directory.iterate(); // reset + + var filenames = try ArrayList([]u8).initCapacity(allocator, total_files); + errdefer filenames.deinit(); + + var index: usize = 0; + while (try dir_iter.next()) |file_entry| { + const file_name_len = file_entry.name.len; + @memcpy(mem[index..(index + file_name_len)], file_entry.name); + filenames.appendAssumeCapacity(mem[index..(index + file_name_len)]); + index += file_name_len; + } + dir_iter = directory.iterate(); // reset + + return .{ .filenames = filenames, .mem = mem }; +} + +const PubkeyBinCalculator = @import("../cmd/snapshot_verify.zig").PubkeyBinCalculator; +pub const PUBKEY_BINS_FOR_CALCULATING_HASHES: usize = 65_536; + +pub const PubkeyBins = struct { + bins: []BinType, + calculator: PubkeyBinCalculator, + + const BinType = ArrayList(*const TmpPubkey); + + pub fn init(allocator: std.mem.Allocator, n_bins: usize) !PubkeyBins { + const calculator = PubkeyBinCalculator.init(n_bins); + + var bins = try allocator.alloc(BinType, n_bins); + for (bins) |*bin| { + const INIT_BUCKET_LENGTH = 1_000; + bin.* = try BinType.initCapacity(allocator, INIT_BUCKET_LENGTH); + } + + return PubkeyBins{ + .bins = bins, + .calculator = calculator, + }; + } + + pub fn deinit(self: *PubkeyBins) void { + const allocator = self.bins[0].allocator; + for (self.bins) |*bin| { + bin.deinit(); + } + allocator.free(self.bins); + } + + pub fn insert(self: *PubkeyBins, pubkey: *const TmpPubkey) !void { + const bin_index = self.calculator.binIndex(pubkey); + try self.bins[bin_index].append(pubkey); + } +}; + +pub fn sortBins( + bins: []ArrayList(*const TmpPubkey), + bin_start_index: usize, + bin_end_index: usize, +) !void { + const total_bins = bin_end_index - bin_start_index; + var timer = try std.time.Timer.start(); + + for (bin_start_index..bin_end_index, 1..) |bin_i, count| { + var bin = bins[bin_i]; + + std.mem.sort(*const TmpPubkey, bin.items, {}, struct { + fn lessThan(_: void, lhs: *const TmpPubkey, rhs: *const TmpPubkey) bool { + return std.mem.lessThan(u8, &lhs.data, &rhs.data); + } + }.lessThan); + + if (count % 1000 == 0) { + printTimeEstimate(&timer, total_bins, count, "sortBins"); + } + } +} + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + var allocator = gpa.allocator(); + + const snapshot_path = "/Users/tmp/Documents/zig-solana/snapshots"; + + const accounts_dir_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_path, "accounts" }, + ); + const accounts_db_fields_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}", + .{ snapshot_path, "accounts_db.bincode" }, + ); + + // time it + var full_timer = try std.time.Timer.start(); + var timer = try std.time.Timer.start(); + + var accounts_dir = try std.fs.openIterableDirAbsolute(accounts_dir_path, .{}); + var files = try readDirectory(allocator, accounts_dir); + var filenames = files.filenames; + defer { + filenames.deinit(); + allocator.free(files.mem); + } + var n_account_files: usize = filenames.items.len; + std.debug.print("n_account_files: {d}\n", .{n_account_files}); + + // read accounts_db.bincode + const accounts_db_fields_file = std.fs.openFileAbsolute(accounts_db_fields_path, .{}) catch |err| { + std.debug.print("failed to open accounts-db fields file: {s} ... skipping test\n", .{@errorName(err)}); + return; + }; + var accounts_db_fields = try bincode.read(allocator, AccountsDbFields, accounts_db_fields_file.reader(), .{}); + defer bincode.free(allocator, accounts_db_fields); + + // init db + var accounts_db = AccountsDB.init(allocator); + + // start processing + var n_threads = @as(u32, @truncate(try std.Thread.getCpuCount())) * 2; + var handles = try ArrayList(std.Thread).initCapacity(allocator, n_threads); + var chunk_size = n_account_files / n_threads; + if (chunk_size == 0) { + n_threads = 1; + } + std.debug.print("starting {d} threads with {d} files per thread\n", .{ n_threads, chunk_size }); + + var channel = AccountFileChannel.init(allocator, 10_000); + defer channel.deinit(); + + var start_index: usize = 0; + var end_index: usize = 0; + + // + for (0..n_threads) |i| { + if (i == (n_threads - 1)) { + end_index = n_account_files; + } else { + end_index = start_index + chunk_size; + } + + const handle = try std.Thread.spawn(.{}, openFiles, .{ + allocator, + &accounts_db_fields, + accounts_dir_path, + filenames.items[start_index..end_index], + channel, + }); + handles.appendAssumeCapacity(handle); + start_index = end_index; + } + std.debug.assert(end_index == n_account_files); + + try recvFilesAndIndex(allocator, channel, &accounts_db, n_account_files); + + for (handles.items) |handle| { + handle.join(); + } + std.debug.print("\n", .{}); + std.debug.print("done in {d}ms\n", .{timer.read() / std.time.ns_per_ms}); + timer.reset(); + + // sort the pubkeys + std.debug.print("initializing pubkey bins\n", .{}); + const n_bins = 128; + var bins = try PubkeyBins.init(allocator, n_bins); + for (accounts_db.index.keys()) |*pubkey| { + try bins.insert(pubkey); + } + + n_threads = @as(u32, @truncate(try std.Thread.getCpuCount())); + chunk_size = n_bins / n_threads; + if (chunk_size == 0) { + n_threads = 1; + } + handles.clearRetainingCapacity(); + std.debug.print("starting {d} threads with {d} bins per thread\n", .{ n_threads, chunk_size }); + + start_index = 0; + end_index = 0; + + for (0..n_threads) |i| { + if (i == (n_threads - 1)) { + end_index = n_bins; + } else { + end_index = start_index + chunk_size; + } + + var handle = try std.Thread.spawn(.{}, sortBins, .{ + bins.bins, + start_index, + end_index, + }); + + handles.appendAssumeCapacity(handle); + start_index = end_index; + } + std.debug.assert(end_index == n_bins); + + for (handles.items) |handle| { + handle.join(); + } + std.debug.print("\n", .{}); + std.debug.print("done in {d}ms\n", .{timer.read() / std.time.ns_per_ms}); + timer.reset(); + + // compute merkle tree over the slices + std.debug.print("computing merkle tree\n", .{}); + var total_count: usize = 0; + for (bins.bins) |*bin| { + total_count += bin.items.len; + } + + var total_lamports: u64 = 0; + var hashes = try ArrayList(Hash).initCapacity(allocator, total_count); + for (bins.bins) |*bin| { + for (bin.items) |pubkey| { + const account_states = accounts_db.index.get(pubkey.*).?; + var max_slot_index: ?usize = null; + var max_slot: usize = 0; + for (account_states.items, 0..) |account_info, i| { + if (max_slot_index == null or max_slot < account_info.slot) { + max_slot = account_info.slot; + max_slot_index = i; + } + } + const newest_account_loc = account_states.items[max_slot_index.?]; + const append_vec: AppendVec = accounts_db.account_files.get(newest_account_loc.file_id).?; + const account = try append_vec.getAccount(newest_account_loc.offset); + const lamports = account.account_info.lamports; + + if (account.account_info.lamports == 0) continue; + // std.debug.print("pubkey: {s} slot: {d} lamports: {d} bin: {d}\n", .{account_info.pubkey.toStringWithBuf(dest[0..44]), account_info.slot, account_info.lamports, bin_i}); + hashes.appendAssumeCapacity(account.hash.*); + total_lamports += lamports; + } + } + std.debug.print("total lamports: {d}\n", .{total_lamports}); + + const root_hash = try merkleTreeHash(hashes.items, MERKLE_FANOUT); + std.debug.print("merkle root: {any}\n", .{root_hash.*}); + + std.debug.print("\n", .{}); + std.debug.print("done in {d}ms\n", .{timer.read() / std.time.ns_per_ms}); + timer.reset(); +} diff --git a/src/core/append_vec.zig b/src/core/append_vec.zig index 5339b4b20..c5836cec8 100644 --- a/src/core/append_vec.zig +++ b/src/core/append_vec.zig @@ -9,17 +9,26 @@ const Epoch = @import("./clock.zig").Epoch; const Pubkey = @import("./pubkey.zig").Pubkey; const bincode = @import("../bincode/bincode.zig"); -const SnapshotFields = @import("./snapshot_fields.zig").SnapshotFields; const AccountsDbFields = @import("./snapshot_fields.zig").AccountsDbFields; const AppendVecInfo = @import("./snapshot_fields.zig").AppendVecInfo; const base58 = @import("base58-zig"); -pub const TmpPubkey = struct { +pub const TmpPubkey = extern struct { data: [32]u8, // note: need to remove cached string to have correct ptr casting - pub fn toString(self: *const TmpPubkey) error{EncodingError}![44]u8 { + pub fn toStringWithBuf(self: *const TmpPubkey, dest: []u8) []u8 { + @memset(dest, 0); + const encoder = base58.Encoder.init(.{}); + var written = encoder.encode(&self.data, dest) catch unreachable; + if (written > 44) { + std.debug.panic("written is > 44, written: {}, dest: {any}, bytes: {any}", .{ written, dest, self.data }); + } + return dest[0..written]; + } + + pub fn toString(self: *const TmpPubkey) error{EncodingError}![]u8 { var dest: [44]u8 = undefined; @memset(&dest, 0); @@ -28,7 +37,17 @@ pub const TmpPubkey = struct { if (written > 44) { std.debug.panic("written is > 44, written: {}, dest: {any}, bytes: {any}", .{ written, dest, self.data }); } - return dest; + return dest[0..written]; + } + + pub fn fromString(str: []const u8) !TmpPubkey { + var pubkey = TmpPubkey{ .data = [_]u8{0} ** 32 }; + const decoder = base58.Decoder.init(.{}); + const size = try decoder.decode(str, &pubkey.data); + if (size != 32) { + return error.EncodingError; + } + return pubkey; } pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) std.os.WriteError!void { @@ -39,6 +58,14 @@ pub const TmpPubkey = struct { pub fn isDefault(self: *const TmpPubkey) bool { return std.mem.eql(u8, &self.data, &[_]u8{0} ** 32); } + + pub fn default() TmpPubkey { + return TmpPubkey{ .data = [_]u8{0} ** 32 }; + } + + pub fn equals(self: *const TmpPubkey, other: *const TmpPubkey) bool { + return std.mem.eql(u8, &self.data, &other.data); + } }; // metadata which is stored inside an AppendVec @@ -85,8 +112,14 @@ pub const AppendVecAccountInfo = struct { } }; +pub const PubkeyAndAccountInAppendVecRef = struct { + pubkey: TmpPubkey, + account_ref: AccountInAppendVecRef, + // hash: Hash, +}; + const u64_size: usize = @sizeOf(u64); -inline fn alignToU64(addr: usize) usize { +pub inline fn alignToU64(addr: usize) usize { return (addr + (u64_size - 1)) & ~(u64_size - 1); } @@ -94,6 +127,7 @@ pub const AppendVec = struct { // file contents mmap_ptr: []align(std.mem.page_size) u8, id: usize, + slot: Slot, // number of bytes used length: usize, // total bytes available @@ -105,7 +139,7 @@ pub const AppendVec = struct { const Self = @This(); - pub fn init(file: std.fs.File, append_vec_info: AppendVecInfo) !Self { + pub fn init(file: std.fs.File, append_vec_info: AppendVecInfo, slot: Slot) !Self { const file_stat = try file.stat(); const file_size: u64 = @intCast(file_stat.size); @@ -126,6 +160,7 @@ pub const AppendVec = struct { .id = append_vec_info.id, .file_size = file_size, .file = file, + .slot = slot, }; } @@ -138,7 +173,6 @@ pub const AppendVec = struct { var offset: usize = 0; var n_accounts: usize = 0; - // parse all the accounts out of the append vec while (true) { const account = self.getAccount(offset) catch break; try account.sanitize(); @@ -190,6 +224,95 @@ pub const AppendVec = struct { const length = @sizeOf(T); return @alignCast(@ptrCast(try self.getSlice(start_index_ptr, length))); } + + pub fn getAccountsRefs(self: *const Self, allocator: std.mem.Allocator) !ArrayList(PubkeyAndAccountInAppendVecRef) { + var accounts = try ArrayList(PubkeyAndAccountInAppendVecRef).initCapacity(allocator, self.n_accounts); + + var offset: usize = 0; + while (true) { + const account = self.getAccount(offset) catch break; + const pubkey = account.store_info.pubkey; + + const pubkey_account_ref = PubkeyAndAccountInAppendVecRef{ + .pubkey = pubkey, + .account_ref = .{ + .slot = self.slot, + .offset = offset, + .append_vec_id = self.id, + }, + // .hash = Hash.default(), + }; + + accounts.appendAssumeCapacity(pubkey_account_ref); + offset = offset + account.len; + } + + return accounts; + } +}; + +pub const AccountInAppendVecRef = struct { + slot: usize, + append_vec_id: usize, + offset: usize, +}; + +pub const AccountsIndex = struct { + // only support RAM for now + ram_map: HashMap(TmpPubkey, ArrayList(AccountInAppendVecRef)), + // TODO: disk_map + + const Self = @This(); + + pub fn init(allocator: std.mem.Allocator) Self { + return Self{ + .ram_map = HashMap(TmpPubkey, ArrayList(AccountInAppendVecRef)).init(allocator), + }; + } + + pub fn deinit(self: *Self) void { + var iter = self.ram_map.iterator(); + while (iter.next()) |*entry| { + entry.value_ptr.deinit(); + } + self.ram_map.deinit(); + } + + pub fn insertNewAccountRef( + self: *Self, + pubkey: TmpPubkey, + account_ref: AccountInAppendVecRef, + ) !void { + var maybe_entry = self.ram_map.getEntry(pubkey); + + // if the pubkey already exists + if (maybe_entry) |*entry| { + var existing_refs: *ArrayList(AccountInAppendVecRef) = entry.value_ptr; + + // search: if slot already exists, replace the value + var found_matching_slot = false; + for (existing_refs.items) |*existing_ref| { + if (existing_ref.slot == account_ref.slot) { + if (!found_matching_slot) { + existing_ref.* = account_ref; + found_matching_slot = true; + break; + } + // TODO: rust impl continues to scan and removes other slot duplicates + // do we need to do this? + } + } + + // otherwise we append the new slot + if (!found_matching_slot) { + try existing_refs.append(account_ref); + } + } else { + var account_refs = try ArrayList(AccountInAppendVecRef).initCapacity(self.ram_map.allocator, 1); + account_refs.appendAssumeCapacity(account_ref); + try self.ram_map.putNoClobber(pubkey, account_refs); + } + } }; test "core.append_vec: parse accounts out of append vec" { @@ -200,7 +323,7 @@ test "core.append_vec: parse accounts out of append vec" { // 3) run the test const alloc = std.testing.allocator; - const accounts_db_fields_path = "/Users/tmp2/Documents/zig-solana/snapshots/accounts_db.bincode"; + const accounts_db_fields_path = "/Users/tmp/Documents/zig-solana/snapshots/accounts_db.bincode"; const accounts_db_fields_file = std.fs.openFileAbsolute(accounts_db_fields_path, .{}) catch |err| { std.debug.print("failed to open accounts-db fields file: {s} ... skipping test\n", .{@errorName(err)}); return; @@ -209,64 +332,20 @@ test "core.append_vec: parse accounts out of append vec" { var accounts_db_fields = try bincode.read(alloc, AccountsDbFields, accounts_db_fields_file.reader(), .{}); defer bincode.free(alloc, accounts_db_fields); - // - var storage = HashMap(Slot, AppendVec).init(alloc); - defer { - var iter = storage.iterator(); - while (iter.next()) |*entry| { - entry.value_ptr.deinit(); - } - storage.deinit(); - } - - var n_appendvec: usize = 0; - var n_valid_appendvec: usize = 0; - - // const accounts_dir_path = "/Users/tmp/Documents/zig-solana/snapshots/accounts"; - var accounts_dir = std.fs.openIterableDirAbsolute(accounts_dir_path, .{}) catch |err| { - std.debug.print("failed to open accounts dir: {s} ... skipping test\n", .{@errorName(err)}); - return; - }; - var accounts_dir_iter = accounts_dir.iterate(); - - while (try accounts_dir_iter.next()) |entry| { - var filename: []const u8 = entry.name; - - // parse "{slot}.{id}" from the filename - var fiter = std.mem.tokenizeSequence(u8, filename, "."); - const slot = try std.fmt.parseInt(Slot, fiter.next().?, 10); - const append_vec_id = try std.fmt.parseInt(usize, fiter.next().?, 10); - - // read metadata - const slot_metas: ArrayList(AppendVecInfo) = accounts_db_fields.map.get(slot).?; - std.debug.assert(slot_metas.items.len == 1); - const slot_meta = slot_metas.items[0]; - std.debug.assert(slot_meta.id == append_vec_id); - - // read appendVec from file - var abs_path_buf: [1024]u8 = undefined; - const abs_path = try std.fmt.bufPrint(&abs_path_buf, "{s}/{s}", .{ accounts_dir_path, filename }); - const append_vec_file = try std.fs.openFileAbsolute(abs_path, .{ .mode = .read_write }); - n_appendvec += 1; - - var append_vec = AppendVec.init(append_vec_file, slot_meta) catch continue; - - // verify its valid - append_vec.sanitize() catch { - append_vec.deinit(); - continue; - }; - n_valid_appendvec += 1; + _ = accounts_dir_path; - // note: newer snapshots shouldnt clobber - try storage.putNoClobber(slot, append_vec); + // // time it + // var timer = try std.time.Timer.start(); - // dont open too many files (just testing) - if (n_appendvec == 10) break; - } + // var accounts_db = AccountsDB.init(alloc); + // defer accounts_db.deinit(); + // try accounts_db.load(alloc, accounts_db_fields, accounts_dir_path, null); + + // const elapsed = timer.read(); + // std.debug.print("elapsed: {d}\n", .{elapsed / std.time.ns_per_s}); // note: didnt untar the full snapshot (bc time) // n_valid_appendvec: 328_811, total_append_vec: 328_812 - std.debug.print("n_valid_appendvec: {d}, total_append_vec: {d}\n", .{ n_valid_appendvec, n_appendvec }); + // std.debug.print("n_valid_appendvec: {d}, total_append_vec: {d}\n", .{ n_valid_appendvec, n_appendvec }); } diff --git a/src/core/snapshot_fields.zig b/src/core/snapshot_fields.zig index b21bf6312..62ad90ee6 100644 --- a/src/core/snapshot_fields.zig +++ b/src/core/snapshot_fields.zig @@ -294,7 +294,7 @@ pub const SnapshotFields = struct { /// NOTE: should call this to get the correct bank_fields instead of accessing it directly /// due to the way snapshot deserialization works - pub fn getFields(self: *@This()) struct { bank_fields: BankFields, accounts_db_fields: AccountsDbFields } { + pub fn getFieldRefs(self: *@This()) struct { bank_fields: *const BankFields, accounts_db_fields: *const AccountsDbFields } { var bank_fields = &self.bank_fields; // if these are availabel they will be parsed (and likely not the default values) // so, we push them on the bank fields here @@ -303,7 +303,21 @@ pub const SnapshotFields = struct { bank_fields.epoch_accounts_hash = self.epoch_accounts_hash; bank_fields.epoch_reward_status = self.epoch_reward_status; - return .{ .bank_fields = self.bank_fields, .accounts_db_fields = self.accounts_db_fields }; + return .{ .bank_fields = bank_fields, .accounts_db_fields = &self.accounts_db_fields }; + } + + pub fn readFromFilePath(allocator: std.mem.Allocator, abs_path: []const u8) !SnapshotFields { + var file = try std.fs.openFileAbsolute(abs_path, .{}); + defer file.close(); + + var file_reader = std.io.bufferedReader(file.reader()); + const file_size = (try file.stat()).size; + + var buf = try std.ArrayList(u8).initCapacity(allocator, file_size); + defer buf.deinit(); + + var snapshot_fields = try bincode.read(allocator, SnapshotFields, file_reader.reader(), .{}); + return snapshot_fields; } }; @@ -316,35 +330,18 @@ test "core.snapshot_fields: parse snapshot fields" { // 4) run this // const snapshot_path = "/test_data/slot/slot"; - const snapshot_path = "/Users/tmp2/Documents/zig-solana/snapshots/snapshots/225552163/225552163"; + const snapshot_path = "/Users/tmp/Documents/zig-solana/snapshots/snapshots/225552163/225552163"; const alloc = std.testing.allocator; - // open file - var file = std.fs.openFileAbsolute(snapshot_path, .{}) catch |err| { - std.debug.print("failed to open snapshot file: {s} ... skipping test\n", .{@errorName(err)}); - return; + var snapshot_fields = SnapshotFields.readFromFilePath(alloc, snapshot_path) catch |err| { + if (err == std.fs.File.OpenError.FileNotFound) { + std.debug.print("failed to open snapshot fields file: {s} ... skipping test\n", .{@errorName(err)}); + return; + } + return err; }; - defer file.close(); - - var file_reader = std.io.bufferedReader(file.reader()); - const file_size = (try file.stat()).size; - - var buf = try std.ArrayList(u8).initCapacity(alloc, file_size); - defer buf.deinit(); - - var snapshot_fields = try bincode.read(alloc, SnapshotFields, file_reader.reader(), .{}); defer bincode.free(alloc, snapshot_fields); - const fields = snapshot_fields.getFields(); - - // rewrite the accounts_db_fields seperate - var db_buf = try bincode.writeToArray(alloc, fields.accounts_db_fields, .{}); - defer db_buf.deinit(); - - // write buf to a file - const accounts_db_path = "/Users/tmp/Documents/zig-solana/snapshots/accounts_db.bincode"; - const db_file = try std.fs.createFileAbsolute(accounts_db_path, .{}); - defer db_file.close(); - - _ = try db_file.write(db_buf.items); + const fields = snapshot_fields.getFieldRefs(); + _ = fields; } diff --git a/src/lib.zig b/src/lib.zig index 3b2d4717a..897addfd9 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -17,6 +17,7 @@ pub const core = struct { pub usingnamespace @import("core/genesis_config.zig"); pub usingnamespace @import("core/snapshot_fields.zig"); pub usingnamespace @import("core/append_vec.zig"); + pub usingnamespace @import("core/accounts_db.zig"); }; pub const gossip = struct { diff --git a/src/sync/thread_pool.zig b/src/sync/thread_pool.zig new file mode 100644 index 000000000..341c63a31 --- /dev/null +++ b/src/sync/thread_pool.zig @@ -0,0 +1,1319 @@ +// Thank you bun.sh: +// https://github.com/oven-sh/bun/blob/main/src/thread_pool.zig +// +// Thank you @kprotty: +// https://github.com/kprotty/zap/blob/blog/src/thread_pool.zig + +const std = @import("std"); +const builtin = @import("builtin"); +const Futex = std.Thread.Futex; +const assert = std.debug.assert; +const Atomic = std.atomic.Atomic; +pub const OnSpawnCallback = *const fn (ctx: ?*anyopaque) ?*anyopaque; + +pub const ThreadPool = struct { + sleep_on_idle_network_thread: bool = true, + /// executed on the thread + on_thread_spawn: ?OnSpawnCallback = null, + threadpool_context: ?*anyopaque = null, + stack_size: u32, + max_threads: u32, + sync: Atomic(u32) = Atomic(u32).init(@as(u32, @bitCast(Sync{}))), + idle_event: Event = .{}, + join_event: Event = .{}, + run_queue: Node.Queue = .{}, + threads: Atomic(?*Thread) = Atomic(?*Thread).init(null), + name: []const u8 = "", + spawned_thread_count: Atomic(u32) = Atomic(u32).init(0), + + const Sync = packed struct { + /// Tracks the number of threads not searching for Tasks + idle: u14 = 0, + /// Tracks the number of threads spawned + spawned: u14 = 0, + /// What you see is what you get + unused: bool = false, + /// Used to not miss notifications while state = waking + notified: bool = false, + /// The current state of the thread pool + state: enum(u2) { + /// A notification can be issued to wake up a sleeping as the "waking thread". + pending = 0, + /// The state was notified with a signal. A thread is woken up. + /// The first thread to transition to `waking` becomes the "waking thread". + signaled, + /// There is a "waking thread" among us. + /// No other thread should be woken up until the waking thread transitions the state. + waking, + /// The thread pool was terminated. Start decremented `spawned` so that it can be joined. + shutdown, + } = .pending, + }; + + /// Configuration options for the thread pool. + /// TODO: add CPU core affinity? + pub const Config = struct { + stack_size: u32 = (std.Thread.SpawnConfig{}).stack_size, + max_threads: u32 = 1, + }; + + /// Statically initialize the thread pool using the configuration. + pub fn init(config: Config) ThreadPool { + return .{ + .stack_size = @max(1, config.stack_size), + .max_threads = @max(1, config.max_threads), + }; + } + + pub fn wakeForIdleEvents(this: *ThreadPool) void { + // Wake all the threads to check for idle events. + this.idle_event.wake(Event.NOTIFIED, std.math.maxInt(u32)); + } + + /// Wait for a thread to call shutdown() on the thread pool and kill the worker threads. + pub fn deinit(self: *ThreadPool) void { + self.join(); + self.* = undefined; + } + + /// A Task represents the unit of Work / Job / Execution that the ThreadPool schedules. + /// The user provides a `callback` which is invoked when the *Task can run on a thread. + pub const Task = struct { + node: Node = .{}, + callback: *const (fn (*Task) void), + }; + + /// An unordered collection of Tasks which can be submitted for scheduling as a group. + pub const Batch = struct { + len: usize = 0, + head: ?*Task = null, + tail: ?*Task = null, + + pub fn pop(this: *Batch) ?*Task { + const len = @atomicLoad(usize, &this.len, .Monotonic); + if (len == 0) { + return null; + } + var task = this.head.?; + if (task.node.next) |node| { + this.head = @fieldParentPtr(Task, "node", node); + } else { + if (task != this.tail.?) unreachable; + this.tail = null; + this.head = null; + } + + this.len -= 1; + if (len == 0) { + this.tail = null; + } + return task; + } + + /// Create a batch from a single task. + pub fn from(task: *Task) Batch { + return Batch{ + .len = 1, + .head = task, + .tail = task, + }; + } + + /// Another batch into this one, taking ownership of its tasks. + pub fn push(self: *Batch, batch: Batch) void { + if (batch.len == 0) return; + if (self.len == 0) { + self.* = batch; + } else { + self.tail.?.node.next = if (batch.head) |h| &h.node else null; + self.tail = batch.tail; + self.len += batch.len; + } + } + }; + + pub const WaitGroup = struct { + mutex: std.Thread.Mutex = .{}, + counter: u32 = 0, + event: std.Thread.ResetEvent, + + pub fn init(self: *WaitGroup) void { + self.* = .{ + .mutex = .{}, + .counter = 0, + .event = undefined, + }; + } + + pub fn deinit(self: *WaitGroup) void { + self.event.reset(); + self.* = undefined; + } + + pub fn start(self: *WaitGroup) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + self.counter += 1; + } + + pub fn isDone(this: *WaitGroup) bool { + return @atomicLoad(u32, &this.counter, .Monotonic) == 0; + } + + pub fn finish(self: *WaitGroup) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + self.counter -= 1; + + if (self.counter == 0) { + self.event.set(); + } + } + + pub fn wait(self: *WaitGroup) void { + while (true) { + self.mutex.lock(); + + if (self.counter == 0) { + self.mutex.unlock(); + return; + } + + self.mutex.unlock(); + self.event.wait(); + } + } + + pub fn reset(self: *WaitGroup) void { + self.event.reset(); + } + }; + + pub fn ConcurrentFunction( + comptime Function: anytype, + ) type { + return struct { + const Fn = Function; + const Args = std.meta.ArgsTuple(@TypeOf(Fn)); + const Runner = @This(); + thread_pool: *ThreadPool, + states: []Routine = undefined, + batch: Batch = .{}, + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator, thread_pool: *ThreadPool, count: usize) !Runner { + return Runner{ + .allocator = allocator, + .thread_pool = thread_pool, + .states = try allocator.alloc(Routine, count), + .batch = .{}, + }; + } + + pub fn call(this: *@This(), args: Args) void { + this.states[this.batch.len] = .{ + .args = args, + }; + this.batch.push(Batch.from(&this.states[this.batch.len].task)); + } + + pub fn run(this: *@This()) void { + this.thread_pool.schedule(this.batch); + } + + pub const Routine = struct { + args: Args, + task: Task = .{ .callback = callback }, + + pub fn callback(task: *Task) void { + var routine = @fieldParentPtr(@This(), "task", task); + @call(.always_inline, Fn, routine.args); + } + }; + + pub fn deinit(this: *@This()) void { + this.allocator.free(this.states); + } + }; + } + + pub fn runner( + this: *ThreadPool, + allocator: std.mem.Allocator, + comptime Function: anytype, + count: usize, + ) !ConcurrentFunction(Function) { + return try ConcurrentFunction(Function).init(allocator, this, count); + } + + /// Loop over an array of tasks and invoke `Run` on each one in a different thread + /// **Blocks the calling thread** until all tasks are completed. + pub fn do( + this: *ThreadPool, + allocator: std.mem.Allocator, + wg: ?*WaitGroup, + ctx: anytype, + comptime Run: anytype, + values: anytype, + ) !void { + return try Do(this, allocator, wg, @TypeOf(ctx), ctx, Run, @TypeOf(values), values, false); + } + + pub fn doPtr( + this: *ThreadPool, + allocator: std.mem.Allocator, + wg: ?*WaitGroup, + ctx: anytype, + comptime Run: anytype, + values: anytype, + ) !void { + return try Do(this, allocator, wg, @TypeOf(ctx), ctx, Run, @TypeOf(values), values, true); + } + + pub fn Do( + this: *ThreadPool, + allocator: std.mem.Allocator, + wg: ?*WaitGroup, + comptime Context: type, + ctx: Context, + comptime Function: anytype, + comptime ValuesType: type, + values: ValuesType, + comptime as_ptr: bool, + ) !void { + if (values.len == 0) + return; + var allocated_wait_group: ?*WaitGroup = null; + defer { + if (allocated_wait_group) |group| { + group.deinit(); + allocator.destroy(group); + } + } + + var wait_group = wg orelse brk: { + allocated_wait_group = try allocator.create(WaitGroup); + allocated_wait_group.?.init(); + break :brk allocated_wait_group.?; + }; + const WaitContext = struct { + wait_group: *WaitGroup = undefined, + ctx: Context, + values: ValuesType, + }; + + const RunnerTask = struct { + task: Task, + ctx: *WaitContext, + i: usize = 0, + + pub fn call(task: *Task) void { + var runner_task = @fieldParentPtr(@This(), "task", task); + const i = runner_task.i; + if (comptime as_ptr) { + Function(runner_task.ctx.ctx, &runner_task.ctx.values[i], i); + } else { + Function(runner_task.ctx.ctx, runner_task.ctx.values[i], i); + } + + runner_task.ctx.wait_group.finish(); + } + }; + var wait_context = allocator.create(WaitContext) catch unreachable; + wait_context.* = .{ + .ctx = ctx, + .wait_group = wait_group, + .values = values, + }; + defer allocator.destroy(wait_context); + var tasks = allocator.alloc(RunnerTask, values.len) catch unreachable; + defer allocator.free(tasks); + var batch: Batch = undefined; + var offset = tasks.len - 1; + + { + tasks[0] = .{ + .i = offset, + .task = .{ .callback = RunnerTask.call }, + .ctx = wait_context, + }; + batch = Batch.from(&tasks[0].task); + } + if (tasks.len > 1) { + for (tasks[1..]) |*runner_task| { + offset -= 1; + runner_task.* = .{ + .i = offset, + .task = .{ .callback = RunnerTask.call }, + .ctx = wait_context, + }; + batch.push(Batch.from(&runner_task.task)); + } + } + + wait_group.counter += @as(u32, @intCast(values.len)); + this.schedule(batch); + wait_group.wait(); + } + + /// Schedule a batch of tasks to be executed by some thread on the thread pool. + pub fn schedule(self: *ThreadPool, batch: Batch) void { + // Sanity check + if (batch.len == 0) { + return; + } + + // Extract out the Node's from the Tasks + var list = Node.List{ + .head = &batch.head.?.node, + .tail = &batch.tail.?.node, + }; + + // Push the task Nodes to the most appropriate queue + if (Thread.current) |thread| { + thread.run_buffer.push(&list) catch thread.run_queue.push(list); + } else { + self.run_queue.push(list); + } + + forceSpawn(self); + } + + pub fn forceSpawn(self: *ThreadPool) void { + // Try to notify a thread + const is_waking = false; + return self.notify(is_waking); + } + + inline fn notify(self: *ThreadPool, is_waking: bool) void { + // Fast path to check the Sync state to avoid calling into notifySlow(). + // If we're waking, then we need to update the state regardless + if (!is_waking) { + const sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + if (sync.notified) { + return; + } + } + + return self.notifySlow(is_waking); + } + + /// Warm the thread pool up to the given number of threads. + /// https://www.youtube.com/watch?v=ys3qcbO5KWw + pub fn warm(self: *ThreadPool, count: u14) void { + var sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + if (sync.spawned >= count) + return; + + const to_spawn = @min(count - sync.spawned, @as(u14, @truncate(self.max_threads))); + while (sync.spawned < to_spawn) { + var new_sync = sync; + new_sync.spawned += 1; + sync = @as(Sync, @bitCast(self.sync.tryCompareAndSwap( + @as(u32, @bitCast(sync)), + @as(u32, @bitCast(new_sync)), + .Release, + .Monotonic, + ) orelse break)); + const spawn_config = if (builtin.os.tag.isDarwin()) + // stack size must be a multiple of page_size + // macOS will fail to spawn a thread if the stack size is not a multiple of page_size + std.Thread.SpawnConfig{ .stack_size = ((std.Thread.SpawnConfig{}).stack_size + (std.mem.page_size / 2) / std.mem.page_size) * std.mem.page_size } + else + std.Thread.SpawnConfig{}; + + const thread = std.Thread.spawn(spawn_config, Thread.run, .{self}) catch return self.unregister(null); + thread.detach(); + } + } + + noinline fn notifySlow(self: *ThreadPool, is_waking: bool) void { + var sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + while (sync.state != .shutdown) { + const can_wake = is_waking or (sync.state == .pending); + if (is_waking) { + assert(sync.state == .waking); + } + + var new_sync = sync; + new_sync.notified = true; + if (can_wake and sync.idle > 0) { // wake up an idle thread + new_sync.state = .signaled; + } else if (can_wake and sync.spawned < self.max_threads) { // spawn a new thread + new_sync.state = .signaled; + new_sync.spawned += 1; + } else if (is_waking) { // no other thread to pass on "waking" status + new_sync.state = .pending; + } else if (sync.notified) { // nothing to update + return; + } + + // Release barrier synchronizes with Acquire in wait() + // to ensure pushes to run queues happen before observing a posted notification. + sync = @as(Sync, @bitCast(self.sync.tryCompareAndSwap( + @as(u32, @bitCast(sync)), + @as(u32, @bitCast(new_sync)), + .Release, + .Monotonic, + ) orelse { + // We signaled to notify an idle thread + if (can_wake and sync.idle > 0) { + return self.idle_event.notify(); + } + + // We signaled to spawn a new thread + if (can_wake and sync.spawned < self.max_threads) { + const spawn_config = if (builtin.os.tag.isDarwin()) + // stack size must be a multiple of page_size + // macOS will fail to spawn a thread if the stack size is not a multiple of page_size + std.Thread.SpawnConfig{ .stack_size = ((std.Thread.SpawnConfig{}).stack_size + (std.mem.page_size / 2) / std.mem.page_size) * std.mem.page_size } + else + std.Thread.SpawnConfig{}; + + const thread = std.Thread.spawn(spawn_config, Thread.run, .{self}) catch return self.unregister(null); + // if (self.name.len > 0) thread.setName(self.name) catch {}; + return thread.detach(); + } + + return; + })); + } + } + + noinline fn wait(self: *ThreadPool, _is_waking: bool) error{Shutdown}!bool { + var is_idle = false; + var is_waking = _is_waking; + var sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + + while (true) { + if (sync.state == .shutdown) return error.Shutdown; + if (is_waking) assert(sync.state == .waking); + + // Consume a notification made by notify(). + if (sync.notified) { + var new_sync = sync; + new_sync.notified = false; + if (is_idle) + new_sync.idle -= 1; + if (sync.state == .signaled) + new_sync.state = .waking; + + // Acquire barrier synchronizes with notify() + // to ensure that pushes to run queue are observed after wait() returns. + sync = @as(Sync, @bitCast(self.sync.tryCompareAndSwap( + @as(u32, @bitCast(sync)), + @as(u32, @bitCast(new_sync)), + .Acquire, + .Monotonic, + ) orelse { + return is_waking or (sync.state == .signaled); + })); + } else if (!is_idle) { + var new_sync = sync; + new_sync.idle += 1; + if (is_waking) + new_sync.state = .pending; + + sync = @as(Sync, @bitCast(self.sync.tryCompareAndSwap( + @as(u32, @bitCast(sync)), + @as(u32, @bitCast(new_sync)), + .Monotonic, + .Monotonic, + ) orelse { + is_waking = false; + is_idle = true; + continue; + })); + } else { + if (Thread.current) |current| { + current.drainIdleEvents(); + } + + self.idle_event.wait(); + sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + } + } + } + + /// Marks the thread pool as shutdown + pub noinline fn shutdown(self: *ThreadPool) void { + var sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + while (sync.state != .shutdown) { + var new_sync = sync; + new_sync.notified = true; + new_sync.state = .shutdown; + new_sync.idle = 0; + + // Full barrier to synchronize with both wait() and notify() + sync = @as(Sync, @bitCast(self.sync.tryCompareAndSwap( + @as(u32, @bitCast(sync)), + @as(u32, @bitCast(new_sync)), + .AcqRel, + .Monotonic, + ) orelse { + // Wake up any threads sleeping on the idle_event. + // TODO: I/O polling notification here. + if (sync.idle > 0) self.idle_event.shutdown(); + return; + })); + } + } + + fn register(noalias self: *ThreadPool, noalias thread: *Thread) void { + // Push the thread onto the threads stack in a lock-free manner. + var threads = self.threads.load(.Monotonic); + while (true) { + thread.next = threads; + threads = self.threads.tryCompareAndSwap( + threads, + thread, + .Release, + .Monotonic, + ) orelse break; + } + } + + pub fn setThreadContext(noalias pool: *ThreadPool, ctx: ?*anyopaque) void { + pool.threadpool_context = ctx; + + var thread = pool.threads.load(.Monotonic) orelse return; + thread.ctx = pool.threadpool_context; + while (thread.next) |next| { + next.ctx = pool.threadpool_context; + thread = next; + } + } + + fn unregister(noalias self: *ThreadPool, noalias maybe_thread: ?*Thread) void { + // Un-spawn one thread, either due to a failed OS thread spawning or the thread is exiting. + const one_spawned = @as(u32, @bitCast(Sync{ .spawned = 1 })); + const sync = @as(Sync, @bitCast(self.sync.fetchSub(one_spawned, .Release))); + assert(sync.spawned > 0); + + // The last thread to exit must wake up the thread pool join()er + // who will start the chain to shutdown all the threads. + if (sync.state == .shutdown and sync.spawned == 1) { + self.join_event.notify(); + } + + // If this is a thread pool thread, wait for a shutdown signal by the thread pool join()er. + const thread = maybe_thread orelse return; + thread.join_event.wait(); + + // After receiving the shutdown signal, shutdown the next thread in the pool. + // We have to do that without touching the thread pool itself since it's memory is invalidated by now. + // So just follow our .next link. + const next_thread = thread.next orelse return; + next_thread.join_event.notify(); + } + + fn join(self: *ThreadPool) void { + // Wait for the thread pool to be shutdown() then for all threads to enter a joinable state + var sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + if (!(sync.state == .shutdown and sync.spawned == 0)) { + self.join_event.wait(); + sync = @as(Sync, @bitCast(self.sync.load(.Monotonic))); + } + + assert(sync.state == .shutdown); + assert(sync.spawned == 0); + + // If there are threads, start off the chain sending it the shutdown signal. + // The thread receives the shutdown signal and sends it to the next thread, and the next.. + const thread = self.threads.load(.Acquire) orelse return; + thread.join_event.notify(); + } + + pub const Thread = struct { + next: ?*Thread = null, + target: ?*Thread = null, + join_event: Event = .{}, + run_queue: Node.Queue = .{}, + idle_queue: Node.Queue = .{}, + run_buffer: Node.Buffer = .{}, + ctx: ?*anyopaque = null, + + pub threadlocal var current: ?*Thread = null; + + pub fn pushIdleTask(self: *Thread, task: *Task) void { + const list = Node.List{ + .head = &task.node, + .tail = &task.node, + }; + self.idle_queue.push(list); + } + + /// Thread entry point which runs a worker for the ThreadPool + fn run(thread_pool: *ThreadPool) void { + var self_ = Thread{}; + var self = &self_; + current = self; + + if (thread_pool.on_thread_spawn) |spawn| { + current.?.ctx = spawn(thread_pool.threadpool_context); + } + + thread_pool.register(self); + + defer thread_pool.unregister(self); + + var is_waking = false; + while (true) { + is_waking = thread_pool.wait(is_waking) catch return; + + while (self.pop(thread_pool)) |result| { + if (result.pushed or is_waking) + thread_pool.notify(is_waking); + is_waking = false; + + const task = @fieldParentPtr(Task, "node", result.node); + (task.callback)(task); + } + + self.drainIdleEvents(); + } + } + + pub fn drainIdleEvents(noalias self: *Thread) void { + var consumer = self.idle_queue.tryAcquireConsumer() catch return; + defer self.idle_queue.releaseConsumer(consumer); + while (self.idle_queue.pop(&consumer)) |node| { + const task = @fieldParentPtr(Task, "node", node); + (task.callback)(task); + } + } + + /// Try to dequeue a Node/Task from the ThreadPool. + /// Spurious reports of dequeue() returning empty are allowed. + pub fn pop(noalias self: *Thread, noalias thread_pool: *ThreadPool) ?Node.Buffer.Stole { + // Check our local buffer first + if (self.run_buffer.pop()) |node| { + return Node.Buffer.Stole{ + .node = node, + .pushed = false, + }; + } + + // Then check our local queue + if (self.run_buffer.consume(&self.run_queue)) |stole| { + return stole; + } + + // Then the global queue + if (self.run_buffer.consume(&thread_pool.run_queue)) |stole| { + return stole; + } + + // Then try work stealing from other threads + var num_threads: u32 = @as(Sync, @bitCast(thread_pool.sync.load(.Monotonic))).spawned; + while (num_threads > 0) : (num_threads -= 1) { + // Traverse the stack of registered threads on the thread pool + const target = self.target orelse thread_pool.threads.load(.Acquire) orelse unreachable; + self.target = target.next; + + // Try to steal from their queue first to avoid contention (the target steal's from queue last). + if (self.run_buffer.consume(&target.run_queue)) |stole| { + return stole; + } + + // Skip stealing from the buffer if we're the target. + // We still steal from our own queue above given it may have just been locked the first time we tried. + if (target == self) { + continue; + } + + // Steal from the buffer of a remote thread as a last resort + if (self.run_buffer.steal(&target.run_buffer)) |stole| { + return stole; + } + } + + return null; + } + }; + + /// An event which stores 1 semaphore token and is multi-threaded safe. + /// The event can be shutdown(), waking up all wait()ing threads and + /// making subsequent wait()'s return immediately. + const Event = struct { + state: Atomic(u32) = Atomic(u32).init(EMPTY), + + const EMPTY = 0; + const WAITING = 1; + pub const NOTIFIED = 2; + const SHUTDOWN = 3; + + /// Wait for and consume a notification + /// or wait for the event to be shutdown entirely + noinline fn wait(self: *Event) void { + var acquire_with: u32 = EMPTY; + var state = self.state.load(.Monotonic); + + while (true) { + // If we're shutdown then exit early. + // Acquire barrier to ensure operations before the shutdown() are seen after the wait(). + // Shutdown is rare so it's better to have an Acquire barrier here instead of on CAS failure + load which are common. + if (state == SHUTDOWN) { + std.atomic.fence(.Acquire); + return; + } + + // Consume a notification when it pops up. + // Acquire barrier to ensure operations before the notify() appear after the wait(). + if (state == NOTIFIED) { + state = self.state.tryCompareAndSwap( + state, + acquire_with, + .Acquire, + .Monotonic, + ) orelse return; + continue; + } + + // There is no notification to consume, we should wait on the event by ensuring its WAITING. + if (state != WAITING) blk: { + state = self.state.tryCompareAndSwap( + state, + WAITING, + .Monotonic, + .Monotonic, + ) orelse break :blk; + continue; + } + + // Wait on the event until a notify() or shutdown(). + // If we wake up to a notification, we must acquire it with WAITING instead of EMPTY + // since there may be other threads sleeping on the Futex who haven't been woken up yet. + // + // Acquiring to WAITING will make the next notify() or shutdown() wake a sleeping futex thread + // who will either exit on SHUTDOWN or acquire with WAITING again, ensuring all threads are awoken. + // This unfortunately results in the last notify() or shutdown() doing an extra futex wake but that's fine. + Futex.wait(&self.state, WAITING); + state = self.state.load(.Monotonic); + acquire_with = WAITING; + } + } + + /// Wait for and consume a notification + /// or wait for the event to be shutdown entirely + noinline fn waitFor(self: *Event, timeout: usize) void { + _ = timeout; + var acquire_with: u32 = EMPTY; + var state = self.state.load(.Monotonic); + + while (true) { + // If we're shutdown then exit early. + // Acquire barrier to ensure operations before the shutdown() are seen after the wait(). + // Shutdown is rare so it's better to have an Acquire barrier here instead of on CAS failure + load which are common. + if (state == SHUTDOWN) { + std.atomic.fence(.Acquire); + return; + } + + // Consume a notification when it pops up. + // Acquire barrier to ensure operations before the notify() appear after the wait(). + if (state == NOTIFIED) { + state = self.state.tryCompareAndSwap( + state, + acquire_with, + .Acquire, + .Monotonic, + ) orelse return; + continue; + } + + // There is no notification to consume, we should wait on the event by ensuring its WAITING. + if (state != WAITING) blk: { + state = self.state.tryCompareAndSwap( + state, + WAITING, + .Monotonic, + .Monotonic, + ) orelse break :blk; + continue; + } + + // Wait on the event until a notify() or shutdown(). + // If we wake up to a notification, we must acquire it with WAITING instead of EMPTY + // since there may be other threads sleeping on the Futex who haven't been woken up yet. + // + // Acquiring to WAITING will make the next notify() or shutdown() wake a sleeping futex thread + // who will either exit on SHUTDOWN or acquire with WAITING again, ensuring all threads are awoken. + // This unfortunately results in the last notify() or shutdown() doing an extra futex wake but that's fine. + Futex.wait(&self.state, WAITING); + state = self.state.load(.Monotonic); + acquire_with = WAITING; + } + } + + /// Post a notification to the event if it doesn't have one already + /// then wake up a waiting thread if there is one as well. + fn notify(self: *Event) void { + return self.wake(NOTIFIED, 1); + } + + /// Marks the event as shutdown, making all future wait()'s return immediately. + /// Then wakes up any threads currently waiting on the Event. + fn shutdown(self: *Event) void { + return self.wake(SHUTDOWN, std.math.maxInt(u32)); + } + + fn wake(self: *Event, release_with: u32, wake_threads: u32) void { + // Update the Event to notify it with the new `release_with` state (either NOTIFIED or SHUTDOWN). + // Release barrier to ensure any operations before this are this to happen before the wait() in the other threads. + const state = self.state.swap(release_with, .Release); + + // Only wake threads sleeping in futex if the state is WAITING. + // Avoids unnecessary wake ups. + if (state == WAITING) { + Futex.wake(&self.state, wake_threads); + } + } + }; + + /// Linked list intrusive memory node and lock-free data structures to operate with it + pub const Node = struct { + next: ?*Node = null, + + /// A linked list of Nodes + const List = struct { + head: *Node, + tail: *Node, + }; + + /// An unbounded multi-producer-(non blocking)-multi-consumer queue of Node pointers. + const Queue = struct { + stack: Atomic(usize) = Atomic(usize).init(0), + cache: ?*Node = null, + + const HAS_CACHE: usize = 0b01; + const IS_CONSUMING: usize = 0b10; + const PTR_MASK: usize = ~(HAS_CACHE | IS_CONSUMING); + + comptime { + assert(@alignOf(Node) >= ((IS_CONSUMING | HAS_CACHE) + 1)); + } + + fn push(noalias self: *Queue, list: List) void { + var stack = self.stack.load(.Monotonic); + while (true) { + // Attach the list to the stack (pt. 1) + list.tail.next = @as(?*Node, @ptrFromInt(stack & PTR_MASK)); + + // Update the stack with the list (pt. 2). + // Don't change the HAS_CACHE and IS_CONSUMING bits of the consumer. + var new_stack = @intFromPtr(list.head); + assert(new_stack & ~PTR_MASK == 0); + new_stack |= (stack & ~PTR_MASK); + + // Push to the stack with a release barrier for the consumer to see the proper list links. + stack = self.stack.tryCompareAndSwap( + stack, + new_stack, + .Release, + .Monotonic, + ) orelse break; + } + } + + fn tryAcquireConsumer(self: *Queue) error{ Empty, Contended }!?*Node { + var stack = self.stack.load(.Monotonic); + while (true) { + if (stack & IS_CONSUMING != 0) + return error.Contended; // The queue already has a consumer. + if (stack & (HAS_CACHE | PTR_MASK) == 0) + return error.Empty; // The queue is empty when there's nothing cached and nothing in the stack. + + // When we acquire the consumer, also consume the pushed stack if the cache is empty. + var new_stack = stack | HAS_CACHE | IS_CONSUMING; + if (stack & HAS_CACHE == 0) { + assert(stack & PTR_MASK != 0); + new_stack &= ~PTR_MASK; + } + + // Acquire barrier on getting the consumer to see cache/Node updates done by previous consumers + // and to ensure our cache/Node updates in pop() happen after that of previous consumers. + stack = self.stack.tryCompareAndSwap( + stack, + new_stack, + .Acquire, + .Monotonic, + ) orelse return self.cache orelse @as(*Node, @ptrFromInt(stack & PTR_MASK)); + } + } + + fn releaseConsumer(noalias self: *Queue, noalias consumer: ?*Node) void { + // Stop consuming and remove the HAS_CACHE bit as well if the consumer's cache is empty. + // When HAS_CACHE bit is zeroed, the next consumer will acquire the pushed stack nodes. + var remove = IS_CONSUMING; + if (consumer == null) + remove |= HAS_CACHE; + + // Release the consumer with a release barrier to ensure cache/node accesses + // happen before the consumer was released and before the next consumer starts using the cache. + self.cache = consumer; + const stack = self.stack.fetchSub(remove, .Release); + assert(stack & remove != 0); + } + + fn pop(noalias self: *Queue, noalias consumer_ref: *?*Node) ?*Node { + // Check the consumer cache (fast path) + if (consumer_ref.*) |node| { + consumer_ref.* = node.next; + return node; + } + + // Load the stack to see if there was anything pushed that we could grab. + var stack = self.stack.load(.Monotonic); + assert(stack & IS_CONSUMING != 0); + if (stack & PTR_MASK == 0) { + return null; + } + + // Nodes have been pushed to the stack, grab then with an Acquire barrier to see the Node links. + stack = self.stack.swap(HAS_CACHE | IS_CONSUMING, .Acquire); + assert(stack & IS_CONSUMING != 0); + assert(stack & PTR_MASK != 0); + + const node = @as(*Node, @ptrFromInt(stack & PTR_MASK)); + consumer_ref.* = node.next; + return node; + } + }; + + /// A bounded single-producer, multi-consumer ring buffer for node pointers. + const Buffer = struct { + head: Atomic(Index) = Atomic(Index).init(0), + tail: Atomic(Index) = Atomic(Index).init(0), + array: [capacity]Atomic(*Node) = undefined, + + const Index = u32; + const capacity = 256; // Appears to be a pretty good trade-off in space vs contended throughput + comptime { + assert(std.math.maxInt(Index) >= capacity); + assert(std.math.isPowerOfTwo(capacity)); + } + + fn push(noalias self: *Buffer, noalias list: *List) error{Overflow}!void { + var head = self.head.load(.Monotonic); + var tail = self.tail.loadUnchecked(); // we're the only thread that can change this + + while (true) { + var size = tail -% head; + assert(size <= capacity); + + // Push nodes from the list to the buffer if it's not empty.. + if (size < capacity) { + var nodes: ?*Node = list.head; + while (size < capacity) : (size += 1) { + const node = nodes orelse break; + nodes = node.next; + + // Array written atomically with weakest ordering since it could be getting atomically read by steal(). + self.array[tail % capacity].store(node, .Unordered); + tail +%= 1; + } + + // Release barrier synchronizes with Acquire loads for steal()ers to see the array writes. + self.tail.store(tail, .Release); + + // Update the list with the nodes we pushed to the buffer and try again if there's more. + list.head = nodes orelse return; + std.atomic.spinLoopHint(); + head = self.head.load(.Monotonic); + continue; + } + + // Try to steal/overflow half of the tasks in the buffer to make room for future push()es. + // Migrating half amortizes the cost of stealing while requiring future pops to still use the buffer. + // Acquire barrier to ensure the linked list creation after the steal only happens after we successfully steal. + var migrate = size / 2; + head = self.head.tryCompareAndSwap( + head, + head +% migrate, + .Acquire, + .Monotonic, + ) orelse { + // Link the migrated Nodes together + const first = self.array[head % capacity].loadUnchecked(); + while (migrate > 0) : (migrate -= 1) { + const prev = self.array[head % capacity].loadUnchecked(); + head +%= 1; + prev.next = self.array[head % capacity].loadUnchecked(); + } + + // Append the list that was supposed to be pushed to the end of the migrated Nodes + const last = self.array[(head -% 1) % capacity].loadUnchecked(); + last.next = list.head; + list.tail.next = null; + + // Return the migrated nodes + the original list as overflowed + list.head = first; + return error.Overflow; + }; + } + } + + fn pop(self: *Buffer) ?*Node { + var head = self.head.load(.Monotonic); + var tail = self.tail.loadUnchecked(); // we're the only thread that can change this + + while (true) { + // Quick sanity check and return null when not empty + var size = tail -% head; + assert(size <= capacity); + if (size == 0) { + return null; + } + + // Dequeue with an acquire barrier to ensure any writes done to the Node + // only happens after we successfully claim it from the array. + head = self.head.tryCompareAndSwap( + head, + head +% 1, + .Acquire, + .Monotonic, + ) orelse return self.array[head % capacity].loadUnchecked(); + } + } + + const Stole = struct { + node: *Node, + pushed: bool, + }; + + fn consume(noalias self: *Buffer, noalias queue: *Queue) ?Stole { + var consumer = queue.tryAcquireConsumer() catch return null; + defer queue.releaseConsumer(consumer); + + const head = self.head.load(.Monotonic); + const tail = self.tail.loadUnchecked(); // we're the only thread that can change this + + const size = tail -% head; + assert(size <= capacity); + assert(size == 0); // we should only be consuming if our array is empty + + // Pop nodes from the queue and push them to our array. + // Atomic stores to the array as steal() threads may be atomically reading from it. + var pushed: Index = 0; + while (pushed < capacity) : (pushed += 1) { + const node = queue.pop(&consumer) orelse break; + self.array[(tail +% pushed) % capacity].store(node, .Unordered); + } + + // We will be returning one node that we stole from the queue. + // Get an extra, and if that's not possible, take one from our array. + const node = queue.pop(&consumer) orelse blk: { + if (pushed == 0) return null; + pushed -= 1; + break :blk self.array[(tail +% pushed) % capacity].loadUnchecked(); + }; + + // Update the array tail with the nodes we pushed to it. + // Release barrier to synchronize with Acquire barrier in steal()'s to see the written array Nodes. + if (pushed > 0) self.tail.store(tail +% pushed, .Release); + return Stole{ + .node = node, + .pushed = pushed > 0, + }; + } + + fn steal(noalias self: *Buffer, noalias buffer: *Buffer) ?Stole { + const head = self.head.load(.Monotonic); + const tail = self.tail.loadUnchecked(); // we're the only thread that can change this + + const size = tail -% head; + assert(size <= capacity); + assert(size == 0); // we should only be stealing if our array is empty + + while (true) : (std.atomic.spinLoopHint()) { + const buffer_head = buffer.head.load(.Acquire); + const buffer_tail = buffer.tail.load(.Acquire); + + // Overly large size indicates the the tail was updated a lot after the head was loaded. + // Reload both and try again. + const buffer_size = buffer_tail -% buffer_head; + if (buffer_size > capacity) { + continue; + } + + // Try to steal half (divCeil) to amortize the cost of stealing from other threads. + const steal_size = buffer_size - (buffer_size / 2); + if (steal_size == 0) { + return null; + } + + // Copy the nodes we will steal from the target's array to our own. + // Atomically load from the target buffer array as it may be pushing and atomically storing to it. + // Atomic store to our array as other steal() threads may be atomically loading from it as above. + var i: Index = 0; + while (i < steal_size) : (i += 1) { + const node = buffer.array[(buffer_head +% i) % capacity].load(.Unordered); + self.array[(tail +% i) % capacity].store(node, .Unordered); + } + + // Try to commit the steal from the target buffer using: + // - an Acquire barrier to ensure that we only interact with the stolen Nodes after the steal was committed. + // - a Release barrier to ensure that the Nodes are copied above prior to the committing of the steal + // because if they're copied after the steal, the could be getting rewritten by the target's push(). + _ = buffer.head.compareAndSwap( + buffer_head, + buffer_head +% steal_size, + .AcqRel, + .Monotonic, + ) orelse { + // Pop one from the nodes we stole as we'll be returning it + const pushed = steal_size - 1; + const node = self.array[(tail +% pushed) % capacity].loadUnchecked(); + + // Update the array tail with the nodes we pushed to it. + // Release barrier to synchronize with Acquire barrier in steal()'s to see the written array Nodes. + if (pushed > 0) self.tail.store(tail +% pushed, .Release); + return Stole{ + .node = node, + .pushed = pushed > 0, + }; + }; + } + } + }; + }; +}; + +// test "parallel for loop" { +// var thread_pool = ThreadPool.init(.{ .max_threads = 12 }); +// var sleepy_time: u32 = 100; +// var random = std.rand.DefaultPrng.init(1); +// var rng = random.random(); + +// var huge_array = &[_]u32{ +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// sleepy_time + rng.uintAtMost(u32, 20), +// }; +// const Runner = struct { +// completed: usize = 0, +// total: usize = 0, +// pub fn run(ctx: *@This(), value: u32, _: usize) void { +// std.time.sleep(value); +// ctx.completed += 1; +// std.debug.assert(ctx.completed <= ctx.total); +// } +// }; +// var runny = try std.heap.page_allocator.create(Runner); +// runny.* = .{ .total = huge_array.len }; +// try thread_pool.doAndWait(std.heap.page_allocator, null, runny, Runner.run, std.mem.span(huge_array)); +// try std.testing.expectEqual(huge_array.len, runny.completed); +// } + +// pub fn NewWorkPool(comptime max_threads: ?usize) type { +// return struct { +// var pool: ThreadPool = undefined; +// var loaded: bool = false; + +// fn create() *ThreadPool { +// @setCold(true); + +// pool = ThreadPool.init(.{ +// .max_threads = max_threads orelse @max(@as(u32, @truncate(std.Thread.getCpuCount() catch 0)), 2), +// .stack_size = 2 * 1024 * 1024, +// }); +// return &pool; +// } + +// pub fn deinit() void { +// get().deinit(); +// } + +// pub inline fn get() *ThreadPool { +// // lil racy +// if (loaded) return &pool; +// loaded = true; + +// return create(); +// } + +// pub fn scheduleBatch(batch: ThreadPool.Batch) void { +// get().schedule(batch); +// } + +// pub fn scheduleTask(task: *ThreadPool.Task) void { +// get().schedule(ThreadPool.Batch.from(task)); +// } + +// pub fn go(allocator: std.mem.Allocator, comptime Context: type, context: Context, comptime function: *const fn (Context) void) !void { +// const TaskType = struct { +// task: ThreadPool.Task, +// context: Context, +// allocator: std.mem.Allocator, + +// pub fn callback(task: *ThreadPool.Task) void { +// var this_task = @fieldParentPtr(@This(), "task", task); +// function(this_task.context); +// this_task.allocator.destroy(this_task); +// } +// }; + +// var task_ = try allocator.create(TaskType); +// task_.* = .{ +// .task = .{ .callback = TaskType.callback }, +// .context = context, +// .allocator = allocator, +// }; +// scheduleTask(&task_.task); +// } +// }; +// } + +// pub const WorkPool = NewWorkPool(null); +// const testing = std.testing; + +// const CrdsTableTrimContext = struct { +// index: usize, +// max_trim: usize, +// self: *CrdsTable, +// }; + +// const CrdsTable = struct { +// pub fn trim(context: CrdsTableTrimContext) void { +// const self = context.self; +// _ = self; +// const max_trim = context.max_trim; +// _ = max_trim; +// const index = context.index; +// _ = index; + +// std.debug.print("I ran!\n\n", .{}); +// // todo + +// } +// }; + +// test "sync.thread_pool: workpool works" { +// var crds: CrdsTable = CrdsTable{}; +// var a = CrdsTableTrimContext{ .index = 1, .max_trim = 2, .self = &crds }; +// defer WorkPool.deinit(); +// try WorkPool.go(testing.allocator, CrdsTableTrimContext, a, CrdsTable.trim); + +// std.time.sleep(std.time.ns_per_s * 1); +// WorkPool.pool.shutdown(); +// }