Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Jul 30, 2024
1 parent 8013147 commit 86cbd39
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
23 changes: 7 additions & 16 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from datetime import datetime
from typing import Callable

from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.core.components.abstract import (
Resource,
WithResources,
Expand Down Expand Up @@ -95,13 +93,11 @@ def __enter__(self) -> Run:
def resources(self) -> tuple[Resource, ...]:
return (self._operations_queue,)

def _prepare_common_message(self) -> RunOperation:
return RunOperation(
project=self._project,
run_id=self._run_id,
create_missing_project=False,
api_key=b"",
)
def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
super().close()

def log(
self,
Expand Down Expand Up @@ -140,7 +136,8 @@ def log(
verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set))

for operation in MessageBuilder(
common_message=self._prepare_common_message(),
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
fields=fields,
Expand All @@ -149,9 +146,3 @@ def log(
remove_tags=remove_tags,
):
self._operations_queue.enqueue(operation=operation)

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
super().close()
8 changes: 6 additions & 2 deletions src/neptune_scale/core/message_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class MessageBuilder:
def __init__(
self,
*,
common_message: RunOperation,
project: str,
run_id: str,
step: int | float | None,
timestamp: datetime,
fields: dict[str, float | int | str],
Expand All @@ -34,8 +35,11 @@ def __init__(
self._metrics = metrics
self._add_tags = add_tags
self._remove_tags = remove_tags
self._common_message = common_message

self._common_message = RunOperation(
project=project,
run_id=run_id,
)
self._ith: int = 0

def __iter__(self) -> MessageBuilder:
Expand Down
28 changes: 20 additions & 8 deletions tests/unit/test_message_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
def test_empty():
# given
builder = MessageBuilder(
common_message=RunOperation(),
project="workspace/project",
run_id="run_id",
step=1,
timestamp=datetime.now(),
fields={},
Expand All @@ -35,15 +36,18 @@ def test_empty():
# then
assert len(result) == 1
assert result[0] == RunOperation(
update=UpdateRunSnapshot(step=Step(whole=1, micro=0), timestamp=Timestamp(seconds=1722341532, nanos=21934))
project="workspace/project",
run_id="run_id",
update=UpdateRunSnapshot(step=Step(whole=1, micro=0), timestamp=Timestamp(seconds=1722341532, nanos=21934)),
)


@freeze_time("2024-07-30 12:12:12.000022")
def test_fields():
# given
builder = MessageBuilder(
common_message=RunOperation(),
project="workspace/project",
run_id="run_id",
step=1,
timestamp=datetime.now(),
fields={
Expand All @@ -65,6 +69,8 @@ def test_fields():
# then
assert len(result) == 1
assert result[0] == RunOperation(
project="workspace/project",
run_id="run_id",
update=UpdateRunSnapshot(
step=Step(whole=1, micro=0),
timestamp=Timestamp(seconds=1722341532, nanos=21934),
Expand All @@ -76,15 +82,16 @@ def test_fields():
"some/datetime": Value(timestamp=Timestamp(seconds=1722341532, nanos=21934)),
"some/tags": Value(string_set=StringSet(values={"tag1", "tag2"})),
},
)
),
)


@freeze_time("2024-07-30 12:12:12.000022")
def test_metrics():
# given
builder = MessageBuilder(
common_message=RunOperation(),
project="workspace/project",
run_id="run_id",
step=1,
timestamp=datetime.now(),
fields={},
Expand All @@ -101,21 +108,24 @@ def test_metrics():
# then
assert len(result) == 1
assert result[0] == RunOperation(
project="workspace/project",
run_id="run_id",
update=UpdateRunSnapshot(
step=Step(whole=1, micro=0),
timestamp=Timestamp(seconds=1722341532, nanos=21934),
append={
"some/metric": Value(float64=3.14),
},
)
),
)


@freeze_time("2024-07-30 12:12:12.000022")
def test_tags():
# given
builder = MessageBuilder(
common_message=RunOperation(),
project="workspace/project",
run_id="run_id",
step=1,
timestamp=datetime.now(),
fields={},
Expand All @@ -136,6 +146,8 @@ def test_tags():
# then
assert len(result) == 1
assert result[0] == RunOperation(
project="workspace/project",
run_id="run_id",
update=UpdateRunSnapshot(
step=Step(whole=1, micro=0),
timestamp=Timestamp(seconds=1722341532, nanos=21934),
Expand All @@ -153,5 +165,5 @@ def test_tags():
string=ModifyStringSet(values={"tag2": SET_OPERATION.REMOVE, "tag3": SET_OPERATION.REMOVE})
),
},
)
),
)

0 comments on commit 86cbd39

Please sign in to comment.