diff --git a/build.zig b/build.zig index 25c3865ba..6cfea8ba4 100644 --- a/build.zig +++ b/build.zig @@ -23,6 +23,7 @@ pub fn build(b: *std.Build) void { const zig_network_module = b.dependency("zig-network", opts).module("network"); const zig_cli_module = b.dependency("zig-cli", opts).module("zig-cli"); const getty_mod = b.dependency("getty", opts).module("getty"); + const httpz_mod = b.dependency("httpz", opts).module("httpz"); const lib = b.addStaticLibrary(.{ .name = "sig", @@ -53,6 +54,10 @@ pub fn build(b: *std.Build) void { .name = "getty", .module = getty_mod, }, + .{ + .name = "httpz", + .module = httpz_mod, + }, }, }); @@ -60,6 +65,7 @@ pub fn build(b: *std.Build) void { lib.addModule("zig-network", zig_network_module); lib.addModule("zig-cli", zig_cli_module); lib.addModule("getty", getty_mod); + lib.addModule("httpz", httpz_mod); // This declares intent for the library to be installed into the standard // location when the user invokes the "install" step (the default step when @@ -77,6 +83,8 @@ pub fn build(b: *std.Build) void { tests.addModule("base58-zig", base58_module); tests.addModule("zig-cli", zig_cli_module); tests.addModule("getty", getty_mod); + tests.addModule("httpz", httpz_mod); + const run_tests = b.addRunArtifact(tests); const test_step = b.step("test", "Run library tests"); test_step.dependOn(&lib.step); @@ -94,6 +102,7 @@ pub fn build(b: *std.Build) void { exe.addModule("zig-network", zig_network_module); exe.addModule("zig-cli", zig_cli_module); exe.addModule("getty", getty_mod); + exe.addModule("httpz", httpz_mod); // This declares intent for the executable to be installed into the // standard location when the user invokes the "install" step (the default @@ -137,6 +146,8 @@ pub fn build(b: *std.Build) void { fuzz_exe.addModule("zig-network", zig_network_module); fuzz_exe.addModule("zig-cli", zig_cli_module); fuzz_exe.addModule("getty", getty_mod); + fuzz_exe.addModule("httpz", httpz_mod); + b.installArtifact(fuzz_exe); const fuzz_cmd = b.addRunArtifact(fuzz_exe); if (b.args) |args| { @@ -158,11 +169,12 @@ pub fn build(b: *std.Build) void { benchmark_exe.addModule("zig-network", zig_network_module); benchmark_exe.addModule("zig-cli", zig_cli_module); benchmark_exe.addModule("getty", getty_mod); + benchmark_exe.addModule("httpz", httpz_mod); + b.installArtifact(benchmark_exe); const benchmark_cmd = b.addRunArtifact(benchmark_exe); if (b.args) |args| { benchmark_cmd.addArgs(args); } - b.step("benchmark", "benchmark gossip").dependOn(&benchmark_cmd.step); } diff --git a/build.zig.zon b/build.zig.zon index c97036005..bd2c5fc23 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -18,5 +18,9 @@ .url = "https://github.com/getty-zig/getty/archive/5b0e750d92ee4ef8e46ad743bb8ced63723acd00.tar.gz", .hash = "12209398657d260abcd6dae946d8da4cd3057b8c7990608476a9f8011aae570d2ebb", }, + .httpz = .{ + .url = "https://github.com/karlseguin/http.zig/archive/7a751549a751d9b45952037abdb127b3225b2ac1.tar.gz", + .hash = "122004f74adf46001fe9129d8cec54bd4a98895ce89f0897790e13b60fa99e527b99", + }, }, } diff --git a/src/cmd/cmd.zig b/src/cmd/cmd.zig index ff73d8674..bc29717d5 100644 --- a/src/cmd/cmd.zig +++ b/src/cmd/cmd.zig @@ -8,6 +8,9 @@ const io = std.io; const Pubkey = @import("../core/pubkey.zig").Pubkey; const SocketAddr = @import("../net/net.zig").SocketAddr; const GossipService = @import("../gossip/gossip_service.zig").GossipService; +const servePrometheus = @import("../prometheus/http.zig").servePrometheus; +const global_registry = @import("../prometheus/registry.zig").global_registry; +const Registry = @import("../prometheus/registry.zig").Registry; var gpa = std.heap.GeneralPurposeAllocator(.{}){}; const gpa_allocator = gpa.allocator(); @@ -31,11 +34,21 @@ var gossip_entrypoints_option = cli.Option{ .value_name = "Entrypoints", }; +var metrics_port_option = cli.Option{ + .long_name = "metrics-port", + .help = "port to expose prometheus metrics via http", + .short_alias = 'm', + .value = cli.OptionValue{ .int = 12345 }, + .required = false, + .value_name = "port_number", +}; + var app = &cli.App{ .name = "sig", .description = "Sig is a Solana client implementation written in Zig.\nThis is still a WIP, PRs welcome.", .version = "0.1.1", .author = "Syndica & Contributors", + .options = &.{&metrics_port_option}, .subcommands = &.{ &cli.Command{ .name = "identity", @@ -76,6 +89,8 @@ fn gossip(_: []const []const u8) !void { // var logger: Logger = .noop; + const metrics_thread = try spawnMetrics(gpa_allocator); + var my_keypair = try getOrInitIdentity(gpa_allocator, logger); var gossip_port: u16 = @intCast(gossip_port_option.value.int.?); @@ -119,6 +134,16 @@ fn gossip(_: []const []const u8) !void { ); handle.join(); + metrics_thread.detach(); +} + +/// Initializes the global registry. Returns error if registry was already initialized. +/// Spawns a thread to serve the metrics over http on the CLI configured port. +/// Uses same allocator for both registry and http adapter. +fn spawnMetrics(allocator: std.mem.Allocator) !std.Thread { + var metrics_port: u16 = @intCast(metrics_port_option.value.int.?); + const registry = try global_registry.initialize(Registry(.{}).init, .{allocator}); + return try std.Thread.spawn(.{}, servePrometheus, .{ allocator, registry, metrics_port }); } pub fn run() !void { diff --git a/src/lib.zig b/src/lib.zig index 4deb81ee0..826e685e9 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -46,6 +46,7 @@ pub const sync = struct { pub usingnamespace @import("sync/mpmc.zig"); pub usingnamespace @import("sync/ref.zig"); pub usingnamespace @import("sync/mux.zig"); + pub usingnamespace @import("sync/once_cell.zig"); pub usingnamespace @import("sync/thread_pool.zig"); }; @@ -75,3 +76,13 @@ pub const net = struct { pub usingnamespace @import("net/net.zig"); pub usingnamespace @import("net/echo.zig"); }; + +pub const prometheus = struct { + pub usingnamespace @import("prometheus/counter.zig"); + pub usingnamespace @import("prometheus/gauge.zig"); + pub usingnamespace @import("prometheus/gauge_fn.zig"); + pub usingnamespace @import("prometheus/http.zig"); + pub usingnamespace @import("prometheus/histogram.zig"); + pub usingnamespace @import("prometheus/metric.zig"); + pub usingnamespace @import("prometheus/registry.zig"); +}; diff --git a/src/prometheus/counter.zig b/src/prometheus/counter.zig new file mode 100644 index 000000000..9044f6af4 --- /dev/null +++ b/src/prometheus/counter.zig @@ -0,0 +1,89 @@ +const std = @import("std"); +const mem = std.mem; +const testing = std.testing; + +const Metric = @import("metric.zig").Metric; + +pub const Counter = struct { + const Self = @This(); + + metric: Metric = Metric{ .getResultFn = getResult }, + value: std.atomic.Atomic(u64) = std.atomic.Atomic(u64).init(0), + + pub fn inc(self: *Self) void { + _ = self.value.fetchAdd(1, .Monotonic); + } + + pub fn add(self: *Self, value: anytype) void { + switch (@typeInfo(@TypeOf(value))) { + .Int, .Float, .ComptimeInt, .ComptimeFloat => {}, + else => @compileError("can't add a non-number"), + } + + _ = self.value.fetchAdd(@intCast(value), .Monotonic); + } + + pub fn get(self: *const Self) u64 { + return self.value.load(.Monotonic); + } + + pub fn reset(self: *Self) void { + _ = self.value.store(0, .Monotonic); + } + + fn getResult(metric: *Metric, _: mem.Allocator) Metric.Error!Metric.Result { + const self = @fieldParentPtr(Self, "metric", metric); + return Metric.Result{ .counter = self.get() }; + } +}; + +test "prometheus.counter: inc/add/dec/set/get" { + var buffer = std.ArrayList(u8).init(testing.allocator); + defer buffer.deinit(); + + var counter = Counter{}; + + try testing.expectEqual(@as(u64, 0), counter.get()); + + counter.inc(); + try testing.expectEqual(@as(u64, 1), counter.get()); + + counter.add(200); + try testing.expectEqual(@as(u64, 201), counter.get()); +} + +test "prometheus.counter: concurrent" { + var counter = Counter{}; + + var threads: [4]std.Thread = undefined; + for (&threads) |*thread| { + thread.* = try std.Thread.spawn( + .{}, + struct { + fn run(c: *Counter) void { + var i: usize = 0; + while (i < 20) : (i += 1) { + c.inc(); + } + } + }.run, + .{&counter}, + ); + } + + for (&threads) |*thread| thread.join(); + + try testing.expectEqual(@as(u64, 80), counter.get()); +} + +test "prometheus.counter: write" { + var counter = Counter{ .value = .{ .value = 340 } }; + + var buffer = std.ArrayList(u8).init(testing.allocator); + defer buffer.deinit(); + + var metric = &counter.metric; + try metric.write(testing.allocator, buffer.writer(), "mycounter"); + + try testing.expectEqualStrings("mycounter 340\n", buffer.items); +} diff --git a/src/prometheus/gauge.zig b/src/prometheus/gauge.zig new file mode 100644 index 000000000..92cd108f6 --- /dev/null +++ b/src/prometheus/gauge.zig @@ -0,0 +1,50 @@ +const std = @import("std"); + +const Metric = @import("metric.zig").Metric; + +/// A gauge that stores the value it reports. +/// Read and write operations are atomic and monotonic. +pub fn Gauge(comptime T: type) type { + return struct { + value: std.atomic.Atomic(T) = .{ .value = 0 }, + metric: Metric = .{ .getResultFn = getResult }, + + const Self = @This(); + + pub fn inc(self: *Self) void { + self.value.fetchAdd(1, .Monotonic); + } + + pub fn add(self: *Self, v: T) void { + self.value.fetchAdd(v, .Monotonic); + } + + pub fn dec(self: *Self) void { + self.value.fetchSub(1, .Monotonic); + } + + pub fn sub(self: *Self, v: T) void { + self.value.fetchAdd(v, .Monotonic); + } + + pub fn set(self: *Self, v: T) void { + self.value.store(v, .Monotonic); + } + + pub fn get(self: *Self) T { + return self.value.load(.Monotonic); + } + + fn getResult(metric: *Metric, allocator: std.mem.Allocator) Metric.Error!Metric.Result { + _ = allocator; + + const self = @fieldParentPtr(Self, "metric", metric); + + return switch (T) { + f64 => Metric.Result{ .gauge = self.get() }, + u64 => Metric.Result{ .gauge_int = self.get() }, + else => unreachable, // Gauge Return may only be 'f64' or 'u64' + }; + } + }; +} diff --git a/src/prometheus/gauge_fn.zig b/src/prometheus/gauge_fn.zig new file mode 100644 index 000000000..0f07c1e7d --- /dev/null +++ b/src/prometheus/gauge_fn.zig @@ -0,0 +1,191 @@ +const std = @import("std"); +const mem = std.mem; +const testing = std.testing; + +const Metric = @import("metric.zig").Metric; + +pub fn GaugeCallFnType(comptime StateType: type, comptime Return: type) type { + const CallFnArgType = switch (@typeInfo(StateType)) { + .Pointer => StateType, + .Optional => |opt| opt.child, + .Void => void, + else => *StateType, + }; + + return *const fn (state: CallFnArgType) Return; +} + +pub fn GaugeFn(comptime StateType: type, comptime Return: type) type { + const CallFnType = GaugeCallFnType(StateType, Return); + + return struct { + const Self = @This(); + + metric: Metric = .{ .getResultFn = getResult }, + callFn: CallFnType = undefined, + state: StateType = undefined, + + pub fn init(callFn: CallFnType, state: StateType) Self { + return .{ + .callFn = callFn, + .state = state, + }; + } + + pub fn get(self: *Self) Return { + const TypeInfo = @typeInfo(StateType); + switch (TypeInfo) { + .Pointer, .Void => { + return self.callFn(self.state); + }, + .Optional => { + if (self.state) |state| { + return self.callFn(state); + } + return 0; + }, + else => { + return self.callFn(&self.state); + }, + } + } + + fn getResult(metric: *Metric, allocator: mem.Allocator) Metric.Error!Metric.Result { + _ = allocator; + + const self = @fieldParentPtr(Self, "metric", metric); + + return switch (Return) { + f64 => Metric.Result{ .gauge = self.get() }, + u64 => Metric.Result{ .gauge_int = self.get() }, + else => unreachable, // Gauge Return may only be 'f64' or 'u64' + }; + } + }; +} + +test "prometheus.gauge_fn: get" { + const TestCase = struct { + state_type: type, + typ: type, + }; + + const testCases = [_]TestCase{ + .{ + .state_type = struct { + value: f64, + }, + .typ = f64, + }, + }; + + inline for (testCases) |tc| { + const State = tc.state_type; + const InnerType = tc.typ; + + var state = State{ .value = 20 }; + + var gauge = GaugeFn(*State, InnerType).init( + struct { + fn get(s: *State) InnerType { + return s.value + 1; + } + }.get, + &state, + ); + + try testing.expectEqual(@as(InnerType, 21), gauge.get()); + } +} + +test "prometheus.gauge_fn: optional state" { + const State = struct { + value: f64, + }; + var state = State{ .value = 20.0 }; + + var gauge = GaugeFn(?*State, f64).init( + struct { + fn get(s: *State) f64 { + return s.value + 1.0; + } + }.get, + &state, + ); + + try testing.expectEqual(@as(f64, 21.0), gauge.get()); +} + +test "prometheus.gauge_fn: non-pointer state" { + var gauge = GaugeFn(f64, f64).init( + struct { + fn get(s: *f64) f64 { + s.* += 1.0; + return s.*; + } + }.get, + 0.0, + ); + + try testing.expectEqual(@as(f64, 1.0), gauge.get()); +} + +test "prometheus.gauge_fn: shared state" { + const State = struct { + mutex: std.Thread.Mutex = .{}, + items: std.ArrayList(usize) = std.ArrayList(usize).init(testing.allocator), + }; + var shared_state = State{}; + defer shared_state.items.deinit(); + + var gauge = GaugeFn(*State, f64).init( + struct { + fn get(state: *State) f64 { + return @floatFromInt(state.items.items.len); + } + }.get, + &shared_state, + ); + + var threads: [4]std.Thread = undefined; + for (&threads, 0..) |*thread, thread_index| { + thread.* = try std.Thread.spawn( + .{}, + struct { + fn run(thread_idx: usize, state: *State) !void { + var i: usize = 0; + while (i < 4) : (i += 1) { + state.mutex.lock(); + defer state.mutex.unlock(); + try state.items.append(thread_idx + i); + } + } + }.run, + .{ thread_index, &shared_state }, + ); + } + + for (&threads) |*thread| thread.join(); + + try testing.expectEqual(@as(usize, 16), @as(usize, @intFromFloat(gauge.get()))); +} + +test "prometheus.gauge_fn: write" { + var gauge = GaugeFn(usize, f64).init( + struct { + fn get(state: *usize) f64 { + state.* += 340; + return @floatFromInt(state.*); + } + }.get, + @as(usize, 0), + ); + + var buffer = std.ArrayList(u8).init(testing.allocator); + defer buffer.deinit(); + + var metric = &gauge.metric; + try metric.write(testing.allocator, buffer.writer(), "mygauge"); + + try testing.expectEqualStrings("mygauge 340.000000\n", buffer.items); +} diff --git a/src/prometheus/histogram.zig b/src/prometheus/histogram.zig new file mode 100644 index 000000000..b85eeeb15 --- /dev/null +++ b/src/prometheus/histogram.zig @@ -0,0 +1,318 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ArrayList = std.ArrayList; +const Atomic = std.atomic.Atomic; +const Ordering = std.atomic.Ordering; + +const Metric = @import("metric.zig").Metric; + +pub const default_buckets: [11]f64 = .{ 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0 }; + +/// Histogram optimized for fast concurrent writes. +/// Reads and writes are thread-safe if you use the public methods. +/// Writes are lock-free. Reads are locked with a mutex because they occupy a shard. +/// +/// The histogram state is represented in a shard. There are two shards, hot and cold. +/// Writes incremenent the hot shard. +/// Reads flip a switch to change which shard is considered hot for writes, +/// then wait for the previous hot shard to cool down before reading it. +pub const Histogram = struct { + allocator: Allocator, + + /// The highest value to include in each bucket. + upper_bounds: ArrayList(f64), + + /// One hot shard for writing, one cold shard for reading. + shards: [2]struct { + /// Total of all observed values. + sum: Atomic(f64) = Atomic(f64).init(0.0), + /// Total number of observations that have finished being recorded to this shard. + count: Atomic(u64) = Atomic(u64).init(0), + /// Cumulative counts for each upper bound. + buckets: ArrayList(Atomic(u64)), + }, + + /// Used to ensure reads and writes occur on separate shards. + /// Atomic representation of `ShardSync`. + shard_sync: Atomic(u64) = Atomic(u64).init(0), + + /// Prevents more than one reader at a time, since read operations actually + /// execute an internal write by swapping the hot and cold shards. + read_mutex: std.Thread.Mutex = .{}, + + /// Used by registry to report the histogram + metric: Metric = .{ .getResultFn = getResult }, + + const ShardSync = packed struct { + /// The total count of events that have started to be recorded (including those that finished). + /// If this is larger than the shard count, it means a write is in progress. + count: u63 = 0, + /// Index of the shard currently being used for writes. + shard: u1 = 0, + }; + + const Self = @This(); + + pub fn init(allocator: Allocator, buckets: []const f64) !Self { + var upper_bounds = try ArrayList(f64).initCapacity(allocator, buckets.len); + upper_bounds.appendSliceAssumeCapacity(buckets); + return Self{ + .allocator = allocator, + .upper_bounds = upper_bounds, + .shards = .{ + .{ .buckets = try shardBuckets(allocator, buckets.len) }, + .{ .buckets = try shardBuckets(allocator, buckets.len) }, + }, + }; + } + + pub fn deinit(self: *Self) void { + self.shards[0].buckets.deinit(); + self.shards[1].buckets.deinit(); + self.upper_bounds.deinit(); + } + + fn shardBuckets(allocator: Allocator, size: usize) !ArrayList(Atomic(u64)) { + var slice = try allocator.alloc(u64, size); + @memset(slice, 0); + return ArrayList(Atomic(u64)).fromOwnedSlice(allocator, @ptrCast(slice)); + } + + /// Writes a value into the histogram. + pub fn observe(self: *Self, value: f64) void { + const shard_sync = self.incrementCount(.Acquire); // acquires lock. must be first step. + const shard = &self.shards[shard_sync.shard]; + for (0.., self.upper_bounds.items) |i, bound| { + if (value <= bound) { + _ = shard.buckets.items[i].fetchAdd(1, .Monotonic); + break; + } + } + _ = shard.sum.fetchAdd(value, .Monotonic); + _ = shard.count.fetchAdd(1, .Release); // releases lock. must be last step. + } + + /// Reads the current state of the histogram. + pub fn getSnapshot(self: *Self, allocator: ?Allocator) !HistogramSnapshot { + var alloc = self.allocator; + if (allocator) |a| alloc = a; + + // Acquire the lock so no one else executes this function at the same time. + self.read_mutex.lock(); + defer self.read_mutex.unlock(); + + // Make the hot shard cold. Some writers may still be writing to it, + // but no more will start after this. + const shard_sync = self.flipShard(.Monotonic); + const cold_shard = &self.shards[shard_sync.shard]; + const hot_shard = &self.shards[shard_sync.shard +% 1]; + + // Wait until all writers are done writing to the cold shard + // TODO: switch to a condvar. see: `std.Thread.Condition` + while (cold_shard.count.tryCompareAndSwap(shard_sync.count, 0, .Acquire, .Monotonic)) |_| { + // Acquire on success: keeps shard usage after. + } + + // Now the cold shard is totally cold and unused by other threads. + // - read the cold shard's data + // - zero out the cold shard. + // - write the cold shard's data into the hot shard. + const cold_shard_sum = cold_shard.sum.swap(0.0, .Monotonic); + var buckets = try ArrayList(Bucket).initCapacity(alloc, self.upper_bounds.items.len); + var cumulative_count: u64 = 0; + for (0.., self.upper_bounds.items) |i, upper_bound| { + const count = cold_shard.buckets.items[i].swap(0, .Monotonic); + cumulative_count += count; + buckets.appendAssumeCapacity(.{ + .cumulative_count = cumulative_count, + .upper_bound = upper_bound, + }); + _ = hot_shard.buckets.items[i].fetchAdd(count, .Monotonic); + } + _ = hot_shard.sum.fetchAdd(cold_shard_sum, .Monotonic); + _ = hot_shard.count.fetchAdd(shard_sync.count, .Monotonic); + + return HistogramSnapshot.init(cold_shard_sum, shard_sync.count, buckets); + } + + fn getResult(metric: *Metric, allocator: Allocator) Metric.Error!Metric.Result { + const self = @fieldParentPtr(Self, "metric", metric); + const snapshot = try self.getSnapshot(allocator); + return Metric.Result{ .histogram = snapshot }; + } + + /// Increases the global count (used for synchronization), not a count within a shard. + /// Returns the state from before this operation, which was replaced by this operation. + fn incrementCount(self: *@This(), comptime ordering: Ordering) ShardSync { + return @bitCast(self.shard_sync.fetchAdd(1, ordering)); + } + + /// Makes the hot shard cold and vice versa. + /// Returns the state from before this operation, which was replaced by this operation. + fn flipShard(self: *@This(), comptime ordering: Ordering) ShardSync { + const data = self.shard_sync.fetchAdd(@bitCast(ShardSync{ .shard = 1 }), ordering); + return @bitCast(data); + } +}; + +/// A snapshot of the histogram state from a point in time. +pub const HistogramSnapshot = struct { + /// Sum of all values observed by the histogram. + sum: f64, + /// Total number of events observed by the histogram. + count: u64, + /// Cumulative histogram counts. + /// + /// The len *must* be the same as the amount of memory that was + /// allocated for this slice, or else the memory will leak. + buckets: []Bucket, + /// Allocator that was used to allocate the buckets. + allocator: Allocator, + + pub fn init(sum: f64, count: u64, buckets: ArrayList(Bucket)) @This() { + std.debug.assert(buckets.capacity == buckets.items.len); + return .{ + .sum = sum, + .count = count, + .buckets = buckets.items, + .allocator = buckets.allocator, + }; + } + + pub fn deinit(self: *const @This()) void { + self.allocator.free(self.buckets); + } +}; + +pub const Bucket = struct { + cumulative_count: u64 = 0, + upper_bound: f64 = 0, +}; + +test "prometheus.histogram: empty" { + const allocator = std.testing.allocator; + var hist = try Histogram.init(allocator, &default_buckets); + defer hist.deinit(); + + var snapshot = try hist.getSnapshot(null); + defer snapshot.deinit(); + + try expectSnapshot(0, &default_buckets, &(.{0} ** 11), snapshot); +} + +test "prometheus.histogram: data goes in correct buckets" { + const allocator = std.testing.allocator; + var hist = try Histogram.init(allocator, &default_buckets); + defer hist.deinit(); + + const expected_buckets = observeVarious(&hist); + + var snapshot = try hist.getSnapshot(null); + defer snapshot.deinit(); + + try expectSnapshot(7, &default_buckets, &expected_buckets, snapshot); +} + +test "prometheus.histogram: repeated snapshots measure the same thing" { + const allocator = std.testing.allocator; + var hist = try Histogram.init(allocator, &default_buckets); + defer hist.deinit(); + + const expected_buckets = observeVarious(&hist); + + var snapshot1 = try hist.getSnapshot(null); + snapshot1.deinit(); + var snapshot = try hist.getSnapshot(null); + defer snapshot.deinit(); + + try expectSnapshot(7, &default_buckets, &expected_buckets, snapshot); +} + +test "prometheus.histogram: values accumulate across snapshots" { + const allocator = std.testing.allocator; + var hist = try Histogram.init(allocator, &default_buckets); + defer hist.deinit(); + + _ = observeVarious(&hist); + + var snapshot1 = try hist.getSnapshot(null); + snapshot1.deinit(); + + hist.observe(1.0); + + var snapshot = try hist.getSnapshot(null); + defer snapshot.deinit(); + + const expected_buckets: [11]u64 = .{ 1, 1, 1, 1, 4, 4, 4, 6, 7, 7, 7 }; + try expectSnapshot(8, &default_buckets, &expected_buckets, snapshot); +} + +test "prometheus.histogram: totals add up after concurrent reads and writes" { + const allocator = std.testing.allocator; + var hist = try Histogram.init(allocator, &default_buckets); + defer hist.deinit(); + + var threads: [4]std.Thread = undefined; + for (&threads) |*thread| { + thread.* = try std.Thread.spawn( + .{}, + struct { + fn run(h: *Histogram) void { + for (0..1000) |i| { + _ = observeVarious(h); + if (i % 10 == 0) { + (h.getSnapshot(null) catch @panic("snapshot")).deinit(); + } + } + } + }.run, + .{&hist}, + ); + } + for (&threads) |*thread| thread.join(); + + const snapshot = try hist.getSnapshot(allocator); + defer snapshot.deinit(); + + var expected = ArrayList(u64).init(allocator); + defer expected.deinit(); + for (result) |r| { + try expected.append(4000 * r); + } + try expectSnapshot(28000, &default_buckets, expected.items, snapshot); +} + +fn observeVarious(hist: *Histogram) [11]u64 { + hist.observe(1.0); + hist.observe(0.1); + hist.observe(2.0); + hist.observe(0.1); + hist.observe(0.0000000001); + hist.observe(0.1); + hist.observe(100.0); + return result; +} + +const result: [11]u64 = .{ 1, 1, 1, 1, 4, 4, 4, 5, 6, 6, 6 }; + +fn expectSnapshot( + expected_total: u64, + expected_bounds: []const f64, + expected_buckets: []const u64, + snapshot: anytype, +) !void { + try std.testing.expectEqual(expected_total, snapshot.count); + try std.testing.expectEqual(default_buckets.len, snapshot.buckets.len); + for (0.., snapshot.buckets) |i, bucket| { + try expectEqual(expected_buckets[i], bucket.cumulative_count, "value in bucket {}\n", .{i}); + try expectEqual(expected_bounds[i], bucket.upper_bound, "bound for bucket {}\n", .{i}); + } +} + +fn expectEqual(expected: anytype, actual: anytype, comptime fmt: anytype, args: anytype) !void { + std.testing.expectEqual(expected, actual) catch |e| { + std.debug.print(fmt, args); + return e; + }; + return; +} diff --git a/src/prometheus/http.zig b/src/prometheus/http.zig new file mode 100644 index 000000000..201321b8c --- /dev/null +++ b/src/prometheus/http.zig @@ -0,0 +1,71 @@ +const std = @import("std"); + +const httpz = @import("httpz"); + +const Level = @import("../trace/level.zig").Level; +const Registry = @import("registry.zig").Registry; +const global_registry = @import("registry.zig").global_registry; +const default_buckets = @import("histogram.zig").default_buckets; + +pub fn servePrometheus( + allocator: std.mem.Allocator, + registry: *Registry(.{}), + port: u16, +) !void { + const endpoint = MetricsEndpoint{ + .allocator = allocator, + .registry = registry, + }; + var server = try httpz.ServerCtx(*const MetricsEndpoint, *const MetricsEndpoint).init( + allocator, + .{ .port = port }, + &endpoint, + ); + var router = server.router(); + router.get("/metrics", getMetrics); + return server.listen(); +} + +const MetricsEndpoint = struct { + allocator: std.mem.Allocator, + registry: *Registry(.{}), +}; + +pub fn getMetrics( + self: *const MetricsEndpoint, + _: *httpz.Request, + response: *httpz.Response, +) !void { + try self.registry.write(self.allocator, response.writer()); +} + +/// Runs a test prometheus endpoint with dummy data. +pub fn main() !void { + const alloc = std.heap.page_allocator; + _ = try global_registry.initialize(Registry(.{}).init, .{alloc}); + + _ = try std.Thread.spawn( + .{}, + struct { + fn run() !void { + const reg = try global_registry.get(); + var secs_counter = try reg.getOrCreateCounter("seconds_since_start"); + var gauge = try reg.getOrCreateGauge("seconds_hand", u64); + var hist = try reg.getOrCreateHistogram("hist", &default_buckets); + while (true) { + std.time.sleep(1_000_000_000); + secs_counter.inc(); + gauge.set(@as(u64, @intCast(std.time.timestamp())) % @as(u64, 60)); + hist.observe(1.1); + hist.observe(0.02); + } + } + }.run, + .{}, + ); + try servePrometheus( + alloc, + try global_registry.get(), + 12345, + ); +} diff --git a/src/prometheus/metric.zig b/src/prometheus/metric.zig new file mode 100644 index 000000000..849475f1f --- /dev/null +++ b/src/prometheus/metric.zig @@ -0,0 +1,160 @@ +const std = @import("std"); +const fmt = std.fmt; +const mem = std.mem; +const testing = std.testing; + +const HistogramSnapshot = @import("histogram.zig").HistogramSnapshot; + +pub const Metric = struct { + pub const Error = error{OutOfMemory} || std.os.WriteError || std.http.Server.Response.Writer.Error; + + pub const Result = union(enum) { + const Self = @This(); + + counter: u64, + gauge: f64, + gauge_int: u64, + histogram: HistogramSnapshot, + + pub fn deinit(self: Self, allocator: mem.Allocator) void { + switch (self) { + .histogram => |v| { + allocator.free(v.buckets); + }, + else => {}, + } + } + }; + + getResultFn: *const fn (self: *Metric, allocator: mem.Allocator) Error!Result, + + pub fn write(self: *Metric, allocator: mem.Allocator, writer: anytype, name: []const u8) Error!void { + const result = try self.getResultFn(self, allocator); + defer result.deinit(allocator); + + switch (result) { + .counter, .gauge_int => |v| { + return try writer.print("{s} {d}\n", .{ name, v }); + }, + .gauge => |v| { + return try writer.print("{s} {d:.6}\n", .{ name, v }); + }, + .histogram => |v| { + if (v.buckets.len <= 0) return; + + const name_and_labels = splitName(name); + + if (name_and_labels.labels.len > 0) { + for (v.buckets) |bucket| { + try writer.print("{s}_bucket{{{s},le=\"{s}\"}} {d:.6}\n", .{ + name_and_labels.name, + name_and_labels.labels, + floatMetric(bucket.upper_bound), + bucket.cumulative_count, + }); + } + try writer.print("{s}_sum{{{s}}} {:.6}\n", .{ + name_and_labels.name, + name_and_labels.labels, + floatMetric(v.sum), + }); + try writer.print("{s}_count{{{s}}} {d}\n", .{ + name_and_labels.name, + name_and_labels.labels, + v.count, + }); + } else { + for (v.buckets) |bucket| { + try writer.print("{s}_bucket{{le=\"{s}\"}} {d:.6}\n", .{ + name_and_labels.name, + floatMetric(bucket.upper_bound), + bucket.cumulative_count, + }); + } + try writer.print("{s}_sum {:.6}\n", .{ + name_and_labels.name, + floatMetric(v.sum), + }); + try writer.print("{s}_count {d}\n", .{ + name_and_labels.name, + v.count, + }); + } + }, + } + } +}; + +/// Converts a float into an anonymous type that can be formatted properly for prometheus. +pub fn floatMetric(value: anytype) struct { + value: @TypeOf(value), + + pub fn format(self: @This(), comptime format_string: []const u8, options: fmt.FormatOptions, writer: anytype) !void { + _ = format_string; + + const as_int: u64 = @intFromFloat(self.value); + if (@as(f64, @floatFromInt(as_int)) == self.value) { + try fmt.formatInt(as_int, 10, .lower, options, writer); + } else { + try fmt.formatFloatDecimal(self.value, options, writer); + } + } +} { + return .{ .value = value }; +} + +const NameAndLabels = struct { + name: []const u8, + labels: []const u8 = "", +}; + +fn splitName(name: []const u8) NameAndLabels { + const bracket_pos = mem.indexOfScalar(u8, name, '{'); + if (bracket_pos) |pos| { + return NameAndLabels{ + .name = name[0..pos], + .labels = name[pos + 1 .. name.len - 1], + }; + } else { + return NameAndLabels{ + .name = name, + }; + } +} + +test "prometheus.metric: ensure splitName works" { + const TestCase = struct { + input: []const u8, + exp: NameAndLabels, + }; + + const test_cases = &[_]TestCase{ + .{ + .input = "foobar", + .exp = .{ + .name = "foobar", + }, + }, + .{ + .input = "foobar{route=\"/home\"}", + .exp = .{ + .name = "foobar", + .labels = "route=\"/home\"", + }, + }, + .{ + .input = "foobar{route=\"/home\",status=\"500\"}", + .exp = .{ + .name = "foobar", + .labels = "route=\"/home\",status=\"500\"", + }, + }, + }; + + inline for (test_cases) |tc| { + const res = splitName(tc.input); + + try testing.expectEqualStrings(tc.exp.name, res.name); + try testing.expectEqualStrings(tc.exp.labels, res.labels); + } +} diff --git a/src/prometheus/registry.zig b/src/prometheus/registry.zig new file mode 100644 index 000000000..4b4eb3ff2 --- /dev/null +++ b/src/prometheus/registry.zig @@ -0,0 +1,368 @@ +const std = @import("std"); +const fmt = std.fmt; +const hash_map = std.hash_map; +const heap = std.heap; +const mem = std.mem; +const testing = std.testing; + +const OnceCell = @import("../sync/once_cell.zig").OnceCell; + +const Metric = @import("metric.zig").Metric; +const Counter = @import("counter.zig").Counter; +const Gauge = @import("gauge.zig").Gauge; +const GaugeFn = @import("gauge_fn.zig").GaugeFn; +const GaugeCallFnType = @import("gauge_fn.zig").GaugeCallFnType; +const Histogram = @import("histogram.zig").Histogram; +const default_buckets = @import("histogram.zig").default_buckets; + +pub const GetMetricError = error{ + /// Returned when trying to add a metric to an already full registry. + TooManyMetrics, + /// Returned when the name of name is bigger than the configured max_name_len. + NameTooLong, + + OutOfMemory, + /// Attempted to get a metric of the wrong type. + InvalidType, +}; + +/// Global registry singleton for convenience. +pub const global_registry: *OnceCell(Registry(.{})) = &global_registry_owned; +var global_registry_owned: OnceCell(Registry(.{})) = .{}; + +const RegistryOptions = struct { + max_metrics: comptime_int = 8192, + max_name_len: comptime_int = 1024, +}; + +pub fn Registry(comptime options: RegistryOptions) type { + return struct { + const Self = @This(); + + const MetricMap = hash_map.StringHashMapUnmanaged(struct { + /// Used to validate the pointer is cast into a valid  type. + type_name: []const u8, + metric: *Metric, + }); + + arena_state: heap.ArenaAllocator, + mutex: std.Thread.Mutex, + metrics: MetricMap, + + pub fn init(allocator: mem.Allocator) Self { + return .{ + .arena_state = heap.ArenaAllocator.init(allocator), + .mutex = .{}, + .metrics = MetricMap{}, + }; + } + + pub fn deinit(self: *Self) void { + self.arena_state.deinit(); + } + + fn nbMetrics(self: *const Self) usize { + return self.metrics.count(); + } + + pub fn getOrCreateCounter(self: *Self, name: []const u8) GetMetricError!*Counter { + return self.getOrCreateMetric(name, Counter, .{}); + } + + pub fn getOrCreateGauge(self: *Self, name: []const u8, comptime T: type) GetMetricError!*Gauge(T) { + return self.getOrCreateMetric(name, Gauge(T), .{}); + } + + pub fn getOrCreateGaugeFn( + self: *Self, + name: []const u8, + state: anytype, + callFn: GaugeCallFnType(@TypeOf(state), f64), + ) GetMetricError!*GaugeFn(@TypeOf(state), Return(@TypeOf(callFn))) { + return self.getOrCreateMetric( + name, + GaugeFn(@TypeOf(state), Return(@TypeOf(callFn))), + .{ callFn, state }, + ); + } + + pub fn getOrCreateHistogram( + self: *Self, + name: []const u8, + buckets: []const f64, + ) GetMetricError!*Histogram { + return self.getOrCreateMetric(name, Histogram, .{buckets}); + } + + /// MetricType must be initializable in one of these ways: + /// - try MetricType.init(allocator, ...args) + /// - MetricType.init(...args) + /// - as args struct (only if no init method is defined) + fn getOrCreateMetric( + self: *Self, + name: []const u8, + comptime MetricType: type, + args: anytype, + ) GetMetricError!*MetricType { + if (self.nbMetrics() >= options.max_metrics) return error.TooManyMetrics; + if (name.len > options.max_name_len) return error.NameTooLong; + + var allocator = self.arena_state.allocator(); + + const duped_name = try allocator.dupe(u8, name); + + self.mutex.lock(); + defer self.mutex.unlock(); + + const gop = try self.metrics.getOrPut(allocator, duped_name); + if (!gop.found_existing) { + var real_metric = try allocator.create(MetricType); + if (@hasDecl(MetricType, "init")) { + const params = @typeInfo(@TypeOf(MetricType.init)).Fn.params; + if (params.len != 0 and params[0].type.? == mem.Allocator) { + real_metric.* = try @call(.auto, MetricType.init, .{allocator} ++ args); + } else { + real_metric.* = @call(.auto, MetricType.init, args); + } + } else { + real_metric.* = args; + } + gop.value_ptr.* = .{ + .type_name = @typeName(MetricType), + .metric = &real_metric.metric, + }; + } else if (!std.mem.eql(u8, gop.value_ptr.*.type_name, @typeName(MetricType))) { + return GetMetricError.InvalidType; + } + + return @fieldParentPtr(MetricType, "metric", gop.value_ptr.*.metric); + } + + pub fn write(self: *Self, allocator: mem.Allocator, writer: anytype) !void { + var arena_state = heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + + self.mutex.lock(); + defer self.mutex.unlock(); + + try writeMetrics(arena_state.allocator(), self.metrics, writer); + } + + fn writeMetrics(allocator: mem.Allocator, map: MetricMap, writer: anytype) !void { + // Get the keys, sorted + const keys = blk: { + var key_list = try std.ArrayList([]const u8).initCapacity(allocator, map.count()); + + var key_iter = map.keyIterator(); + while (key_iter.next()) |key| { + key_list.appendAssumeCapacity(key.*); + } + + break :blk key_list.items; + }; + defer allocator.free(keys); + + std.mem.sort([]const u8, keys, {}, stringLessThan); + + // Write each metric in key order + for (keys) |key| { + var value = map.get(key) orelse unreachable; + try value.metric.write(allocator, writer, key); + } + } + }; +} + +/// Gets the return type of a function or function pointer +fn Return(comptime FnPtr: type) type { + return switch (@typeInfo(FnPtr)) { + .Fn => |fun| fun.return_type.?, + .Pointer => |ptr| @typeInfo(ptr.child).Fn.return_type.?, + else => @compileError("not a function or function pointer"), + }; +} + +fn stringLessThan(context: void, lhs: []const u8, rhs: []const u8) bool { + _ = context; + return mem.lessThan(u8, lhs, rhs); +} + +test "prometheus.registry: getOrCreateCounter" { + var registry = Registry(.{}).init(testing.allocator); + defer registry.deinit(); + + const name = try fmt.allocPrint(testing.allocator, "http_requests{{status=\"{d}\"}}", .{500}); + defer testing.allocator.free(name); + + var i: usize = 0; + while (i < 10) : (i += 1) { + var counter = try registry.getOrCreateCounter(name); + counter.inc(); + } + + var counter = try registry.getOrCreateCounter(name); + try testing.expectEqual(@as(u64, 10), counter.get()); +} + +test "prometheus.registry: getOrCreateX requires the same type" { + var registry = Registry(.{}).init(testing.allocator); + defer registry.deinit(); + + const name = try fmt.allocPrint(testing.allocator, "http_requests{{status=\"{d}\"}}", .{500}); + defer testing.allocator.free(name); + + _ = try registry.getOrCreateCounter(name); + if (registry.getOrCreateGauge(name, u64)) |_| try testing.expect(false) else |_| {} +} + +test "prometheus.registry: write" { + const TestCase = struct { + counter_name: []const u8, + gauge_name: []const u8, + gauge_fn_name: []const u8, + histogram_name: []const u8, + exp: []const u8, + }; + + const exp1 = + \\http_conn_pool_size 4.000000 + \\http_gauge 13 + \\http_request_size_bucket{le="0.005"} 0 + \\http_request_size_bucket{le="0.01"} 0 + \\http_request_size_bucket{le="0.025"} 0 + \\http_request_size_bucket{le="0.05"} 0 + \\http_request_size_bucket{le="0.1"} 0 + \\http_request_size_bucket{le="0.25"} 0 + \\http_request_size_bucket{le="0.5"} 0 + \\http_request_size_bucket{le="1"} 0 + \\http_request_size_bucket{le="2.5"} 1 + \\http_request_size_bucket{le="5"} 1 + \\http_request_size_bucket{le="10"} 2 + \\http_request_size_sum 18.703600 + \\http_request_size_count 3 + \\http_requests 2 + \\ + ; + + const exp2 = + \\http_conn_pool_size{route="/api/v2/users"} 4.000000 + \\http_gauge{route="/api/v2/users"} 13 + \\http_request_size_bucket{route="/api/v2/users",le="0.005"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="0.01"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="0.025"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="0.05"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="0.1"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="0.25"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="0.5"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="1"} 0 + \\http_request_size_bucket{route="/api/v2/users",le="2.5"} 1 + \\http_request_size_bucket{route="/api/v2/users",le="5"} 1 + \\http_request_size_bucket{route="/api/v2/users",le="10"} 2 + \\http_request_size_sum{route="/api/v2/users"} 18.703600 + \\http_request_size_count{route="/api/v2/users"} 3 + \\http_requests{route="/api/v2/users"} 2 + \\ + ; + + const test_cases = &[_]TestCase{ + .{ + .counter_name = "http_requests", + .gauge_name = "http_gauge", + .gauge_fn_name = "http_conn_pool_size", + .histogram_name = "http_request_size", + .exp = exp1, + }, + .{ + .counter_name = "http_requests{route=\"/api/v2/users\"}", + .gauge_name = "http_gauge{route=\"/api/v2/users\"}", + .gauge_fn_name = "http_conn_pool_size{route=\"/api/v2/users\"}", + .histogram_name = "http_request_size{route=\"/api/v2/users\"}", + .exp = exp2, + }, + }; + + inline for (test_cases) |tc| { + var registry = Registry(.{}).init(testing.allocator); + defer registry.deinit(); + + // Add some counters + { + var counter = try registry.getOrCreateCounter(tc.counter_name); + counter.* = .{ .value = .{ .value = 2 } }; + } + + // Add some gauges + { + var counter = try registry.getOrCreateGauge(tc.gauge_name, u64); + counter.* = .{ .value = .{ .value = 13 } }; + } + + // Add some gauge_fns + { + _ = try registry.getOrCreateGaugeFn( + tc.gauge_fn_name, + @as(f64, 4.0), + struct { + fn get(s: *f64) f64 { + return s.*; + } + }.get, + ); + } + + // Add a histogram + { + var histogram = try registry.getOrCreateHistogram(tc.histogram_name, &default_buckets); + + histogram.observe(5.0012); + histogram.observe(12.30240); + histogram.observe(1.40); + } + + // Write to a buffer + { + var buffer = std.ArrayList(u8).init(testing.allocator); + defer buffer.deinit(); + + try registry.write(testing.allocator, buffer.writer()); + + try testing.expectEqualStrings(tc.exp, buffer.items); + } + + // Write to a file + { + const filename = "prometheus_metrics.txt"; + var file = try std.fs.cwd().createFile(filename, .{ .read = true }); + defer { + file.close(); + std.fs.cwd().deleteFile(filename) catch {}; + } + + try registry.write(testing.allocator, file.writer()); + + try file.seekTo(0); + const file_data = try file.readToEndAlloc(testing.allocator, std.math.maxInt(usize)); + defer testing.allocator.free(file_data); + + try testing.expectEqualStrings(tc.exp, file_data); + } + } +} + +test "prometheus.registry: options" { + var registry = Registry(.{ .max_metrics = 1, .max_name_len = 4 }).init(testing.allocator); + defer registry.deinit(); + + { + try testing.expectError(error.NameTooLong, registry.getOrCreateCounter("hello")); + _ = try registry.getOrCreateCounter("foo"); + } + + { + try testing.expectError(error.TooManyMetrics, registry.getOrCreateCounter("bar")); + } +} + +test { + testing.refAllDecls(@This()); +} diff --git a/src/sync/once_cell.zig b/src/sync/once_cell.zig new file mode 100644 index 000000000..f8610fa7c --- /dev/null +++ b/src/sync/once_cell.zig @@ -0,0 +1,200 @@ +const std = @import("std"); + +/// Thread-safe data structure that can only be written to once. +/// WARNING: This does not make the inner type thread-safe. +/// +/// All fields are private. Direct access leads to undefined behavior. +/// +/// 1. When this struct is initialized, the contained type is missing. +/// 2. Call one of the init methods to initialize the contained type. +/// 3. After initialization: +/// - get methods will return the initialized value. +/// - value may not be re-initialized. +pub fn OnceCell(comptime T: type) type { + return struct { + value: T = undefined, + started: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), + finished: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), + + const Self = @This(); + + pub fn init() Self { + return .{}; + } + + /// Initializes the inner value and returns pointer to it. + /// Returns error if it was already initialized. + /// Blocks while other threads are in the process of initialization. + pub fn initialize(self: *Self, initLogic: anytype, init_args: anytype) error{AlreadyInitialized}!*T { + if (!self.acquire()) return error.AlreadyInitialized; + self.value = @call(.auto, initLogic, init_args); + self.finished.store(true, .Release); + return &self.value; + } + + /// Tries to initialize the inner value and returns pointer to it, or return error if it fails. + /// Returns error if it was already initialized. + /// Blocks while other threads are in the process of initialization. + pub fn tryInit(self: *Self, initLogic: anytype, init_args: anytype) !*T { + if (!self.acquire()) return error.AlreadyInitialized; + errdefer self.started.store(false, .Release); + self.value = try @call(.auto, initLogic, init_args); + self.finished.store(true, .Release); + return &self.value; + } + + /// Returns pointer to inner value if already initialized. + /// Otherwise initializes the value and returns it. + /// Blocks while other threads are in the process of initialization. + pub fn getOrInit(self: *Self, initLogic: anytype, init_args: anytype) *T { + if (self.acquire()) { + self.value = @call(.auto, initLogic, init_args); + self.finished.store(true, .Release); + } + return &self.value; + } + + /// Returns pointer to inner value if already initialized. + /// Otherwise tries to initialize the value and returns it, or return error if it fails. + /// Blocks while other threads are in the process of initialization. + pub fn getOrTryInit(self: *Self, initLogic: anytype, init_args: anytype) !*T { + if (self.acquire()) { + errdefer self.started.store(false, .Release); + self.value = try @call(.auto, initLogic, init_args); + self.finished.store(true, .Release); + } + return &self.value; + } + + /// Tries to acquire the write lock. + /// returns: + /// - true if write lock is acquired. + /// - false if write lock is not acquirable because a write was already completed. + /// - waits if another thread has a write in progress. if the other thread fails, this may acquire the lock. + fn acquire(self: *Self) bool { + while (self.started.compareAndSwap(false, true, .Acquire, .Monotonic)) |_| { + if (self.finished.load(.Acquire)) { + return false; + } + } + return true; + } + + /// Returns the value if initialized. + /// Returns error if not initialized. + /// Blocks while other threads are in the process of initialization. + pub fn get(self: *Self) error{NotInitialized}!*T { + if (self.finished.load(.Acquire)) { + return &self.value; + } + while (self.started.load(.Monotonic)) { + if (self.finished.load(.Acquire)) { + return &self.value; + } + } + return error.NotInitialized; + } + }; +} + +test "sync.once_cell: init returns correctly" { + var oc = OnceCell(u64).init(); + const x = try oc.initialize(returns(10), .{}); + try std.testing.expect(10 == x.*); +} + +test "sync.once_cell: cannot get uninitialized" { + var oc = OnceCell(u64).init(); + if (oc.get()) |_| { + try std.testing.expect(false); + } else |_| {} +} + +test "sync.once_cell: can get initialized" { + var oc = OnceCell(u64).init(); + _ = try oc.initialize(returns(10), .{}); + const x = try oc.get(); + try std.testing.expect(10 == x.*); +} + +test "sync.once_cell: tryInit returns error on failure" { + var oc = OnceCell(u64).init(); + const err = oc.tryInit(returnErr, .{}); + try std.testing.expectError(error.TestErr, err); +} + +test "sync.once_cell: tryInit works on success" { + var oc = OnceCell(u64).init(); + const x1 = try oc.tryInit(returnNotErr(10), .{}); + const x2 = try oc.get(); + try std.testing.expect(10 == x1.*); + try std.testing.expect(10 == x2.*); +} + +test "sync.once_cell: tryInit returns error if initialized" { + var oc = OnceCell(u64).init(); + const x1 = try oc.tryInit(returnNotErr(10), .{}); + const err = oc.tryInit(returnNotErr(11), .{}); + const x2 = try oc.get(); + try std.testing.expect(10 == x1.*); + try std.testing.expectError(error.AlreadyInitialized, err); + try std.testing.expect(10 == x2.*); +} + +test "sync.once_cell: getOrInit can initialize when needed" { + var oc = OnceCell(u64).init(); + const x1 = oc.getOrInit(returns(10), .{}); + const x2 = try oc.get(); + try std.testing.expect(10 == x1.*); + try std.testing.expect(10 == x2.*); +} + +test "sync.once_cell: getOrInit uses already initialized value" { + var oc = OnceCell(u64).init(); + const x1 = oc.getOrInit(returns(10), .{}); + const x2 = oc.getOrInit(returns(11), .{}); + try std.testing.expect(10 == x1.*); + try std.testing.expect(10 == x2.*); +} + +test "sync.once_cell: getOrTryInit returns error on failure" { + var oc = OnceCell(u64).init(); + const err = oc.getOrTryInit(returnErr, .{}); + try std.testing.expectError(error.TestErr, err); +} + +test "sync.once_cell: getOrTryInit works on success" { + var oc = OnceCell(u64).init(); + const x1 = try oc.getOrTryInit(returnNotErr(10), .{}); + const x2 = try oc.get(); + try std.testing.expect(10 == x1.*); + try std.testing.expect(10 == x2.*); +} + +test "sync.once_cell: getOrTryInit uses already initialized value" { + var oc = OnceCell(u64).init(); + const x1 = try oc.getOrTryInit(returnNotErr(10), .{}); + const x2 = try oc.getOrTryInit(returnNotErr(11), .{}); + try std.testing.expect(10 == x1.*); + try std.testing.expect(10 == x2.*); +} + +fn returns(comptime x: u64) fn () u64 { + return struct { + fn get() u64 { + return x; + } + }.get; +} + +fn returnNotErr(comptime x: u64) fn () error{}!u64 { + return struct { + fn get() !u64 { + return x; + } + }.get; +} + +fn returnErr() !u64 { + return error.TestErr; +}