diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index d17f0d9f..b210cfbb 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -43,7 +43,6 @@ use std::path::Path; use std::sync::Arc; use std::time::Duration; use std::{env, process}; -use tracing_subscriber::Layer; #[cfg(feature = "distributed_storage")] use clap::parser::ValuesRef; @@ -52,9 +51,10 @@ use sysinfo::{MemoryRefreshKind, RefreshKind, System}; use thiserror::Error; use tokio::runtime::Handle; use tracing::level_filters::LevelFilter; +use tracing::Subscriber; use tracing_subscriber::fmt::format::FmtSpan; -use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{layer::SubscriberExt, Layer}; mod envoy_rls; mod http_api; @@ -217,63 +217,7 @@ async fn main() -> Result<(), Box> { let (config, version) = create_config(); println!("{LIMITADOR_HEADER} {version}"); - let level = config.log_level.unwrap_or_else(|| { - tracing_subscriber::filter::EnvFilter::from_default_env() - .max_level_hint() - .unwrap_or(LevelFilter::ERROR) - }); - - let fmt_layer = tracing_subscriber::fmt::layer() - .with_span_events(if level >= LevelFilter::DEBUG { - FmtSpan::CLOSE - } else { - FmtSpan::NONE - }) - .with_filter(level); - - let metrics_layer = MetricsLayer::new() - .gather( - "should_rate_limit", - PrometheusMetrics::record_datastore_latency, - vec!["datastore"], - ) - .gather( - "flush_batcher_and_update_counters", - PrometheusMetrics::record_datastore_latency, - vec!["datastore"], - ); - - if !config.tracing_endpoint.is_empty() { - global::set_text_map_propagator(TraceContextPropagator::new()); - - let tracer = opentelemetry_otlp::new_pipeline() - .tracing() - .with_exporter( - opentelemetry_otlp::new_exporter() - .tonic() - .with_endpoint(config.tracing_endpoint.clone()), - ) - .with_trace_config(trace::config().with_resource(Resource::new(vec![ - KeyValue::new("service.name", "limitador"), - ]))) - .install_batch(opentelemetry_sdk::runtime::Tokio)?; - - let telemetry_layer = tracing_opentelemetry::layer().with_tracer(tracer); - - // Init tracing subscriber with telemetry - tracing_subscriber::registry() - .with(metrics_layer) - .with(fmt_layer) - .with(level.max(LevelFilter::INFO)) - .with(telemetry_layer) - .init(); - } else { - // Init tracing subscriber without telemetry - tracing_subscriber::registry() - .with(metrics_layer) - .with(fmt_layer) - .init(); - }; + configure_tracing_subscriber(&config); info!("Version: {}", version); info!("Using config: {:?}", config); @@ -808,3 +752,91 @@ fn guess_cache_size() -> Option { fn leak(s: D) -> &'static str { return Box::leak(format!("{}", s).into_boxed_str()); } + +fn configure_tracing_subscriber(config: &Configuration) { + let level = config.log_level.unwrap_or_else(|| { + tracing_subscriber::filter::EnvFilter::from_default_env() + .max_level_hint() + .unwrap_or(LevelFilter::ERROR) + }); + + let metrics_layer = MetricsLayer::default() + .gather( + "should_rate_limit", + PrometheusMetrics::record_datastore_latency, + vec!["datastore"], + ) + .gather( + "flush_batcher_and_update_counters", + PrometheusMetrics::record_datastore_latency, + vec!["datastore"], + ); + + if !config.tracing_endpoint.is_empty() { + // Init tracing subscriber with telemetry + // If running in memory initialize without metrics + match config.storage { + StorageConfiguration::InMemory(_) => tracing_subscriber::registry() + .with(fmt_layer(level)) + .with(telemetry_layer(&config.tracing_endpoint, level)) + .init(), + _ => tracing_subscriber::registry() + .with(metrics_layer) + .with(fmt_layer(level)) + .with(telemetry_layer(&config.tracing_endpoint, level)) + .init(), + } + } else { + // If running in memory initialize without metrics + match config.storage { + StorageConfiguration::InMemory(_) => { + tracing_subscriber::registry().with(fmt_layer(level)).init() + } + _ => tracing_subscriber::registry() + .with(metrics_layer) + .with(fmt_layer(level)) + .init(), + } + } +} + +fn fmt_layer(level: LevelFilter) -> impl Layer +where + S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>, +{ + tracing_subscriber::fmt::layer() + .with_span_events(if level >= LevelFilter::DEBUG { + FmtSpan::CLOSE + } else { + FmtSpan::NONE + }) + .with_filter(level) +} + +fn telemetry_layer(tracing_endpoint: &String, level: LevelFilter) -> impl Layer +where + S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>, +{ + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(tracing_endpoint), + ) + .with_trace_config( + trace::config().with_resource(Resource::new(vec![KeyValue::new( + "service.name", + "limitador", + )])), + ) + .install_batch(opentelemetry_sdk::runtime::Tokio) + .expect("error installing tokio tracing exporter"); + + // Set the level to minimum info if tracing enabled + tracing_opentelemetry::layer() + .with_tracer(tracer) + .with_filter(level.max(LevelFilter::INFO)) +} diff --git a/limitador-server/src/metrics.rs b/limitador-server/src/metrics.rs index d41df57f..cd3708e8 100644 --- a/limitador-server/src/metrics.rs +++ b/limitador-server/src/metrics.rs @@ -57,17 +57,16 @@ struct SpanState { impl SpanState { fn new(group: String) -> Self { - Self { - group_times: HashMap::from([(group, Timings::new())]), - } + let mut group_times = HashMap::new(); + group_times.insert(group, Timings::new()); + Self { group_times } } - fn increment(&mut self, group: &String, timings: Timings) -> &mut Self { + fn increment(&mut self, group: String, timings: Timings) { self.group_times - .entry(group.to_string()) + .entry(group) .and_modify(|x| *x += timings) .or_insert(timings); - self } } @@ -82,23 +81,18 @@ impl MetricsGroup { } } +#[derive(Default)] pub struct MetricsLayer { groups: HashMap, } impl MetricsLayer { - pub fn new() -> Self { - Self { - groups: HashMap::new(), - } - } - pub fn gather(mut self, aggregate: &str, consumer: fn(Timings), records: Vec<&str>) -> Self { // TODO(adam-cattermole): does not handle case where aggregate already exists - let rec = records.iter().map(|r| r.to_string()).collect(); + let rec = records.iter().map(|&r| r.to_string()).collect(); self.groups .entry(aggregate.to_string()) - .or_insert(MetricsGroup::new(Box::new(consumer), rec)); + .or_insert_with(|| MetricsGroup::new(Box::new(consumer), rec)); self } } @@ -111,7 +105,7 @@ where fn on_new_span(&self, _attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { let span = ctx.span(id).expect("Span not found, this is a bug"); let mut extensions = span.extensions_mut(); - let name = span.name(); + let name = span.name().to_string(); // if there's a parent if let Some(parent) = span.parent() { @@ -122,33 +116,32 @@ where } // if we are an aggregator - if self.groups.contains_key(name) { + if self.groups.contains_key(&name) { if let Some(span_state) = extensions.get_mut::() { // if the SpanState has come from parent and we must append // (we are a second level aggregator) span_state .group_times - .entry(name.to_string()) - .or_insert(Timings::new()); + .entry(name.clone()) + .or_insert_with(Timings::new); } else { // otherwise create a new SpanState with ourselves - extensions.insert(SpanState::new(name.to_string())) + extensions.insert(SpanState::new(name.to_owned())) } } if let Some(span_state) = extensions.get_mut::() { // either we are an aggregator or nested within one for group in span_state.group_times.keys() { - for record in &self + if self .groups .get(group) .expect("Span state contains group times for an unconfigured group") .records + .contains(&name) { - if name == record { - extensions.insert(Timings::new()); - return; - } + extensions.insert(Timings::new()); + return; } } // if here we are an intermediate span that should not be recorded @@ -156,54 +149,48 @@ where } fn on_enter(&self, id: &Id, ctx: Context<'_, S>) { - let span = ctx.span(id).expect("Span not found, this is a bug"); - let mut extensions = span.extensions_mut(); + if let Some(span) = ctx.span(id) { + if let Some(timings) = span.extensions_mut().get_mut::() { + let now = Instant::now(); + timings.idle += (now - timings.last).as_nanos() as u64; + timings.last = now; - if let Some(timings) = extensions.get_mut::() { - let now = Instant::now(); - timings.idle += (now - timings.last).as_nanos() as u64; - timings.last = now; - timings.updated = true; + timings.updated = true; + } } } fn on_exit(&self, id: &Id, ctx: Context<'_, S>) { - let span = ctx.span(id).expect("Span not found, this is a bug"); - let mut extensions = span.extensions_mut(); - - if let Some(timings) = extensions.get_mut::() { - let now = Instant::now(); - timings.busy += (now - timings.last).as_nanos() as u64; - timings.last = now; - timings.updated = true; + if let Some(span) = ctx.span(id) { + if let Some(timings) = span.extensions_mut().get_mut::() { + let now = Instant::now(); + timings.busy += (now - timings.last).as_nanos() as u64; + timings.last = now; + timings.updated = true; + } } } fn on_close(&self, id: Id, ctx: Context<'_, S>) { let span = ctx.span(&id).expect("Span not found, this is a bug"); let mut extensions = span.extensions_mut(); - let name = span.name(); - let mut t: Option = None; + let name = span.name().to_string(); - if let Some(timing) = extensions.get_mut::() { - let mut time = *timing; - time.idle += (Instant::now() - time.last).as_nanos() as u64; - t = Some(time); - } + let timing = extensions.get_mut::().map(|t| { + let now = Instant::now(); + t.idle += (now - t.last).as_nanos() as u64; + *t + }); if let Some(span_state) = extensions.get_mut::() { - if let Some(timing) = t { - let group_times = span_state.group_times.clone(); + if let Some(timing) = timing { // iterate over the groups this span belongs to - 'aggregate: for group in group_times.keys() { + for group in span_state.group_times.keys().cloned().collect::>() { // find the set of records related to these groups in the layer - for record in &self.groups.get(group).unwrap().records { + if self.groups.get(&group).unwrap().records.contains(&name) { // if we are a record for this group then increment the relevant // span-local timing and continue to the next group - if name == record { - span_state.increment(group, timing); - continue 'aggregate; - } + span_state.increment(group, timing); } } } @@ -214,8 +201,8 @@ where parent.extensions_mut().replace(span_state.clone()); } // IF we are aggregator call consume function - if let Some(metrics_group) = self.groups.get(name) { - if let Some(t) = span_state.group_times.get(name).filter(|&t| t.updated) { + if let Some(metrics_group) = self.groups.get(&name) { + if let Some(t) = span_state.group_times.get(&name).filter(|&t| t.updated) { (metrics_group.consumer)(*t); } } @@ -285,14 +272,14 @@ mod tests { #[test] fn span_state_increment() { let group = String::from("group"); - let mut span_state = SpanState::new(group.clone()); + let mut span_state = SpanState::new(group.to_owned()); let t1 = Timings { idle: 5, busy: 5, last: Instant::now(), updated: true, }; - span_state.increment(&group, t1); + span_state.increment(group.to_owned(), t1); assert_eq!(span_state.group_times.get(&group).unwrap().idle, t1.idle); assert_eq!(span_state.group_times.get(&group).unwrap().busy, t1.busy); } @@ -300,7 +287,7 @@ mod tests { #[test] fn metrics_layer() { let consumer = |_| println!("group/record"); - let ml = MetricsLayer::new().gather("group", consumer, vec!["record"]); + let ml = MetricsLayer::default().gather("group", consumer, vec!["record"]); assert_eq!(ml.groups.get("group").unwrap().records, vec!["record"]); } } diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index 148af984..6f2c743d 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -14,7 +14,7 @@ use std::collections::{BTreeSet, HashSet}; use std::ops::Deref; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use tracing::trace_span; +use tracing::debug_span; pub struct RocksDbStorage { db: DBWithThreadMode, @@ -53,7 +53,7 @@ impl CounterStorage for RocksDbStorage { let key = key_for_counter(counter); let slice: &[u8] = key.as_ref(); let entry = { - let span = trace_span!("datastore"); + let span = debug_span!("datastore"); let _entered = span.enter(); self.db.get(slice)? }; @@ -100,7 +100,7 @@ impl CounterStorage for RocksDbStorage { let mut iterator = self.db.prefix_iterator(prefix_for_namespace(ns)); loop { let option = { - let span = trace_span!("datastore"); + let span = debug_span!("datastore"); let _entered = span.enter(); iterator.next() }; @@ -138,7 +138,7 @@ impl CounterStorage for RocksDbStorage { fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { let counters = self.get_counters(limits)?; for counter in &counters { - let span = trace_span!("datastore"); + let span = debug_span!("datastore"); let _entered = span.enter(); self.db.delete(key_for_counter(counter))?; } @@ -147,10 +147,10 @@ impl CounterStorage for RocksDbStorage { #[tracing::instrument(skip_all)] fn clear(&self) -> Result<(), StorageErr> { - let span = trace_span!("datastore"); + let span = debug_span!("datastore"); let _entered = span.enter(); for entry in self.db.iterator(IteratorMode::Start) { - let span = trace_span!("datastore"); + let span = debug_span!("datastore"); let _entered = span.enter(); self.db.delete(entry?.0)? } @@ -203,7 +203,7 @@ impl RocksDbStorage { ) -> Result { let now = SystemTime::now(); let entry = { - let span = trace_span!("datastore"); + let span = debug_span!("datastore"); let _entered = span.enter(); self.db.get(key)? }; @@ -217,7 +217,7 @@ impl RocksDbStorage { if value.value_at(now) + delta <= counter.max_value() { let expiring_value = ExpiringValue::new(delta, now + Duration::from_secs(counter.limit().seconds())); - let span = trace_span!("datastore"); + let span = debug_span!("datastore"); let _entered = span.enter(); self.db .merge(key, >>::into(expiring_value))?; diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index d29e7b3a..ac2fdc51 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -15,7 +15,7 @@ use std::ops::Deref; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tracing::{debug_span, Instrument}; +use tracing::{info_span, Instrument}; // Note: this implementation does not guarantee exact limits. Ensuring that we // never go over the limits would hurt performance. This implementation @@ -39,7 +39,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { match con .get::>(key_for_counter(counter)) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? { Some(val) => Ok(u64::try_from(val).unwrap_or(0) + delta <= counter.max_value()), @@ -57,7 +57,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { .arg(counter.window().as_secs()) .arg(delta) .invoke_async::<_, _>(&mut con) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await?; Ok(()) @@ -84,7 +84,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { let script_res: Vec> = { script_invocation .invoke_async(&mut con) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? }; if let Some(res) = is_limited(counters, delta, script_res) { @@ -95,7 +95,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { redis::cmd("MGET") .arg(counter_keys.clone()) .query_async(&mut con) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? }; @@ -121,7 +121,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { .arg(counter.window().as_secs()) .arg(delta) .invoke_async::<_, _>(&mut con) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? } @@ -140,7 +140,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { for limit in limits { let counter_keys = { con.smembers::>(key_for_counters_of_limit(limit)) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? }; @@ -157,14 +157,14 @@ impl AsyncCounterStorage for AsyncRedisStorage { // unnecessarily. let option = { con.get::>(counter_key.clone()) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? }; if let Some(val) = option { counter.set_remaining(limit.max_value() - u64::try_from(val).unwrap_or(0)); let ttl: i64 = { con.ttl(&counter_key) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? }; counter.set_expires_in(Duration::from_secs(u64::try_from(ttl).unwrap_or(0))); @@ -181,7 +181,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { async fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { for limit in limits { self.delete_counters_associated_with_limit(limit.deref()) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? } Ok(()) @@ -192,7 +192,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { let mut con = self.conn_manager.clone(); redis::cmd("FLUSHDB") .query_async(&mut con) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await?; Ok(()) } @@ -219,13 +219,13 @@ impl AsyncRedisStorage { let counter_keys = { con.smembers::>(key_for_counters_of_limit(limit)) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await? }; for counter_key in counter_keys { con.del(counter_key) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await?; } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 9a3ae681..b23efc67 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -20,7 +20,7 @@ use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use tracing::{debug_span, error, info, warn, Instrument}; +use tracing::{error, info, info_span, warn, Instrument}; // This is just a first version. // @@ -313,7 +313,7 @@ async fn update_counters( // The redis crate is not working with tables, thus the response will be a Vec of counter values let script_res: Vec = match script_invocation .invoke_async(redis_conn) - .instrument(debug_span!("datastore")) + .instrument(info_span!("datastore")) .await { Ok(res) => res,