Skip to content

Commit

Permalink
Merge pull request #90 from neptune-ai/kg/dict-like-api
Browse files Browse the repository at this point in the history
Add initial version of dict-like API (similar to `neptune-client` 1.x)
  • Loading branch information
kgodlewski authored Dec 6, 2024
2 parents 432c447 + fed359c commit bc69b17
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 11 deletions.
6 changes: 3 additions & 3 deletions .github/actions/test-unit/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ runs:

- name: Test
run: |
pytest -v ./tests/unit/ \
--timeout=120 --timeout_method=thread \
pytest -v --timeout=120 --timeout_method=thread \
--color=yes \
--junitxml="./test-results/test-unit-new-${{ inputs.os }}-${{ inputs.python-version }}.xml"
--junitxml="./test-results/test-unit-new-${{ inputs.os }}-${{ inputs.python-version }}.xml" \
./tests/unit/ ./src/
shell: bash

- name: Upload test reports
Expand Down
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
pre-commit
pytest
pytest-timeout
pytest-xdist
freezegun
neptune-fetcher
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ check_untyped_defs = "True"
warn_return_any = "True"
show_error_codes = "True"
# warn_unused_ignores = "True"

[tool.pytest.ini_options]
addopts = "--doctest-modules -n auto"
244 changes: 244 additions & 0 deletions src/neptune_scale/api/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import functools
import itertools
import warnings
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
cast,
)

if TYPE_CHECKING:
from neptune_scale.api.run import Run

__all__ = ("Attribute", "AttributeStore")


def warn_unsupported_params(fn: Callable) -> Callable:
# Perform some simple heuristics to detect if a method is called with parameters
# that are not supported by Scale
warn = functools.partial(warnings.warn, stacklevel=3)

@functools.wraps(fn)
def wrapper(*args, **kwargs): # type: ignore
if kwargs.get("wait") is not None:
warn("The `wait` parameter is not yet implemented and will be ignored.")

extra_kwargs = set(kwargs.keys()) - {"wait", "step", "timestamp", "steps", "timestamps"}
if extra_kwargs:
warn(
f"`{fn.__name__}()` was called with additional keyword argument(s): `{', '.join(extra_kwargs)}`. "
"These arguments are not supported by Neptune Scale and will be ignored."
)

return fn(*args, **kwargs)

return wrapper


# TODO: proper typehinting
ValueType = Any # Union[float, int, str, bool, datetime, Tuple, List, Dict, Set]


class AttributeStore:
def __init__(self, run: "Run") -> None:
self._run = run
self._attributes: Dict[str, Attribute] = {}

def __getitem__(self, path: str) -> "Attribute":
path = cleanup_path(path)
attr = self._attributes.get(path)
if attr is None:
attr = Attribute(self, path)
self._attributes[path] = attr

return attr

def __setitem__(self, key: str, value: ValueType) -> None:
# TODO: validate type if attr is already known
attr = self[key]
attr.assign(value)

def log(
self,
step: Optional[Union[float, int]] = None,
timestamp: Optional[Union[datetime, float]] = None,
configs: Optional[Dict[str, Union[float, bool, int, str, datetime, list, set, tuple]]] = None,
metrics: Optional[Dict[str, Union[float, int]]] = None,
tags_add: Optional[Dict[str, Union[List[str], Set[str], Tuple[str]]]] = None,
tags_remove: Optional[Dict[str, Union[List[str], Set[str], Tuple[str]]]] = None,
) -> None:
# TODO: This should not call Run.log, but do the actual work. Reverse the current dependency so that this
# class handles all the logging
timestamp = datetime.now() if timestamp is None else timestamp

# TODO: Remove this and teach MetadataSplitter to handle Nones
configs = {} if configs is None else configs
metrics = {} if metrics is None else metrics
tags_add = {} if tags_add is None else tags_add
tags_remove = {} if tags_remove is None else tags_remove

# TODO: remove once Run.log accepts Union[datetime, float]
timestamp = cast(datetime, timestamp)
self._run.log(
step=step, timestamp=timestamp, configs=configs, metrics=metrics, tags_add=tags_add, tags_remove=tags_remove
)


class Attribute:
"""Objects of this class are returned on dict-like access to Run. Attributes have a path and
allow logging values under it.
run = Run(...)
run['foo'] = 1
run['nested'] = {'foo': 1, {'bar': {'baz': 2}}}
run['bar'].append(1, step=10)
"""

def __init__(self, store: AttributeStore, path: str) -> None:
self._store = store
self._path = path

# TODO: typehint value properly
@warn_unsupported_params
def assign(self, value: Any, *, wait: bool = False) -> None:
data = accumulate_dict_values(value, self._path)
self._store.log(configs=data)

@warn_unsupported_params
def append(
self,
value: Union[Dict[str, Any], float],
*,
step: Union[float, int],
timestamp: Optional[Union[float, datetime]] = None,
wait: bool = False,
**kwargs: Any,
) -> None:
data = accumulate_dict_values(value, self._path)
self._store.log(metrics=data, step=step, timestamp=timestamp)

@warn_unsupported_params
# TODO: this should be Iterable in Run as well
# def add(self, values: Union[str, Iterable[str]], *, wait: bool = False) -> None:
def add(self, values: Union[str, Union[List[str], Set[str], Tuple[str]]], *, wait: bool = False) -> None:
if isinstance(values, str):
values = (values,)
self._store.log(tags_add={self._path: values})

@warn_unsupported_params
# TODO: this should be Iterable in Run as well
# def remove(self, values: Union[str, Iterable[str]], *, wait: bool = False) -> None:
def remove(self, values: Union[str, Union[List[str], Set[str], Tuple[str]]], *, wait: bool = False) -> None:
if isinstance(values, str):
values = (values,)
self._store.log(tags_remove={self._path: values})

@warn_unsupported_params
def extend(
self,
values: Collection[Union[float, int]],
*,
steps: Collection[Union[float, int]],
timestamps: Optional[Collection[Union[float, datetime]]] = None,
wait: bool = False,
**kwargs: Any,
) -> None:
# TODO: make this compatible with the old client
assert len(values) == len(steps)

if timestamps is not None:
assert len(timestamps) == len(values)
else:
timestamps = cast(tuple, itertools.repeat(datetime.now()))

for value, step, timestamp in zip(values, steps, timestamps):
self.append(value, step=step, timestamp=timestamp, wait=wait)

# TODO: add value type validation to all the methods
# TODO: change Run API to typehint timestamp as Union[datetime, float]


def iter_nested(dict_: Dict[str, ValueType], path: str) -> Iterator[Tuple[Tuple[str, ...], ValueType]]:
"""Iterate a nested dictionary, yielding a tuple of path components and value.
>>> list(iter_nested({"foo": 1, "bar": {"baz": 2}}, "base"))
[(('base', 'foo'), 1), (('base', 'bar', 'baz'), 2)]
>>> list(iter_nested({"foo":{"bar": 1}, "bar":{"baz": 2}}, "base"))
[(('base', 'foo', 'bar'), 1), (('base', 'bar', 'baz'), 2)]
>>> list(iter_nested({"foo": 1, "bar": 2}, "base"))
[(('base', 'foo'), 1), (('base', 'bar'), 2)]
>>> list(iter_nested({"foo": {}}, ""))
Traceback (most recent call last):
...
ValueError: The dictionary cannot be empty or contain empty nested dictionaries.
"""

parts = tuple(path.split("/"))
yield from _iter_nested(dict_, parts)


def _iter_nested(dict_: Dict[str, ValueType], path_acc: Tuple[str, ...]) -> Iterator[Tuple[Tuple[str, ...], ValueType]]:
if not dict_:
raise ValueError("The dictionary cannot be empty or contain empty nested dictionaries.")

for key, value in dict_.items():
current_path = path_acc + (key,)
if isinstance(value, dict):
yield from _iter_nested(value, current_path)
else:
yield current_path, value


def cleanup_path(path: str) -> str:
"""
>>> cleanup_path('/a/b/c')
'a/b/c'
>>> cleanup_path('a/b/c/')
Traceback (most recent call last):
...
ValueError: Invalid path: `a/b/c/`. Path must not end with a slash.
>>> cleanup_path('a//b/c')
Traceback (most recent call last):
...
ValueError: Invalid path: `a//b/c`. Path components must not be empty.
"""

path = path.strip()
if path in ("", "/"):
raise ValueError(f"Invalid path: `{path}`.")

if path.startswith("/"):
path = path[1:]

if path.endswith("/"):
raise ValueError(f"Invalid path: `{path}`. Path must not end with a slash.")
if not all(path.split("/")):
raise ValueError(f"Invalid path: `{path}`. Path components must not be empty.")

return path


def accumulate_dict_values(value: Union[ValueType, Dict[str, ValueType]], path_or_base: str) -> Dict:
"""
>>> accumulate_dict_values(1, "foo")
{'foo': 1}
>>> accumulate_dict_values({"bar": 1, 'l0/l1': 2, 'l3':{"l4": 3}}, "foo")
{'foo/bar': 1, 'foo/l0/l1': 2, 'foo/l3/l4': 3}
"""

if isinstance(value, dict):
data = {"/".join(path): value for path, value in iter_nested(value, path_or_base)}
else:
data = {path_or_base: value}

return data
13 changes: 13 additions & 0 deletions src/neptune_scale/api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.attribute import (
Attribute,
AttributeStore,
)
from neptune_scale.api.validation import (
verify_collection_type,
verify_max_length,
Expand Down Expand Up @@ -199,6 +203,7 @@ def __init__(

self._project: str = input_project
self._run_id: str = run_id
self._attr_store: AttributeStore = AttributeStore(self)

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
Expand Down Expand Up @@ -388,6 +393,12 @@ def _create_run(
)
self._operations_queue.enqueue(operation=operation)

def __getitem__(self, key: str) -> Attribute:
return self._attr_store[key]

def __setitem__(self, key: str, value: Any) -> None:
self._attr_store[key] = value

def log_metrics(
self,
data: Dict[str, Union[float, int]],
Expand Down Expand Up @@ -530,11 +541,13 @@ def log(
verify_type("tags_remove", tags_remove, (dict, type(None)))

timestamp = datetime.now() if timestamp is None else timestamp
# TODO: move this into AttributeStore
configs = {} if configs is None else configs
metrics = {} if metrics is None else metrics
tags_add = {} if tags_add is None else tags_add
tags_remove = {} if tags_remove is None else tags_remove

# TODO: refactor this into something like `verify_dict_types(name, allowed_key_types, allowed_value_types)`
verify_collection_type("`configs` keys", list(configs.keys()), str)
verify_collection_type("`metrics` keys", list(metrics.keys()), str)
verify_collection_type("`tags_add` keys", list(tags_add.keys()), str)
Expand Down
7 changes: 6 additions & 1 deletion src/neptune_scale/api/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None:
raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name")


def verify_collection_type(var_name: str, var: Union[list, set, tuple], expected_type: Union[type, tuple]) -> None:
def verify_collection_type(
var_name: str, var: Union[list, set, tuple], expected_type: Union[type, tuple], allow_none: bool = True
) -> None:
if var is None and not allow_none:
raise ValueError(f"{var_name} must not be None")

verify_type(var_name, var, (list, set, tuple))

for value in var:
Expand Down
1 change: 1 addition & 0 deletions src/neptune_scale/net/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def make_step(number: Union[float, int], raise_on_step_precision_loss: bool = Fa
"""
Converts a number to protobuf Step value. Example:
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321)
Args:
number: step expressed as number
raise_on_step_precision_loss: inform converter whether it should silently drop precision and
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import base64
import json

import pytest


@pytest.fixture(scope="session")
def api_token():
return base64.b64encode(json.dumps({"api_address": "aa", "api_url": "bb"}).encode("utf-8")).decode("utf-8")
Loading

0 comments on commit bc69b17

Please sign in to comment.