Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into kg/add-run-exists-func
Browse files Browse the repository at this point in the history
  • Loading branch information
kgodlewski committed Jan 13, 2025
2 parents 98cedc5 + e7ee538 commit de309ad
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 103 deletions.
2 changes: 1 addition & 1 deletion src/neptune_scale/api/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def log(
)

for operation, metadata_size in splitter:
self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step)
self._operations_queue.enqueue(operation=operation, size=metadata_size)


class Attribute:
Expand Down
62 changes: 15 additions & 47 deletions src/neptune_scale/sync/aggregating_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def commit(self) -> None:
def get(self) -> BatchedOperations:
start = time.monotonic()

batch_operations: dict[Optional[float], RunOperation] = {}
batch_operations: list[RunOperation] = []
batch_sequence_id: Optional[int] = None
batch_timestamp: Optional[float] = None

Expand All @@ -95,7 +95,7 @@ def get(self) -> BatchedOperations:
if not batch_operations:
new_operation = RunOperation()
new_operation.ParseFromString(element.operation)
batch_operations[element.batch_key] = new_operation
batch_operations.append(new_operation)
batch_bytes += len(element.operation)
else:
if not element.is_batchable:
Expand All @@ -110,10 +110,7 @@ def get(self) -> BatchedOperations:

new_operation = RunOperation()
new_operation.ParseFromString(element.operation)
if element.batch_key not in batch_operations:
batch_operations[element.batch_key] = new_operation
else:
merge_run_operation(batch_operations[element.batch_key], new_operation)
batch_operations.append(new_operation)
batch_bytes += element.metadata_size

batch_sequence_id = element.sequence_id
Expand Down Expand Up @@ -157,54 +154,25 @@ def get(self) -> BatchedOperations:
)


def create_run_batch(operations: dict[Optional[float], RunOperation]) -> RunOperation:
def create_run_batch(operations: list[RunOperation]) -> RunOperation:
if not operations:
raise Empty

if len(operations) == 1:
return next(iter(operations.values()))
return operations[0]

batch = None
for _, operation in sorted(operations.items(), key=lambda x: (x[0] is not None, x[0])):
if batch is None:
batch = RunOperation()
batch.project = operation.project
batch.run_id = operation.run_id
batch.create_missing_project = operation.create_missing_project
batch.api_key = operation.api_key
head = operations[0]
batch = RunOperation()
batch.project = head.project
batch.run_id = head.run_id
batch.create_missing_project = head.create_missing_project
batch.api_key = head.api_key

for operation in operations:
operation_type = operation.WhichOneof("operation")
if operation_type == "update":
batch.update_batch.snapshots.append(operation.update)
else:
raise ValueError("Cannot batch operation of type %s", operation_type)

if batch is None:
raise Empty
return batch


def merge_run_operation(batch: RunOperation, operation: RunOperation) -> None:
"""
Merge the `operation` into `batch`, taking into account the special case of `modify_sets`.
Protobuf merges existing map keys by simply overwriting values, instead of calling
`MergeFrom` on the existing value, eg: A['foo'] = B['foo'].
We want this instead:
batch = {'sys/tags': 'string': { 'values': {'foo': ADD}}}
operation = {'sys/tags': 'string': { 'values': {'bar': ADD}}}
result = {'sys/tags': 'string': { 'values': {'foo': ADD, 'bar': ADD}}}
If we called `batch.MergeFrom(operation)` we would get an overwritten value:
result = {'sys/tags': 'string': { 'values': {'bar': ADD}}}
This function ensures that the `modify_sets` are merged correctly, leaving the default
behaviour for all other fields.
"""

modify_sets = operation.update.modify_sets
operation.update.ClearField("modify_sets")

batch.MergeFrom(operation)

for k, v in modify_sets.items():
batch.update.modify_sets[k].MergeFrom(v)
3 changes: 1 addition & 2 deletions src/neptune_scale/sync/operations_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def last_timestamp(self) -> Optional[float]:
with self._lock:
return self._last_timestamp

def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Optional[float] = None) -> None:
def enqueue(self, *, operation: RunOperation, size: Optional[int] = None) -> None:
try:
is_metadata_update = operation.HasField("update")
serialized_operation = operation.SerializeToString()
Expand All @@ -75,7 +75,6 @@ def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: O
operation=serialized_operation,
metadata_size=size,
is_batchable=is_metadata_update,
batch_key=key,
),
block=True,
timeout=None,
Expand Down
2 changes: 0 additions & 2 deletions src/neptune_scale/sync/queue_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,3 @@ class SingleOperation(NamedTuple):
is_batchable: bool
# Size of the metadata in the operation (without project, family, run_id etc.)
metadata_size: Optional[int]
# Update metadata key
batch_key: Optional[float]
Loading

0 comments on commit de309ad

Please sign in to comment.