Skip to content

Commit

Permalink
Merge branch 'main' into strict_properties
Browse files Browse the repository at this point in the history
  • Loading branch information
hudson-ai committed Oct 29, 2024
2 parents f43b5c3 + c8c6a11 commit fab1385
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 31 deletions.
72 changes: 58 additions & 14 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from json import dumps as json_dumps
from json import dumps as json_dumps, loads as json_loads
from enum import Enum
import math
from typing import (
Expand Down Expand Up @@ -165,17 +165,17 @@ class ObjectKeywords(str, Enum):
JSONType.OBJECT: ObjectKeywords,
}

DEFS_KEYS = {"$defs", "definitions"}

IGNORED_KEYS = {
"$anchor",
"$defs",
"$schema",
"$id",
"id",
"$comment",
"title",
"description",
"default",
"definitions",
"description",
"examples",
}

Expand All @@ -188,7 +188,7 @@ class ObjectKeywords(str, Enum):
IGNORED_KEYS.add("discriminator")

WHITESPACE = {b" ", b"\t", b"\n", b"\r"}
VALID_KEYS = set(Keyword) | IGNORED_KEYS | DEFS_KEYS | set(NumberKeywords) | set(StringKeywords) | set(ArrayKeywords) | set(ObjectKeywords)
VALID_KEYS = set(Keyword) | set(NumberKeywords) | set(StringKeywords) | set(ArrayKeywords) | set(ObjectKeywords) | IGNORED_KEYS

FORMAT_PATTERNS: dict[str, Optional[str]] = {
# https://json-schema.org/understanding-json-schema/reference/string#built-in-formats
Expand Down Expand Up @@ -398,6 +398,11 @@ def validate_json_node_keys(node: Mapping[str, Any]):
)


def get_sibling_keys(node: Mapping[str, Any], key: str) -> set[str]:
# Get the set of functional (non-ignored) keys that are siblings of the given key
return set(node.keys()) & VALID_KEYS - set(IGNORED_KEYS) - {key}


class GenJson:
item_separator = ", "
key_separator = ": "
Expand Down Expand Up @@ -724,7 +729,20 @@ def const(
lm,
*,
value: Union[None, bool, int, float, str, Mapping, Sequence],
instance_type: Optional[Union[str, Sequence[str]]] = None,
enum: Optional[Sequence[Union[None, bool, int, float, str, Mapping, Sequence]]] = None,
):
schema_to_validate_against: dict[str, Any] = {}
if instance_type is not None:
schema_to_validate_against["type"] = instance_type
if enum is not None:
schema_to_validate_against["enum"] = enum
if schema_to_validate_against:
# Raise a validation error if the value doesn't match the type
jsonschema.validate(
instance=value,
schema=schema_to_validate_against,
)
# Base case
if isinstance(value, (type(None), bool, int, float, str)):
return lm + json_dumps(value)
Expand Down Expand Up @@ -757,14 +775,18 @@ def enum(
self,
lm,
*,
options: Sequence[Mapping[str, Any]]
options: Sequence[Union[None, bool, int, float, str, Mapping, Sequence]],
instance_type: Optional[Union[str, Sequence[str]]] = None,
):
# TODO: can we support a whitespace-flexible version of this?
all_opts: list[GrammarFunction] = []
for opt in options:
all_opts.append(
self.const(value=opt)
)
for instance in options:
try:
grm = self.const(value=instance, instance_type=instance_type)
except jsonschema.ValidationError:
continue
all_opts.append(grm)
if not all_opts:
raise ValueError(f"No valid options found for enum with type {instance_type!r}: {options}")
return lm + select(options=all_opts)


Expand Down Expand Up @@ -802,29 +824,47 @@ def json(
validate_json_node_keys(json_schema)

if Keyword.ANYOF in json_schema:
sibling_keys = get_sibling_keys(json_schema, Keyword.ANYOF)
if sibling_keys:
raise NotImplementedError(f"anyOf with sibling keys is not yet supported. Got {sibling_keys}")
return lm + self.anyOf(anyof_list=json_schema[Keyword.ANYOF])

if Keyword.ALLOF in json_schema:
sibling_keys = get_sibling_keys(json_schema, Keyword.ALLOF)
if sibling_keys:
raise NotImplementedError(f"allOf with sibling keys is not yet supported. Got {sibling_keys}")
allof_list = json_schema[Keyword.ALLOF]
if len(allof_list) != 1:
raise ValueError("Only support allOf with exactly one item")
return lm + self.json(json_schema=allof_list[0])

if Keyword.ONEOF in json_schema:
sibling_keys = get_sibling_keys(json_schema, Keyword.ONEOF)
if sibling_keys:
raise NotImplementedError(f"oneOf with sibling keys is not yet supported. Got {sibling_keys}")
oneof_list = json_schema[Keyword.ONEOF]
if len(oneof_list) == 1:
return lm + self.json(json_schema=oneof_list[0])
warnings.warn("oneOf not fully supported, falling back to anyOf. This may cause validation errors in some cases.")
return lm + self.anyOf(anyof_list=oneof_list)

if Keyword.REF in json_schema:
sibling_keys = get_sibling_keys(json_schema, Keyword.REF)
if sibling_keys:
raise NotImplementedError(f"$ref with sibling keys is not yet supported. Got {sibling_keys}")
return lm + self.ref(reference=json_schema[Keyword.REF])

if Keyword.CONST in json_schema:
return lm + self.const(value=json_schema[Keyword.CONST])
sibling_keys = get_sibling_keys(json_schema, Keyword.CONST) - {Keyword.TYPE, Keyword.ENUM}
if sibling_keys:
raise NotImplementedError(f"const with sibling keys is not yet supported. Got {sibling_keys}")
return lm + self.const(value=json_schema[Keyword.CONST], instance_type=json_schema.get(Keyword.TYPE, None), enum=json_schema.get(Keyword.ENUM, None))

if Keyword.ENUM in json_schema:
return lm + self.enum(options=json_schema[Keyword.ENUM])
sibling_keys = get_sibling_keys(json_schema, Keyword.ENUM) - {Keyword.TYPE}
if sibling_keys:
raise NotImplementedError(f"enum with sibling keys is not yet supported. Got {sibling_keys}")
return lm + self.enum(options=json_schema[Keyword.ENUM], instance_type=json_schema.get(Keyword.TYPE, None))

if Keyword.TYPE in json_schema:
target_types = cast(Union[str, Sequence[str]], json_schema[Keyword.TYPE])
Expand Down Expand Up @@ -911,6 +951,7 @@ def json(
*,
schema: Union[
None,
str,
JSONSchema,
Type["pydantic.BaseModel"],
"pydantic.TypeAdapter",
Expand Down Expand Up @@ -960,6 +1001,7 @@ def json(
schema : Union[None, Mapping[str, Any], Type[pydantic.BaseModel], pydantic.TypeAdapter]
One of:
- None, in which case any valid JSON will be generated
- A string representing a JSON schema which will be parsed using ``json.loads()``
- A JSON schema object. This is a JSON schema string which has been passed to ``json.loads()``
- A subclass of ``pydantic.BaseModel``
- An instance of ``pydantic.TypeAdapter``
Expand Down Expand Up @@ -993,7 +1035,9 @@ def json(
# In this case, we don't want to use strict_properties
# because we're not actually validating against a schema
strict_properties = False
elif isinstance(schema, (Mapping, bool)):
elif isinstance(schema, (Mapping, bool, str)):
if isinstance(schema, str):
schema = cast(JSONSchema, json_loads(schema))
# Raises jsonschema.exceptions.SchemaError or ValueError
# if schema is not valid
jsonschema.validators.Draft202012Validator.check_schema(schema)
Expand Down
Loading

0 comments on commit fab1385

Please sign in to comment.