Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Monitor token consumption #806

Merged
merged 30 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0b9b43b
Create the basic class for holding metrics
riedgar-ms May 7, 2024
6faf8db
Put in, along with a very basic test
riedgar-ms May 7, 2024
bbbec17
Another test to watch the metrics
riedgar-ms May 7, 2024
25f42bf
Getting things close to working.....
riedgar-ms May 7, 2024
230f782
Remove minor hangover
riedgar-ms May 7, 2024
bdd80e7
Another oversight
riedgar-ms May 7, 2024
392a479
Add a comment
riedgar-ms May 7, 2024
76d533e
Need to be able to reset the metrics on the Model
riedgar-ms May 7, 2024
34881c9
Thinking about another metric
riedgar-ms May 8, 2024
96de164
Figure out how to call tokeniser
riedgar-ms May 8, 2024
2fa6521
Reformat
riedgar-ms May 8, 2024
c3a0c6b
Do some renaming
riedgar-ms May 8, 2024
ff46ec1
Some more output
riedgar-ms May 8, 2024
2faa583
Trying to count forced tokens
riedgar-ms May 8, 2024
f8de7c8
Try following things through
riedgar-ms May 8, 2024
822f8e1
I don't think I need these bits
riedgar-ms May 8, 2024
5c50051
Tweak where stats are grabbed
riedgar-ms May 8, 2024
67e21c6
Tidy up tests
riedgar-ms May 8, 2024
bcc269f
Remove extra
riedgar-ms May 8, 2024
b728b0f
Try to figure out if syntax makes a difference
riedgar-ms May 8, 2024
9f330c3
Latest attempt to get consistent token results
riedgar-ms May 8, 2024
216a5de
Rethink the metrics
riedgar-ms May 9, 2024
a083f1b
Add a reset method
riedgar-ms May 9, 2024
66a3b05
Undo another change
riedgar-ms May 9, 2024
6af11ba
Fix tests
riedgar-ms May 9, 2024
b25381e
Better name
riedgar-ms May 9, 2024
f557f4d
Merge remote-tracking branch 'upstream/main' into riedgar-ms/model-me…
riedgar-ms May 9, 2024
268d4a0
Don't have things for common_chat_testing yet
riedgar-ms May 9, 2024
4d851b0
Hook metrics into llamacpp
riedgar-ms May 9, 2024
d860cb2
Fix test
riedgar-ms May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions guidance/models/_guidance_engine_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel, NonNegativeInt


class GuidanceEngineMetrics(BaseModel):
engine_input_tokens: NonNegativeInt = 0
engine_output_tokens: NonNegativeInt = 0
9 changes: 8 additions & 1 deletion guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"Failed to load guidance.cpp, falling back to Python mirror implementations..."
)
from .. import _cpp as cpp

from ._guidance_engine_metrics import GuidanceEngineMetrics
from .._rust.guidancerust import engine_start
from .._utils import softmax, CaptureEvents
from .._parser import EarleyCommitParser, Parser
Expand Down Expand Up @@ -203,6 +205,11 @@ def __init__(self, tokenizer, compute_log_probs=False):
self._token_trie.match = True
self._token_trie.match_version = 0

self.metrics = GuidanceEngineMetrics()

def reset_metrics(self):
self.metrics = GuidanceEngineMetrics()

def start(self, parser, grammar, ensure_bos_token=True):
"""Start processing parser state executed through the grammar.

Expand Down Expand Up @@ -1626,4 +1633,4 @@ def _check_dominated(node, parser, match_version, next_byte_mask):
parser.pos = curr_pos
if not child_dominate:
return False
return True
return True
3 changes: 3 additions & 0 deletions guidance/models/llama_cpp/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,12 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
batch.logits[n_tokens - 1] = True

ret = llama_cpp.llama_decode(self.model_obj.ctx, batch)
self.metrics.engine_input_tokens += n_tokens
if ret != 0:
raise Exception(f"Call to llama_cpp.llama_decode returned {ret}.")

self.metrics.engine_output_tokens += 1

# get the logits
logits = llama_cpp.llama_get_logits(self.model_obj.ctx)
if llama_cpp.__version__ < "0.2.58":
Expand Down
6 changes: 6 additions & 0 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def _tokenizer(self, model, **kwargs):

return tokenizer

def __call__(self, byte_string):
tokenisation = self._orig_tokenizer(byte_string)
return tokenisation["input_ids"]


class TransformersEngine(Engine):
def __init__(self, model, tokenizer, compute_log_probs, **kwargs):
Expand Down Expand Up @@ -265,6 +269,8 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
self._cached_logits = (
model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy()
)
self.metrics.engine_input_tokens += len(new_token_ids)
self.metrics.engine_output_tokens += 1

return self._cached_logits

Expand Down
63 changes: 60 additions & 3 deletions tests/library/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from guidance import gen, models
from guidance import gen, models, select


def test_basic():
Expand Down Expand Up @@ -73,6 +73,56 @@ def test_stop_quote(selected_model):
assert not lm["title"].endswith('"')


def test_metrics_smoke(selected_model: models.Model):
lm = selected_model
lm.engine.reset_metrics()

lm += "abcd"
print(f"{lm.engine.metrics=}")
lm += gen("first", max_tokens=1)
print(f"{lm.engine.metrics=}")
# Can't be sure of exact count due to token healing
assert (
lm.engine.metrics.engine_output_tokens == 1
or lm.engine.metrics.engine_output_tokens == 2
)
assert lm.engine.metrics.engine_input_tokens >= 1
last_input_tokens = lm.engine.metrics.engine_input_tokens

lm += "fg"
lm += gen("second", max_tokens=1)
# Again, trouble with healing
assert (
lm.engine.metrics.engine_output_tokens >= 2
or lm.engine.metrics.engine_output_tokens <= 4
)
assert lm.engine.metrics.engine_input_tokens > last_input_tokens


def test_metrics_select(selected_model: models.Model):
lm = selected_model
lm.engine.reset_metrics()

lm += "I will "
lm += select(
[
"ride a bicycle down the road",
"row in a boat along the river",
"go for a swim in the ocean",
]
)
print(f"lm={str(lm)}")
print(f"{lm.engine.metrics=}")
assert lm.engine.metrics.engine_input_tokens > 1
assert lm.engine.metrics.engine_output_tokens > 0
# Guidance should be able to force the generation after only a couple of tokens
# so even though the options are long, relatively few output tokens should be
# needed
assert (
lm.engine.metrics.engine_input_tokens > lm.engine.metrics.engine_output_tokens
)


def test_unicode(selected_model):
# black makes this test ugly -- easier to read with fmt: off
# fmt: off
Expand All @@ -85,11 +135,18 @@ def test_unicode(selected_model):
# fmt: on


def test_unicode2(selected_model):
def test_unicode2(selected_model: models.Model):
lm = selected_model
lm.engine.reset_metrics()
prompt = "Janet’s ducks lay 16 eggs per day"
lm += prompt + gen(max_tokens=10)
assert True
assert lm.engine.metrics.engine_input_tokens > 1
# Due to token healing, we can't be sure of the
# precise output count
assert (
lm.engine.metrics.engine_output_tokens == 10
or lm.engine.metrics.engine_output_tokens == 11
)


def test_gsm8k():
Expand Down
Loading