From c016afc1a8f95f63b415a9d7e9e0b988695fdf34 Mon Sep 17 00:00:00 2001 From: Michael Grosse Huelsewiesche Date: Tue, 3 Dec 2024 12:23:39 -0500 Subject: [PATCH] Improving threading safety for telemetry --- .../analytics/kotlin/core/Telemetry.kt | 59 +++++++++++-------- .../analytics/kotlin/core/TelemetryTest.kt | 12 ++-- 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt b/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt index 7b7210c4..5f51343d 100644 --- a/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt +++ b/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt @@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.Executors import kotlin.math.min import kotlin.math.roundToInt +import java.util.concurrent.atomic.AtomicBoolean class MetricsRequestFactory : RequestFactory() { override fun upload(apiHost: String): HttpURLConnection { @@ -76,7 +77,14 @@ object Telemetry: Subscriber { var host: String = Constants.DEFAULT_API_HOST // 1.0 is 100%, will get set by Segment setting before start() // Values are adjusted by the sampleRate on send - var sampleRate: Double = 1.0 + @Volatile private var _sampleRate: Double = 1.0 + var sampleRate: Double + get() = _sampleRate + set(value) { + synchronized(this) { + _sampleRate = value + } + } var flushTimer: Int = 30 * 1000 // 30s var httpClient: HTTPClient = HTTPClient("", MetricsRequestFactory()) var sendWriteKeyOnError: Boolean = true @@ -93,9 +101,9 @@ object Telemetry: Subscriber { private val queue = ConcurrentLinkedQueue() private var queueBytes = 0 - private var started = false + private var started = AtomicBoolean(false) private var rateLimitEndTime: Long = 0 - private var flushFirstError = true + private var flushFirstError = AtomicBoolean(true) private val exceptionHandler = CoroutineExceptionHandler { _, t -> errorHandler?.let { it( Exception( @@ -113,8 +121,8 @@ object Telemetry: Subscriber { * Called automatically when Telemetry.enable is set to true and when configuration data is received from Segment. */ fun start() { - if (!enable || started || sampleRate == 0.0) return - started = true + if (!enable || started.get() || sampleRate == 0.0) return + started.set(true) // Everything queued was sampled at default 100%, downsample adjustment and send will adjust values if (Math.random() > sampleRate) { @@ -124,7 +132,7 @@ object Telemetry: Subscriber { telemetryJob = telemetryScope.launch(telemetryDispatcher) { while (isActive) { if (!enable) { - started = false + started.set(false) return@launch } try { @@ -148,7 +156,7 @@ object Telemetry: Subscriber { fun reset() { telemetryJob?.cancel() resetQueue() - started = false + started.set(false) rateLimitEndTime = 0 } @@ -202,8 +210,8 @@ object Telemetry: Subscriber { addRemoteMetric(metric, filteredTags, log=logData) - if(flushFirstError) { - flushFirstError = false + if(flushFirstError.get()) { + flushFirstError.set(false) flush() } } @@ -218,7 +226,6 @@ object Telemetry: Subscriber { try { send() - queueBytes = 0 } catch (error: Throwable) { errorHandler?.invoke(error) sampleRate = 0.0 @@ -227,16 +234,14 @@ object Telemetry: Subscriber { private fun send() { if (sampleRate == 0.0) return - var queueCount = queue.size - // Reset queue data size counter since all current queue items will be removed - queueBytes = 0 - val sendQueue = mutableListOf() - while (queueCount-- > 0 && !queue.isEmpty()) { - val m = queue.poll() - if(m != null) { - m.value = (m.value / sampleRate).roundToInt() - sendQueue.add(m) - } + val sendQueue: MutableList + synchronized(queue) { + sendQueue = queue.toMutableList() + queue.clear() + queueBytes = 0 + } + sendQueue.forEach { m -> + m.value = (m.value / sampleRate).roundToInt() } try { // Json.encodeToString by default does not include default values @@ -309,9 +314,11 @@ object Telemetry: Subscriber { tags = fullTags ) val newMetricSize = newMetric.toString().toByteArray().size - if (queueBytes + newMetricSize <= maxQueueBytes) { - queue.add(newMetric) - queueBytes += newMetricSize + synchronized(queue) { + if (queueBytes + newMetricSize <= maxQueueBytes) { + queue.add(newMetric) + queueBytes += newMetricSize + } } } @@ -338,7 +345,9 @@ object Telemetry: Subscriber { } private fun resetQueue() { - queue.clear() - queueBytes = 0 + synchronized(queue) { + queue.clear() + queueBytes = 0 + } } } \ No newline at end of file diff --git a/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt b/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt index df1ff354..ea2ed543 100644 --- a/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt +++ b/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt @@ -10,13 +10,15 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.CountDownLatch import java.util.concurrent.Executors import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean import kotlin.random.Random class TelemetryTest { fun TelemetryResetFlushFirstError() { val field: Field = Telemetry::class.java.getDeclaredField("flushFirstError") field.isAccessible = true - field.set(true, true) + val atomicBoolean = field.get(Telemetry) as AtomicBoolean + atomicBoolean.set(true) } fun TelemetryQueueSize(): Int { val queueField: Field = Telemetry::class.java.getDeclaredField("queue") @@ -29,11 +31,11 @@ class TelemetryTest { queueBytesField.isAccessible = true return queueBytesField.get(Telemetry) as Int } - var TelemetryStarted: Boolean + var TelemetryStarted: AtomicBoolean get() { val startedField: Field = Telemetry::class.java.getDeclaredField("started") startedField.isAccessible = true - return startedField.get(Telemetry) as Boolean + return startedField.get(Telemetry) as AtomicBoolean } set(value) { val startedField: Field = Telemetry::class.java.getDeclaredField("started") @@ -78,11 +80,11 @@ class TelemetryTest { Telemetry.sampleRate = 0.0 Telemetry.enable = true Telemetry.start() - assertEquals(false, TelemetryStarted) + assertEquals(false, TelemetryStarted.get()) Telemetry.sampleRate = 1.0 Telemetry.start() - assertEquals(true, TelemetryStarted) + assertEquals(true, TelemetryStarted.get()) assertEquals(0,errors.size) }