diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index ea43313..734bf3e 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -100,7 +100,7 @@ def log( ) for operation, metadata_size in splitter: - self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step) + self._operations_queue.enqueue(operation=operation, size=metadata_size) class Attribute: diff --git a/src/neptune_scale/sync/aggregating_queue.py b/src/neptune_scale/sync/aggregating_queue.py index 8e3fc40..0bf4722 100644 --- a/src/neptune_scale/sync/aggregating_queue.py +++ b/src/neptune_scale/sync/aggregating_queue.py @@ -71,7 +71,7 @@ def commit(self) -> None: def get(self) -> BatchedOperations: start = time.monotonic() - batch_operations: dict[Optional[float], RunOperation] = {} + batch_operations: list[RunOperation] = [] batch_sequence_id: Optional[int] = None batch_timestamp: Optional[float] = None @@ -95,7 +95,7 @@ def get(self) -> BatchedOperations: if not batch_operations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - batch_operations[element.batch_key] = new_operation + batch_operations.append(new_operation) batch_bytes += len(element.operation) else: if not element.is_batchable: @@ -110,10 +110,7 @@ def get(self) -> BatchedOperations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - if element.batch_key not in batch_operations: - batch_operations[element.batch_key] = new_operation - else: - merge_run_operation(batch_operations[element.batch_key], new_operation) + batch_operations.append(new_operation) batch_bytes += element.metadata_size batch_sequence_id = element.sequence_id @@ -157,54 +154,25 @@ def get(self) -> BatchedOperations: ) -def create_run_batch(operations: dict[Optional[float], RunOperation]) -> RunOperation: +def create_run_batch(operations: list[RunOperation]) -> RunOperation: + if not operations: + raise Empty + if len(operations) == 1: - return next(iter(operations.values())) + return operations[0] - batch = None - for _, operation in sorted(operations.items(), key=lambda x: (x[0] is not None, x[0])): - if batch is None: - batch = RunOperation() - batch.project = operation.project - batch.run_id = operation.run_id - batch.create_missing_project = operation.create_missing_project - batch.api_key = operation.api_key + head = operations[0] + batch = RunOperation() + batch.project = head.project + batch.run_id = head.run_id + batch.create_missing_project = head.create_missing_project + batch.api_key = head.api_key + for operation in operations: operation_type = operation.WhichOneof("operation") if operation_type == "update": batch.update_batch.snapshots.append(operation.update) else: raise ValueError("Cannot batch operation of type %s", operation_type) - if batch is None: - raise Empty return batch - - -def merge_run_operation(batch: RunOperation, operation: RunOperation) -> None: - """ - Merge the `operation` into `batch`, taking into account the special case of `modify_sets`. - - Protobuf merges existing map keys by simply overwriting values, instead of calling - `MergeFrom` on the existing value, eg: A['foo'] = B['foo']. - - We want this instead: - - batch = {'sys/tags': 'string': { 'values': {'foo': ADD}}} - operation = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - result = {'sys/tags': 'string': { 'values': {'foo': ADD, 'bar': ADD}}} - - If we called `batch.MergeFrom(operation)` we would get an overwritten value: - result = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - - This function ensures that the `modify_sets` are merged correctly, leaving the default - behaviour for all other fields. - """ - - modify_sets = operation.update.modify_sets - operation.update.ClearField("modify_sets") - - batch.MergeFrom(operation) - - for k, v in modify_sets.items(): - batch.update.modify_sets[k].MergeFrom(v) diff --git a/src/neptune_scale/sync/operations_queue.py b/src/neptune_scale/sync/operations_queue.py index 0069b47..d518da1 100644 --- a/src/neptune_scale/sync/operations_queue.py +++ b/src/neptune_scale/sync/operations_queue.py @@ -57,7 +57,7 @@ def last_timestamp(self) -> Optional[float]: with self._lock: return self._last_timestamp - def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Optional[float] = None) -> None: + def enqueue(self, *, operation: RunOperation, size: Optional[int] = None) -> None: try: is_metadata_update = operation.HasField("update") serialized_operation = operation.SerializeToString() @@ -75,7 +75,6 @@ def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: O operation=serialized_operation, metadata_size=size, is_batchable=is_metadata_update, - batch_key=key, ), block=True, timeout=None, diff --git a/src/neptune_scale/sync/queue_element.py b/src/neptune_scale/sync/queue_element.py index 1c37ff1..521f89e 100644 --- a/src/neptune_scale/sync/queue_element.py +++ b/src/neptune_scale/sync/queue_element.py @@ -26,5 +26,3 @@ class SingleOperation(NamedTuple): is_batchable: bool # Size of the metadata in the operation (without project, family, run_id etc.) metadata_size: Optional[int] - # Update metadata key - batch_key: Optional[float] diff --git a/tests/unit/test_aggregating_queue.py b/tests/unit/test_aggregating_queue.py index de9394a..0f736f5 100644 --- a/tests/unit/test_aggregating_queue.py +++ b/tests/unit/test_aggregating_queue.py @@ -6,6 +6,7 @@ import pytest from freezegun import freeze_time +from google.protobuf.timestamp_pb2 import Timestamp from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( Step, @@ -32,7 +33,6 @@ def test__simple(): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, ) # and @@ -60,7 +60,6 @@ def test__max_size_exceeded(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -68,7 +67,6 @@ def test__max_size_exceeded(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) # and @@ -108,7 +106,6 @@ def test__batch_size_limit(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -116,7 +113,6 @@ def test__batch_size_limit(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, ) # and @@ -148,7 +144,6 @@ def test__batching(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -156,7 +151,6 @@ def test__batching(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, ) # and @@ -179,7 +173,8 @@ def test__batching(): assert batch.project == "project" assert batch.run_id == "run_id" - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) + assert all(k in batch.update_batch.snapshots[0].assign for k in ["aa0", "aa1"]) + assert all(k in batch.update_batch.snapshots[1].assign for k in ["bb0", "bb1"]) @freeze_time("2024-09-01") @@ -189,21 +184,21 @@ def test__queue_element_size_limit_with_different_steps(): update2 = UpdateRunSnapshot(step=Step(whole=2), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) operation1 = RunOperation(update=update1) operation2 = RunOperation(update=update2) + timestamp1 = time.process_time() + timestamp2 = timestamp1 + 1 element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp2, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=2.0, ) # and @@ -235,7 +230,6 @@ def test__not_merge_two_run_creation(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -243,7 +237,6 @@ def test__not_merge_two_run_creation(): operation=operation2.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) # and @@ -301,7 +294,6 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -309,7 +301,6 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, ) # and @@ -351,7 +342,7 @@ def test__not_merge_run_creation_with_metadata_update(): @freeze_time("2024-09-01") -def test__merge_same_key(): +def test__batch_same_key(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -361,21 +352,20 @@ def test__merge_same_key(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp0 = time.process_time() element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp0, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp0, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=1.0, ) # and @@ -390,7 +380,7 @@ def test__merge_same_key(): # then assert result.sequence_id == 2 - assert result.timestamp == element2.timestamp + assert result.timestamp == timestamp0 # and batch = RunOperation() @@ -398,12 +388,14 @@ def test__merge_same_key(): assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update.step == Step(whole=1, micro=0) - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) + assert batch.update_batch.snapshots[0].step == Step(whole=1, micro=0) + assert batch.update_batch.snapshots[1].step == Step(whole=1, micro=0) + assert all(k in batch.update_batch.snapshots[0].assign for k in ["aa0", "aa1"]) + assert all(k in batch.update_batch.snapshots[1].assign for k in ["bb0", "bb1"]) @freeze_time("2024-09-01") -def test__merge_two_different_steps(): +def test__batch_two_different_steps(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -413,21 +405,21 @@ def test__merge_two_different_steps(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp1 = time.process_time() + timestamp2 = timestamp1 + 1 element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp2, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=2.0, ) # and @@ -454,7 +446,7 @@ def test__merge_two_different_steps(): @freeze_time("2024-09-01") -def test__merge_step_with_none(): +def test__batch_step_with_none(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -464,13 +456,13 @@ def test__merge_step_with_none(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp1 = time.process_time() element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, @@ -478,7 +470,6 @@ def test__merge_step_with_none(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) # and @@ -501,16 +492,33 @@ def test__merge_step_with_none(): assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update2, update1] # None is always first + assert batch.update_batch.snapshots == [update1, update2] @freeze_time("2024-09-01") -def test__merge_two_steps_two_metrics(): +def test__batch_two_steps_two_metrics(): # given - update1a = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"aa": Value(int64=10)}) - update2a = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"aa": Value(int64=20)}) - update1b = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"bb": Value(int64=100)}) - update2b = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"bb": Value(int64=200)}) + timestamp0 = int(time.process_time()) + update1a = UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 1, nanos=0), + assign={"aa": Value(int64=10)}, + ) + update2a = UpdateRunSnapshot( + step=Step(whole=2, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 2, nanos=0), + assign={"aa": Value(int64=20)}, + ) + update1b = UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 3, nanos=0), + assign={"bb": Value(int64=100)}, + ) + update2b = UpdateRunSnapshot( + step=Step(whole=2, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 4, nanos=0), + assign={"bb": Value(int64=200)}, + ) # and operations = [ @@ -522,13 +530,12 @@ def test__merge_two_steps_two_metrics(): elements = [ SingleOperation( sequence_id=sequence_id, - timestamp=time.process_time(), + timestamp=timestamp0 + sequence_id, operation=operation.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=batch_key, ) - for sequence_id, batch_key, operation in [ + for sequence_id, step, operation in [ (1, 1.0, operations[0]), (2, 2.0, operations[1]), (3, 1.0, operations[2]), @@ -554,13 +561,6 @@ def test__merge_two_steps_two_metrics(): batch = RunOperation() batch.ParseFromString(result.operation) - update1_merged = UpdateRunSnapshot( - step=Step(whole=1, micro=0), assign={"aa": Value(int64=10), "bb": Value(int64=100)} - ) - update2_merged = UpdateRunSnapshot( - step=Step(whole=2, micro=0), assign={"aa": Value(int64=20), "bb": Value(int64=200)} - ) - assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update1_merged, update2_merged] + assert batch.update_batch.snapshots == [update1a, update2a, update1b, update2b] diff --git a/tests/unit/test_sync_process.py b/tests/unit/test_sync_process.py index 1709510..25b1791 100644 --- a/tests/unit/test_sync_process.py +++ b/tests/unit/test_sync_process.py @@ -35,7 +35,6 @@ def single_operation(update: UpdateRunSnapshot, sequence_id): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, )