Skip to content

Commit

Permalink
Merge pull request #806 from riedgar-ms/riedgar-ms/model-metrics-01
Browse files Browse the repository at this point in the history
[Feature] Monitor token consumption
  • Loading branch information
paulbkoch authored May 9, 2024
2 parents 7b4d85f + d860cb2 commit 5ad2304
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
6 changes: 6 additions & 0 deletions guidance/models/_guidance_engine_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel, NonNegativeInt


class GuidanceEngineMetrics(BaseModel):
engine_input_tokens: NonNegativeInt = 0
engine_output_tokens: NonNegativeInt = 0
9 changes: 8 additions & 1 deletion guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"Failed to load guidance.cpp, falling back to Python mirror implementations..."
)
from .. import _cpp as cpp

from ._guidance_engine_metrics import GuidanceEngineMetrics
from .._rust.guidancerust import engine_start
from .._utils import softmax, CaptureEvents
from .._parser import EarleyCommitParser, Parser
Expand Down Expand Up @@ -203,6 +205,11 @@ def __init__(self, tokenizer, compute_log_probs=False):
self._token_trie.match = True
self._token_trie.match_version = 0

self.metrics = GuidanceEngineMetrics()

def reset_metrics(self):
self.metrics = GuidanceEngineMetrics()

def start(self, parser, grammar, ensure_bos_token=True):
"""Start processing parser state executed through the grammar.
Expand Down Expand Up @@ -1626,4 +1633,4 @@ def _check_dominated(node, parser, match_version, next_byte_mask):
parser.pos = curr_pos
if not child_dominate:
return False
return True
return True
3 changes: 3 additions & 0 deletions guidance/models/llama_cpp/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,12 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
batch.logits[n_tokens - 1] = True

ret = llama_cpp.llama_decode(self.model_obj.ctx, batch)
self.metrics.engine_input_tokens += n_tokens
if ret != 0:
raise Exception(f"Call to llama_cpp.llama_decode returned {ret}.")

self.metrics.engine_output_tokens += 1

# get the logits
logits = llama_cpp.llama_get_logits(self.model_obj.ctx)
if llama_cpp.__version__ < "0.2.58":
Expand Down
6 changes: 6 additions & 0 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def _tokenizer(self, model, **kwargs):

return tokenizer

def __call__(self, byte_string):
tokenisation = self._orig_tokenizer(byte_string)
return tokenisation["input_ids"]


class TransformersEngine(Engine):
def __init__(self, model, tokenizer, compute_log_probs, **kwargs):
Expand Down Expand Up @@ -265,6 +269,8 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
self._cached_logits = (
model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy()
)
self.metrics.engine_input_tokens += len(new_token_ids)
self.metrics.engine_output_tokens += 1

return self._cached_logits

Expand Down
63 changes: 60 additions & 3 deletions tests/library/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from guidance import gen, models
from guidance import gen, models, select


def test_basic():
Expand Down Expand Up @@ -73,6 +73,56 @@ def test_stop_quote(selected_model):
assert not lm["title"].endswith('"')


def test_metrics_smoke(selected_model: models.Model):
lm = selected_model
lm.engine.reset_metrics()

lm += "abcd"
print(f"{lm.engine.metrics=}")
lm += gen("first", max_tokens=1)
print(f"{lm.engine.metrics=}")
# Can't be sure of exact count due to token healing
assert (
lm.engine.metrics.engine_output_tokens == 1
or lm.engine.metrics.engine_output_tokens == 2
)
assert lm.engine.metrics.engine_input_tokens >= 1
last_input_tokens = lm.engine.metrics.engine_input_tokens

lm += "fg"
lm += gen("second", max_tokens=1)
# Again, trouble with healing
assert (
lm.engine.metrics.engine_output_tokens >= 2
or lm.engine.metrics.engine_output_tokens <= 4
)
assert lm.engine.metrics.engine_input_tokens > last_input_tokens


def test_metrics_select(selected_model: models.Model):
lm = selected_model
lm.engine.reset_metrics()

lm += "I will "
lm += select(
[
"ride a bicycle down the road",
"row in a boat along the river",
"go for a swim in the ocean",
]
)
print(f"lm={str(lm)}")
print(f"{lm.engine.metrics=}")
assert lm.engine.metrics.engine_input_tokens > 1
assert lm.engine.metrics.engine_output_tokens > 0
# Guidance should be able to force the generation after only a couple of tokens
# so even though the options are long, relatively few output tokens should be
# needed
assert (
lm.engine.metrics.engine_input_tokens > lm.engine.metrics.engine_output_tokens
)


def test_unicode(selected_model):
# black makes this test ugly -- easier to read with fmt: off
# fmt: off
Expand All @@ -85,11 +135,18 @@ def test_unicode(selected_model):
# fmt: on


def test_unicode2(selected_model):
def test_unicode2(selected_model: models.Model):
lm = selected_model
lm.engine.reset_metrics()
prompt = "Janet’s ducks lay 16 eggs per day"
lm += prompt + gen(max_tokens=10)
assert True
assert lm.engine.metrics.engine_input_tokens > 1
# Due to token healing, we can't be sure of the
# precise output count
assert (
lm.engine.metrics.engine_output_tokens == 10
or lm.engine.metrics.engine_output_tokens == 11
)


def test_gsm8k():
Expand Down

0 comments on commit 5ad2304

Please sign in to comment.