Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
hudson-ai committed Oct 29, 2024
1 parent b4ef5f1 commit 35e7010
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import json
from functools import partial
from typing import Any, Dict, Set, Union, Optional
from typing import Any, Set, Union, Optional

import pytest
from jsonschema import validate, ValidationError
from json import dumps as json_dumps
from json import dumps as json_dumps, loads as json_loads

from guidance import json as gen_json
from guidance import models

from guidance.library._json import IGNORED_KEYS
from guidance.library._json import IGNORED_KEYS, JSONSchema

from ...utils import check_match_failure as _check_match_failure
from ...utils import check_run_with_temperature
from ...utils import generate_and_check as _generate_and_check


def generate_and_check(
target_obj: Any, schema_obj, desired_temperature: Optional[float] = None
target_obj: Any, schema_obj: Union[str, JSONSchema], desired_temperature: Optional[float] = None
):
if isinstance(schema_obj, str):
schema_obj = json_loads(schema_obj)

# Sanity check what we're being asked
validate(instance=target_obj, schema=schema_obj)
prepared_json = json_dumps(target_obj)
Expand Down Expand Up @@ -46,7 +49,7 @@ def check_match_failure(
good_bytes: Optional[bytes] = None,
failure_byte: Optional[bytes] = None,
allowed_bytes: Optional[Set[bytes]] = None,
schema_obj: Dict[str, Any],
schema_obj: Union[str, JSONSchema],
):
grammar = gen_json(schema=schema_obj)

Expand Down Expand Up @@ -3161,3 +3164,17 @@ def test_whitespace_flexibility(self, indent, separators, schema, obj):
assert grammar.match(prepared_json, raise_exceptions=True) is not None
model = models.Mock(f"<s>{prepared_json}".encode())
assert str(model + grammar) == prepared_json


class TestStringSchema:
def test_good(self):
schema = """{"type": "object", "properties": {"a": {"type": "string"}}}"""
target_obj = {"a": "hello"}
generate_and_check(target_obj, schema)

def test_bad(self):
schema = """{"type": "object", "properties": {"a": {"type": "string"}}}"""
check_match_failure(
bad_string='{"a": 42}',
schema_obj=schema,
)

0 comments on commit 35e7010

Please sign in to comment.