Skip to content

Commit

Permalink
Add helper for thread local variables that can be used to add metadat…
Browse files Browse the repository at this point in the history
…a to the output stream (#1052)

* Add helper for thread local variables that can be used to add metadata to the output stream

* Add devtools to help with debugging
* Extend llm generate to calculate metadata

* fix mocked test
  • Loading branch information
eric-anderson authored Dec 4, 2024
1 parent 7e6b626 commit 1f05347
Show file tree
Hide file tree
Showing 11 changed files with 214 additions and 13 deletions.
11 changes: 11 additions & 0 deletions lib/sycamore/sycamore/data/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sycamore.data import MetadataDocument
from sycamore.utils.thread_local import ThreadLocalAccess, ADD_METADATA_TO_OUTPUT


def add_metadata(**metadata):
ThreadLocalAccess(ADD_METADATA_TO_OUTPUT).get().append(MetadataDocument(**metadata))


# At some point we should define particular forms of metadata like metrics
# Maybe following https://github.com/prometheus/OpenMetrics/blob/main/specification/OpenMetrics.md
# as a structure for the metrics -- too complex for now.
26 changes: 23 additions & 3 deletions lib/sycamore/sycamore/llms/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import datetime
from enum import Enum
import boto3
import json
Expand Down Expand Up @@ -124,18 +125,33 @@ def _get_generate_kwargs(self, prompt_kwargs: Dict, llm_kwargs: Optional[Dict] =

return kwargs

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict:
key, ret = self._cache_get(prompt_kwargs, llm_kwargs)
if ret is not None:
if isinstance(ret, dict):
return ret

kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
body = json.dumps(kwargs)
start = datetime.datetime.now()
response = self._client.invoke_model(
body=body, modelId=self.model.name, accept="application/json", contentType="application/json"
)
wall_latency = datetime.datetime.now() - start
md = response["ResponseMetadata"]
assert md["HTTPStatusCode"] == 200, f"Request failed {md['HTTPStatusCode']}"
hdrs = md["HTTPHeaders"]
server_latency = datetime.timedelta(milliseconds=int(hdrs["x-amzn-bedrock-invocation-latency"]))
in_tokens = int(hdrs["x-amzn-bedrock-input-token-count"])
out_tokens = int(hdrs["x-amzn-bedrock-output-token-count"])
response_body = json.loads(response.get("body").read())
ret = response_body.get("content", {})[0].get("text", "")
output = response_body.get("content", {})[0].get("text", "")
ret = {
"output": output,
"wall_latency": wall_latency,
"server_latency": server_latency,
"in_tokens": in_tokens,
"out_tokens": out_tokens,
}
value = {
"result": ret,
"prompt_kwargs": prompt_kwargs,
Expand All @@ -144,3 +160,7 @@ def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) ->
}
self._cache_set(key, value)
return ret

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
d = self.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs)
return d["output"]
6 changes: 5 additions & 1 deletion lib/sycamore/sycamore/llms/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def _get_cache_key(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None)
data = pickle.dumps(combined)
return self._cache.get_hash_context(data).hexdigest()

def _cache_get(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> Tuple[Optional[str], Optional[str]]:
# TODO fix cache_get and cache_set to have more consistent typing on the value.
# cache_set takes a value which is a dictionary and has a bunch of required values
# so that cache_get can verify the return value. cache_set then reaches inside that
# dictionary for its return value.
def _cache_get(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> Tuple[Optional[str], Any]:
"""Get a cached result for the given prompt and LLM parameters. Returns the cache key
and the cached result if found, otherwise returns None for both."""
if (llm_kwargs or {}).get("temperature", 0) != 0 or not self._cache:
Expand Down
13 changes: 13 additions & 0 deletions lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,16 @@ def test_cached_bedrock_different_models(tmp_path: Path):
# check for difference with model change
assert key_HAIKU != key_SONNET
assert res_HAIKU != res_SONNET


def test_metadata():
llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU)
prompt_kwargs = {"prompt": "Write a limerick about large language models."}

res = llm.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs={})

assert "output" in res
assert "wall_latency" in res
assert "server_latency" in res
assert "in_tokens" in res
assert "out_tokens" in res
32 changes: 32 additions & 0 deletions lib/sycamore/sycamore/tests/integration/transforms/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,35 @@ def __call__(self, d):
print(f"Actor {a} got {count[a]} items")
assert count[a] >= min_count
assert count[a] <= max_count


def test_map_metadata() -> None:
dicts = [
{"index": 1, "doc": "Members of a strike at Yale University."},
{"index": 2, "doc": "A woman is speaking at a podium outdoors."},
]
in_docs = [Document(d) for d in dicts]

def inject_metadata(d):
from sycamore.data.metadata import add_metadata

idx = d["index"]
for i in range(idx):
add_metadata(index=idx, value=i)
return d

docs = (
sycamore.init(exec_mode=sycamore.EXEC_RAY)
.read.document(in_docs)
.map(inject_metadata)
.take_all(include_metadata=True)
)

inject_md = [d for d in docs if "metadata" in d and "index" in d.metadata]
assert len(inject_md) == 3
assert inject_md[0].metadata["index"] == 1
assert inject_md[0].metadata["value"] == 0
assert inject_md[1].metadata["index"] == 2
assert inject_md[1].metadata["value"] == 0
assert inject_md[2].metadata["index"] == 2
assert inject_md[2].metadata["value"] == 1
34 changes: 28 additions & 6 deletions lib/sycamore/sycamore/tests/unit/llms/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,32 @@
from sycamore.utils.cache import DiskCache


class BedrockBody:
def __init__(self, body):
self.body = body

def read(self):
return self.body


def bedrock_reply(body):
return {
"ResponseMetadata": {
"HTTPStatusCode": 200,
"HTTPHeaders": {
"x-amzn-bedrock-invocation-latency": 1111,
"x-amzn-bedrock-input-token-count": 30,
"x-amzn-bedrock-output-token-count": 50,
},
},
"body": BedrockBody(body),
}


@patch("boto3.client")
def test_bedrock(mock_boto3_client):
mock_boto3_client.return_value.invoke_model.return_value.get.return_value.read.return_value = (
'{"content": [{"text": "Here is your result: 56"}]}'
def test_bedrock_simple(mock_boto3_client):
mock_boto3_client.return_value.invoke_model.return_value = bedrock_reply(
'{ "content": [{"text": "Here is your result: 56"}]}'
)

client = Bedrock(BedrockModels.CLAUDE_3_5_SONNET)
Expand All @@ -37,7 +59,7 @@ def test_bedrock(mock_boto3_client):

@patch("boto3.client")
def test_bedrock_system_role(mock_boto3_client):
mock_boto3_client.return_value.invoke_model.return_value.get.return_value.read.return_value = (
mock_boto3_client.return_value.invoke_model.return_value = bedrock_reply(
'{"content": [{"text": "Here is your result: 56"}]}'
)

Expand Down Expand Up @@ -68,7 +90,7 @@ def test_bedrock_system_role(mock_boto3_client):

@patch("boto3.client")
def test_bedrock_with_llm_kwargs(mock_boto3_client):
mock_boto3_client.return_value.invoke_model.return_value.get.return_value.read.return_value = (
mock_boto3_client.return_value.invoke_model.return_value = bedrock_reply(
'{"content": [{"text": "Here is your result: 56"}]}'
)

Expand Down Expand Up @@ -97,7 +119,7 @@ def test_bedrock_with_llm_kwargs(mock_boto3_client):

@patch("boto3.client")
def test_bedrock_with_cache(mock_boto3_client):
mock_boto3_client.return_value.invoke_model.return_value.get.return_value.read.return_value = (
mock_boto3_client.return_value.invoke_model.return_value = bedrock_reply(
'{"content": [{"text": "Here is your result: 56"}]}'
)

Expand Down
32 changes: 32 additions & 0 deletions lib/sycamore/sycamore/tests/unit/transforms/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import ray.data

import sycamore
from sycamore.data import Document
from sycamore.plan_nodes import Node
from sycamore.transforms import Map, FlatMap, MapBatch
Expand Down Expand Up @@ -68,6 +69,37 @@ def sort_key(d: dict) -> int:
assert dicts[0]["index"] == 2
assert dicts[1]["index"] == 3

def test_map_metadata(self) -> None:
dicts = [
{"index": 1, "doc": "Members of a strike at Yale University."},
{"index": 2, "doc": "A woman is speaking at a podium outdoors."},
]
in_docs = [Document(d) for d in dicts]

def inject_metadata(d):
from sycamore.data.metadata import add_metadata

idx = d["index"]
for i in range(idx):
add_metadata(index=idx, value=i)
return d

docs = (
sycamore.init(exec_mode=sycamore.EXEC_LOCAL)
.read.document(in_docs)
.map(inject_metadata)
.take_all(include_metadata=True)
)

inject_md = [d for d in docs if "metadata" in d and "index" in d.metadata]
assert len(inject_md) == 3
assert inject_md[0].metadata["index"] == 1
assert inject_md[0].metadata["value"] == 0
assert inject_md[1].metadata["index"] == 2
assert inject_md[1].metadata["value"] == 0
assert inject_md[2].metadata["index"] == 2
assert inject_md[2].metadata["value"] == 1

class Empty:
def __init__(self):
pass
Expand Down
11 changes: 9 additions & 2 deletions lib/sycamore/sycamore/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sycamore.data.document import split_data_metadata
from sycamore.plan_nodes import Node, UnaryNode
from sycamore.utils.ray_utils import check_serializable
from sycamore.utils.thread_local import ThreadLocal, ADD_METADATA_TO_OUTPUT

if TYPE_CHECKING:
from ray.data import Dataset, Datasink
Expand Down Expand Up @@ -159,11 +160,14 @@ def execute(
def local_execute(self, all_docs: list[Document]) -> list[Document]:
docs = [d for d in all_docs if not isinstance(d, MetadataDocument)]
metadata = [d for d in all_docs if isinstance(d, MetadataDocument)]
outputs = self._local_process(docs)
extra_metadata: list[MetadataDocument] = []
with ThreadLocal(ADD_METADATA_TO_OUTPUT, extra_metadata):
outputs = self._local_process(docs)
to_docs = [d for d in outputs if not isinstance(d, MetadataDocument)]
if self._enable_auto_metadata and (len(docs) > 0 or len(to_docs) > 0):
outputs.extend(update_lineage(docs, to_docs))
outputs.extend(metadata)
outputs.extend(extra_metadata)
return outputs

def _local_process(self, in_docs: list[Document]) -> list[Document]:
Expand Down Expand Up @@ -248,7 +252,9 @@ def _process_ray(
all_docs = [Document.deserialize(s) for s in ray_input.get("doc", [])]
docs = [d for d in all_docs if not isinstance(d, MetadataDocument)]
metadata = [d for d in all_docs if isinstance(d, MetadataDocument)]
outputs = f(docs)
extra_metadata: list[MetadataDocument] = []
with ThreadLocal(ADD_METADATA_TO_OUTPUT, extra_metadata):
outputs = f(docs)
if outputs is None:
logging.warn(f"Function {name} returned nothing. If it has no outputs it should return an empty list")
outputs = []
Expand All @@ -264,6 +270,7 @@ def _process_ray(
if enable_auto_metadata and (len(docs) > 0 or len(to_docs) > 0):
outputs.extend(update_lineage(docs, to_docs))
outputs.extend(metadata)
outputs.extend(extra_metadata)
return {"doc": [d.serialize() for d in outputs]}


Expand Down
43 changes: 43 additions & 0 deletions lib/sycamore/sycamore/utils/thread_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
import threading

logger = logging.getLogger(__name__)
# Used to inject metadata into the docset output stream to make it easier to use with
# transforms like map.
ADD_METADATA_TO_OUTPUT = "add_metadata_to_output"

# Create a thread-local data object
thread_local_data = threading.local()


class ThreadLocalAccess:
def __init__(self, var_name):
self.var_name = var_name

def present(self):
return hasattr(thread_local_data, self.var_name)

def get(self):
assert hasattr(thread_local_data, self.var_name), f"{self.var_name} not present in TLS"
return getattr(thread_local_data, self.var_name)

def set(self, value):
assert hasattr(thread_local_data, self.var_name), f"{self.var_name} not present in TLS"
setattr(thread_local_data, self.var_name, value)


class ThreadLocal(ThreadLocalAccess):
def __init__(self, var_name, var_value):
self.var_name = var_name
self.var_value = var_value

def __enter__(self):
assert not hasattr(thread_local_data, self.var_name), f"{self.var_name} already set in TLS"
setattr(thread_local_data, self.var_name, self.var_value)
logger.debug(f"Thread-local variable '{self.var_name}' removed.")
return self

def __exit__(self, exc_type, exc_val, exc_tb):
assert hasattr(thread_local_data, self.var_name), f"{self.var_name} vanished from TLS"
delattr(thread_local_data, self.var_name)
logger.debug(f"Thread-local variable '{self.var_name}' removed.")
18 changes: 17 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pre-commit = "^3.4.0"
mypy = "^1.11.0"
nbmake = "^1.4.5"
pip = "^24.2"
devtools = "^0.12"

[tool.poetry.group.notebook.dependencies]
jupyterlab = "^4.0.11"
Expand Down

0 comments on commit 1f05347

Please sign in to comment.