Skip to content

Commit

Permalink
Latest attempt to get consistent token results
Browse files Browse the repository at this point in the history
  • Loading branch information
riedgar-ms committed May 8, 2024
1 parent b728b0f commit 9f330c3
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions guidance/models/_guidance_engine_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
class GuidanceEngineMetrics(BaseModel):
generated_tokens: NonNegativeInt = 0
forced_tokens: NonNegativeInt = 0
model_input_tokens: NonNegativeInt = 0
13 changes: 12 additions & 1 deletion guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class EngineCallResponse:
capture_groups: dict
capture_group_log_probs: dict
new_token_count: int
last_model_token_count: int

def __init__(
self,
Expand All @@ -140,13 +141,15 @@ def __init__(
capture_groups,
capture_group_log_probs,
new_token_count,
last_model_token_count,
):
self.new_bytes = new_bytes
self.is_generated = is_generated
self.new_bytes_prob = new_bytes_prob
self.capture_groups = capture_groups
self.capture_group_log_probs = capture_group_log_probs
self.new_token_count = new_token_count
self.last_model_token_count = last_model_token_count

def _to_proto(self):
"""Converts an EngineCallResponse object to its Protobuf representation.
Expand Down Expand Up @@ -739,6 +742,7 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
# TODO: remove this after the next release. This verifies that calling Rust works.
assert "def" == engine_start("abc", "def", 1)

last_model_token_count = 0
logits = None
while True:
is_done, logits_state, response_state = self.next(logits)
Expand All @@ -765,13 +769,19 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
capture_groups=response_capture_groups,
capture_group_log_probs=response_capture_group_log_probs,
new_token_count=response_new_token_count,
last_model_token_count=last_model_token_count,
)
last_model_token_count = 0

if logits_state is not None:
token_ids, forced_bytes, current_temp = logits_state
logits = self.get_logits(token_ids, forced_bytes, current_temp)
logits, model_token_count = self.get_logits(
token_ids, forced_bytes, current_temp
)
last_model_token_count = model_token_count

if is_done:
assert last_model_token_count == 0, "Unyielded input tokens"
break

def _tokenize_prefix(self, byte_string):
Expand Down Expand Up @@ -1393,6 +1403,7 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
self.engine_metrics.generated_tokens += chunk.new_token_count
else:
self.engine_metrics.forced_tokens += chunk.new_token_count
self.engine_metrics.model_input_tokens += chunk.last_model_token_count

# convert the bytes to a string (delaying if we don't yet have a valid unicode string)
lm.token_count += chunk.new_token_count
Expand Down
2 changes: 1 addition & 1 deletion guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy()
)

return self._cached_logits
return self._cached_logits, len(new_token_ids)


class Transformers(Model):
Expand Down
7 changes: 5 additions & 2 deletions tests/library/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@ def test_metrics_alt_expressions(selected_model: models.Model):
assert str(lm) == str(lm2)
assert lm.engine_metrics.generated_tokens == 10
assert lm2.engine_metrics.generated_tokens == 10
assert lm.engine_metrics.forced_tokens == 0
assert lm2.engine_metrics.forced_tokens == 0

assert (
lm.engine_metrics.forced_tokens + lm.engine_metrics.model_input_tokens
== lm2.engine_metrics.forced_tokens + lm2.engine_metrics.model_input_tokens
)


def test_unicode(selected_model):
Expand Down

0 comments on commit 9f330c3

Please sign in to comment.