Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow json-loads-able strings to be passed as schema #1028

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from json import dumps as json_dumps
from json import dumps as json_dumps, loads as json_loads
from enum import Enum
import math
from typing import (
Expand Down Expand Up @@ -957,6 +957,7 @@ def json(
*,
schema: Union[
None,
str,
JSONSchema,
Type["pydantic.BaseModel"],
"pydantic.TypeAdapter",
Expand Down Expand Up @@ -1005,6 +1006,7 @@ def json(
schema : Union[None, Mapping[str, Any], Type[pydantic.BaseModel], pydantic.TypeAdapter]
One of:
- None, in which case any valid JSON will be generated
- A string representing a JSON schema which will be parsed using ``json.loads()``
- A JSON schema object. This is a JSON schema string which has been passed to ``json.loads()``
- A subclass of ``pydantic.BaseModel``
- An instance of ``pydantic.TypeAdapter``
Expand All @@ -1018,7 +1020,9 @@ def json(
# Default schema is empty, "anything goes" schema
# TODO: consider default being `{"type": "object"}`
schema = {}
elif isinstance(schema, (Mapping, bool)):
elif isinstance(schema, (Mapping, bool, str)):
if isinstance(schema, str):
schema = cast(JSONSchema, json_loads(schema))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@riedgar-ms done :)

# Raises jsonschema.exceptions.SchemaError or ValueError
# if schema is not valid
jsonschema.validators.Draft202012Validator.check_schema(schema)
Expand Down
27 changes: 22 additions & 5 deletions tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import json
from functools import partial
from typing import Any, Dict, Set, Union, Optional
from typing import Any, Set, Union, Optional

import pytest
from jsonschema import validate, ValidationError
from json import dumps as json_dumps
from json import dumps as json_dumps, loads as json_loads

from guidance import json as gen_json
from guidance import models

from guidance.library._json import IGNORED_KEYS
from guidance.library._json import IGNORED_KEYS, JSONSchema

from ...utils import check_match_failure as _check_match_failure
from ...utils import check_run_with_temperature
from ...utils import generate_and_check as _generate_and_check


def generate_and_check(
target_obj: Any, schema_obj, desired_temperature: Optional[float] = None
target_obj: Any, schema_obj: Union[str, JSONSchema], desired_temperature: Optional[float] = None
):
if isinstance(schema_obj, str):
schema_obj = json_loads(schema_obj)

# Sanity check what we're being asked
validate(instance=target_obj, schema=schema_obj)
prepared_json = json_dumps(target_obj)
Expand Down Expand Up @@ -46,7 +49,7 @@ def check_match_failure(
good_bytes: Optional[bytes] = None,
failure_byte: Optional[bytes] = None,
allowed_bytes: Optional[Set[bytes]] = None,
schema_obj: Dict[str, Any],
schema_obj: Union[str, JSONSchema],
):
grammar = gen_json(schema=schema_obj)

Expand Down Expand Up @@ -3270,3 +3273,17 @@ def test_whitespace_flexibility(self, indent, separators, schema, obj):
assert grammar.match(prepared_json, raise_exceptions=True) is not None
model = models.Mock(f"<s>{prepared_json}".encode())
assert str(model + grammar) == prepared_json


class TestStringSchema:
def test_good(self):
schema = """{"type": "object", "properties": {"a": {"type": "string"}}}"""
target_obj = {"a": "hello"}
generate_and_check(target_obj, schema)

def test_bad(self):
schema = """{"type": "object", "properties": {"a": {"type": "string"}}}"""
check_match_failure(
bad_string='{"a": 42}',
schema_obj=schema,
)
Loading