Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(accountsdb): Make the fields field optional, and validate that it's non-null before working with it #170

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions src/accountsdb/db.zig
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub const AccountsDB = struct {

logger: Logger,
config: AccountsDBConfig,
fields: AccountsDbFields = undefined,
fields: ?AccountsDbFields = null,

const Self = @This();

Expand All @@ -81,24 +81,28 @@ pub const AccountsDB = struct {
config.storage_cache_size,
);

var disk_allocator_ptr: ?*DiskMemoryAllocator = null;
var reference_allocator = std.heap.page_allocator;
if (config.disk_index_path) |disk_index_path| {
var ptr = try allocator.create(DiskMemoryAllocator);
ptr.* = try DiskMemoryAllocator.init(disk_index_path);
reference_allocator = ptr.allocator();
disk_allocator_ptr = ptr;
}
const maybe_disk_allocator_ptr: ?*DiskMemoryAllocator, //
const reference_allocator: std.mem.Allocator //
= blk: {
if (config.disk_index_path) |disk_index_path| {
const ptr = try allocator.create(DiskMemoryAllocator);
errdefer allocator.destroy(ptr);
ptr.* = try DiskMemoryAllocator.init(disk_index_path);
break :blk .{ ptr, ptr.allocator() };
} else {
break :blk .{ null, std.heap.page_allocator };
}
};

const account_index = try AccountIndex.init(
allocator,
reference_allocator,
config.number_of_index_bins,
);

return Self{
return .{
.allocator = allocator,
.disk_allocator_ptr = disk_allocator_ptr,
.disk_allocator_ptr = maybe_disk_allocator_ptr,
.storage = storage,
.account_index = account_index,
.logger = logger,
Expand Down Expand Up @@ -280,7 +284,8 @@ pub const AccountsDB = struct {
timer.reset();
}

/// multithread entrypoint into parseAndBinAccountFiles
/// multithread entrypoint into parseAndBinAccountFiles.
/// Assumes that `loading_threads[thread_id].fields != null`.
pub fn loadAndVerifyAccountsFilesMultiThread(
loading_threads: []AccountsDB,
filenames: [][]const u8,
Expand All @@ -303,6 +308,7 @@ pub const AccountsDB = struct {

/// loads and verifies the account files into the threads file map
/// and stores the accounts into the threads index
/// Assumes `self.fields != null`.
pub fn loadAndVerifyAccountsFiles(
self: *Self,
accounts_dir_path: []const u8,
Expand All @@ -311,6 +317,8 @@ pub const AccountsDB = struct {
// when we multithread this function we only want to print on the first thread
print_progress: bool,
) !void {
std.debug.assert(self.fields != null);

var file_map = &self.storage.file_map;
try file_map.ensureTotalCapacity(file_names.len);

Expand Down Expand Up @@ -340,7 +348,7 @@ pub const AccountsDB = struct {
const accounts_file_id = try std.fmt.parseInt(usize, fiter.next().?, 10);

// read metadata
const file_infos: ArrayList(AccountFileInfo) = self.fields.file_map.get(slot) orelse {
const file_infos: ArrayList(AccountFileInfo) = self.fields.?.file_map.get(slot) orelse {
// dont read account files which are not in the file_map
// note: this can happen when we load from a snapshot and there are extra account files
// in the directory which dont correspond to the snapshot were loading
Expand Down Expand Up @@ -612,7 +620,7 @@ pub const AccountsDB = struct {
full_snapshot_slot: Slot,
expected_full_lamports: u64,
) !void {
const expected_accounts_hash = self.fields.bank_hash_info.accounts_hash;
const expected_accounts_hash = self.fields.?.bank_hash_info.accounts_hash;

// validate the full snapshot
self.logger.infof("validating the full snapshot", .{});
Expand Down
22 changes: 13 additions & 9 deletions src/accountsdb/download.zig
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@ const PeerSearchResult = struct {
pub fn findPeersToDownloadFromAssumeCapacity(
allocator: std.mem.Allocator,
table: *const GossipTable,
contact_infos: []ContactInfo,
contact_infos: []const ContactInfo,
my_shred_version: usize,
my_pubkey: Pubkey,
blacklist: []Pubkey,
trusted_validators: ?std.ArrayList(Pubkey),
blacklist: []const Pubkey,
trusted_validators: ?[]const Pubkey,
/// `.capacity` must be >= `contact_infos.len`.
/// The arraylist is first cleared, and then the outputs
/// are appended to it.
valid_peers: *std.ArrayList(PeerSnapshotHash),
) !PeerSearchResult {
// clear the list
valid_peers.clearRetainingCapacity();
std.debug.assert(valid_peers.capacity >= contact_infos.len);

const TrustedMapType = std.AutoHashMap(
SlotAndHash, // full snapshot hash
Expand All @@ -62,7 +66,7 @@ pub fn findPeersToDownloadFromAssumeCapacity(
// populate with the hashes of trusted validators
var trusted_count: usize = 0;
// SAFE: the perf is safe because maybe_ is non null only if trusted_validators is non-null
for (trusted_validators.?.items) |trusted_validator| {
for (trusted_validators.?) |trusted_validator| {
const gossip_data = table.get(.{ .SnapshotHashes = trusted_validator }) orelse continue;
const trusted_hashes = gossip_data.value.data.SnapshotHashes;
trusted_count += 1;
Expand Down Expand Up @@ -154,7 +158,7 @@ pub fn downloadSnapshotsFromGossip(
allocator: std.mem.Allocator,
logger: Logger,
// if null, then we trust any peer for snapshot download
maybe_trusted_validators: ?std.ArrayList(Pubkey),
maybe_trusted_validators: ?[]const Pubkey,
gossip_service: *GossipService,
output_dir: []const u8,
min_mb_per_sec: usize,
Expand Down Expand Up @@ -209,8 +213,8 @@ pub fn downloadSnapshotsFromGossip(
defer allocator.free(snapshot_filename);

const rpc_socket = peer.contact_info.getSocket(socket_tag.RPC).?;
const r = rpc_socket.toString();
const rpc_url = r[0][0..r[1]];
const rpc_url_bounded = rpc_socket.toStringBounded();
const rpc_url = rpc_url_bounded.constSlice();

const snapshot_url = try std.fmt.allocPrintZ(allocator, "http://{s}/{s}", .{
rpc_url,
Expand Down Expand Up @@ -520,7 +524,7 @@ test "accounts_db.download: test remove untrusted peers" {
my_shred_version,
my_pubkey,
&.{},
trusted_validators,
trusted_validators.items,
&valid_peers,
);
try std.testing.expectEqual(valid_peers.items.len, 10);
Expand All @@ -535,7 +539,7 @@ test "accounts_db.download: test remove untrusted peers" {
my_shred_version,
my_pubkey,
&.{},
trusted_validators,
trusted_validators.items,
&valid_peers,
);
try std.testing.expectEqual(valid_peers.items.len, 8);
Expand Down
4 changes: 2 additions & 2 deletions src/cmd/cmd.zig
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ fn downloadSnapshot() !void {
try downloadSnapshotsFromGossip(
gpa_allocator,
logger,
trusted_validators,
if (trusted_validators) |trusted| trusted.items else null,
&gossip_service,
snapshot_dir_str,
@intCast(min_mb_per_sec),
Expand Down Expand Up @@ -795,7 +795,7 @@ fn getOrDownloadSnapshots(
try downloadSnapshotsFromGossip(
allocator,
logger,
trusted_validators,
if (trusted_validators) |trusted| trusted.items else null,
gossip_service orelse return error.SnapshotsNotFoundAndNoGossipService,
snapshot_dir_str,
@intCast(min_mb_per_sec),
Expand Down
15 changes: 13 additions & 2 deletions src/net/net.zig
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,20 @@ pub const SocketAddr = union(enum(u8)) {
/// - integer: length of the string within the array
pub fn toString(self: Self) struct { [53]u8, usize } {
var buf: [53]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const len = self.toStringBuf(&buf);
return .{ buf, len };
}

pub fn toStringBounded(self: Self) std.BoundedArray(u8, 53) {
var buf: [53]u8 = undefined;
const len = self.toStringBuf(&buf);
return std.BoundedArray(u8, 53).fromSlice(buf[0..len]) catch unreachable;
}

pub fn toStringBuf(self: Self, buf: *[53]u8) std.math.IntFittingRange(0, 53) {
var stream = std.io.fixedBufferStream(buf);
self.toAddress().format("", .{}, stream.writer()) catch unreachable;
return .{ buf, stream.pos };
return @intCast(stream.pos);
}

pub fn isUnspecified(self: *const Self) bool {
Expand Down
Loading