"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" # flake8: noqa from pathlib import Path from itertools import groupby from typing import ( Any, Set, List, Optional, Tuple, Union, ) LLAMA_GRAMMAR_DEFAULT_ROOT = "root" class LlamaGrammar: def __init__(self, *args, _grammar: str, **kwargs): self._grammar = _grammar self._root = LLAMA_GRAMMAR_DEFAULT_ROOT @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": return cls(_grammar=grammar) @classmethod def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": try: with open(file) as f: grammar = f.read() except Exception as err: raise Exception( f"{cls.from_file.__name__}: error reading grammar file: {err}" ) if grammar: return cls.from_string(grammar, verbose=verbose) raise ValueError( f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) @classmethod def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar": return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose) """llama.cpp gbnf rules from vendor/llama.cpp/grammars""" ARITHMETIC_GBNF = r""" root ::= (expr "=" ws term "\n")+ expr ::= term ([-+*/] term)* term ::= ident | num | "(" ws expr ")" ws ident ::= [a-z] [a-z0-9_]* ws num ::= [0-9]+ ws ws ::= [ \t\n]* """ C_GBNF = r""" root ::= (declaration)* declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" dataType ::= "int" ws | "float" ws | "char" ws identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* parameter ::= dataType identifier statement ::= ( dataType identifier ws "=" ws expression ";" ) | ( identifier ws "=" ws expression ";" ) | ( identifier ws "(" argList? ")" ";" ) | ( "return" ws expression ";" ) | ( "while" "(" condition ")" "{" statement* "}" ) | ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | ( singleLineComment ) | ( multiLineComment ) forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression forUpdate ::= identifier ws "=" ws expression condition ::= expression relationOperator expression relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") expression ::= term (("+" | "-") term)* term ::= factor(("*" | "/") factor)* factor ::= identifier | number | unaryTerm | funcCall | parenExpression unaryTerm ::= "-" factor funcCall ::= identifier "(" argList? ")" parenExpression ::= "(" ws expression ws ")" argList ::= expression ("," ws expression)* number ::= [0-9]+ singleLineComment ::= "//" [^\n]* "\n" multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" ws ::= ([ \t\n]+) """ CHESS_GBNF = r""" root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= ([ \t\n] ws)? """ JAPANESE_GBNF = r""" root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= ([ \t\n] ws)? """ JSON_ARR_GBNF = r""" # This is the same as json.gbnf but we restrict whitespaces at the end of the root array # Useful for generating JSON arrays root ::= arr value ::= object | array | string | number | ("true" | "false" | "null") ws arr ::= "[\n" ws ( value (",\n" ws value)* )? "]" object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\\x7F\x00-\x1F] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= ([ \t\n] ws)? """ JSON_GBNF = r""" root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\\x7F\x00-\x1F] | "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= | " " | "\n" [ \t]{0,20} """ LIST_GBNF = r""" root ::= item+ # Excludes various line break characters item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" """ """llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" import json import re from typing import List, Optional # whitespace is constrained to a single space char to prevent model "running away" in # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} # whitespace is constrained to a single space char to prevent model "running away" in # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' def _build_repetition( item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False ): if not separator_rule: if min_items == 0 and max_items == 1: return f"{item_rule}?" elif min_items == 1 and max_items is None: return f"{item_rule}+" result = "" if min_items > 0: if item_rule_is_literal and separator_rule is None: result = '"' + (item_rule[1:-1] * min_items) + '"' else: result = (f" {separator_rule} " if separator_rule else " ").join( [item_rule] * min_items ) def opt_repetitions(up_to_n, prefix_with_sep=False): """ - n=4, no sep: '(a (a (a (a)?)?)?)?' - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' """ content = ( f"{separator_rule} {item_rule}" if prefix_with_sep and separator_rule else item_rule ) if up_to_n == 0: return "" elif up_to_n == 1: return f"({content})?" elif separator_rule and not prefix_with_sep: return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?" else: return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n) if min_items > 0 and max_items != min_items: result += " " if max_items is not None: result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) else: item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' if min_items == 0 and separator_rule: result = f"({item_rule} {item_operator}*)?" else: result += f"{item_operator}*" return result class BuiltinRule: def __init__(self, content: str, deps: list = None): self.content = content self.deps = deps or [] _up_to_15_digits = _build_repetition("[0-9]", 0, 15) PRIMITIVE_RULES = { "boolean": BuiltinRule('("true" | "false") space', []), "decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []), "integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []), "number": BuiltinRule( '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ["integral-part", "decimal-part"], ), "integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]), "value": BuiltinRule( "object | array | string | number | boolean | null", ["object", "array", "string", "number", "boolean", "null"], ), "object": BuiltinRule( '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ["string", "value"], ), "array": BuiltinRule( '"[" space ( value ("," space value)* )? "]" space', ["value"] ), "uuid": BuiltinRule( r'"\"" ' + ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', [], ), "char": BuiltinRule( r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', [], ), "string": BuiltinRule(r'"\"" char* "\"" space', ["char"]), "null": BuiltinRule('"null" space', []), } # TODO: support "uri", "email" string formats STRING_FORMAT_RULES = { "date": BuiltinRule( '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )', [], ), "time": BuiltinRule( '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', [], ), "date-time": BuiltinRule('date "T" time', ["date", "time"]), "date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]), "time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]), "date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]), } DOTALL = "[\\U00000000-\\U0010FFFF]" DOT = "[^\\x0A\\x0D]" RESERVED_NAMES = set( ["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()] ) NON_LITERAL_SET = set("|.()[]{}*+?") ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?") class SchemaConverter: def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): self._prop_order = prop_order self._allow_fetch = allow_fetch self._dotall = dotall self._raw_pattern = raw_pattern self._rules = { "space": SPACE_RULE, } self._refs = {} self._refs_being_resolved = set() def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal ) return f'"{escaped}"' def not_literal( self, literal: str, dotall: bool = True, maybe_escaped_underscores=False ) -> str: """ not_literal('a') -> '[^a]' not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' """ assert len(literal) > 0, "Empty literal not supported" def recurse(i: int): c = literal[i] if maybe_escaped_underscores and c == "_": yield f"[^{c}\\\\]" yield " | " yield f'"\\\\"? "{c}"' else: yield f"[^{c}]" if i < len(literal) - 1: yield " | " yield self._format_literal(c) yield " (" yield from recurse(i + 1) yield ")?" return "".join(("(", *recurse(0), ")")) def _add_rule(self, name, rule): esc_name = INVALID_RULE_CHARS_RE.sub("-", name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name else: i = 0 while ( f"{esc_name}{i}" in self._rules and self._rules[f"{esc_name}{i}"] != rule ): i += 1 key = f"{esc_name}{i}" self._rules[key] = rule return key def resolve_refs(self, schema: dict, url: str): """ Resolves all $ref fields in the given schema, fetching any remote schemas, replacing $ref with absolute reference URL and populating self._refs with the respective referenced (sub)schema dictionaries. """ def visit(n: dict): if isinstance(n, list): return [visit(x) for x in n] elif isinstance(n, dict): ref = n.get("$ref") if ref is not None and ref not in self._refs: if ref.startswith("https://"): assert ( self._allow_fetch ), "Fetching remote schemas is not allowed (use --allow-fetch for force)" import requests frag_split = ref.split("#") base_url = frag_split[0] target = self._refs.get(base_url) if target is None: target = self.resolve_refs( requests.get(ref).json(), base_url ) self._refs[base_url] = target if len(frag_split) == 1 or frag_split[-1] == "": return target elif ref.startswith("#/"): target = schema ref = f"{url}{ref}" n["$ref"] = ref else: raise ValueError(f"Unsupported ref {ref}") for sel in ref.split("#")[-1].split("/")[1:]: assert ( target is not None and sel in target ), f"Error resolving ref {ref}: {sel} not in {target}" target = target[sel] self._refs[ref] = target else: for v in n.values(): visit(v) return n return visit(schema) def _generate_union_rule(self, name, alt_schemas): return " | ".join( ( self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') for i, alt_schema in enumerate(alt_schemas) ) ) def _visit_pattern(self, pattern, name): """ Transforms a regular expression pattern into a GBNF rule. Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which we define sub-rules to keep the output lean. """ assert pattern.startswith("^") and pattern.endswith( "$" ), 'Pattern must start with "^" and end with "$"' pattern = pattern[1:-1] sub_rule_ids = {} i = 0 length = len(pattern) def to_rule(s: Tuple[str, bool]) -> str: (txt, is_literal) = s return '"' + txt + '"' if is_literal else txt def transform() -> Tuple[str, bool]: """ Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. """ nonlocal i nonlocal pattern nonlocal sub_rule_ids start = i # For each component of this sequence, store its string representation and whether it's a literal. # We only need a flat structure here to apply repetition operators to the last item, and # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially # (GBNF's syntax is luckily very close to regular expressions!) seq: list[Tuple[str, bool]] = [] def get_dot(): if self._dotall: rule = DOTALL else: # Accept any character... except \n and \r line break chars (\x0A and \xOD) rule = DOT return self._add_rule(f"dot", rule) def join_seq(): nonlocal seq ret = [] for is_literal, g in groupby(seq, lambda x: x[1]): if is_literal: ret.append(("".join(x[0] for x in g), True)) else: ret.extend(g) if len(ret) == 1: return ret[0] return (" ".join(to_rule(x) for x in seq), False) while i < length: c = pattern[i] if c == ".": seq.append((get_dot(), False)) i += 1 elif c == "(": i += 1 if i < length: assert ( pattern[i] != "?" ), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' seq.append((f"({to_rule(transform())})", False)) elif c == ")": i += 1 assert ( start > 0 and pattern[start - 1] == "(" ), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}" return join_seq() elif c == "[": square_brackets = c i += 1 while i < length and pattern[i] != "]": if pattern[i] == "\\": square_brackets += pattern[i : i + 2] i += 2 else: square_brackets += pattern[i] i += 1 assert ( i < length ), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}" square_brackets += "]" i += 1 seq.append((square_brackets, False)) elif c == "|": seq.append(("|", False)) i += 1 elif c in ("*", "+", "?"): seq[-1] = (to_rule(seq[-1]) + c, False) i += 1 elif c == "{": curly_brackets = c i += 1 while i < length and pattern[i] != "}": curly_brackets += pattern[i] i += 1 assert ( i < length ), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}" curly_brackets += "}" i += 1 nums = [s.strip() for s in curly_brackets[1:-1].split(",")] min_times = 0 max_times = None try: if len(nums) == 1: min_times = int(nums[0]) max_times = min_times else: assert len(nums) == 2 min_times = int(nums[0]) if nums[0] else 0 max_times = int(nums[1]) if nums[1] else None except ValueError: raise ValueError( f"Invalid quantifier {curly_brackets} in /{pattern}/" ) (sub, sub_is_literal) = seq[-1] if not sub_is_literal: id = sub_rule_ids.get(sub) if id is None: id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub) sub_rule_ids[sub] = id sub = id seq[-1] = ( _build_repetition( f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal, ), False, ) else: literal = "" while i < length: if pattern[i] == "\\" and i < length - 1: next = pattern[i + 1] if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: i += 1 literal += pattern[i] i += 1 else: literal += pattern[i : i + 2] i += 2 elif pattern[i] == '"' and not self._raw_pattern: literal += '\\"' i += 1 elif pattern[i] not in NON_LITERAL_SET and ( i == length - 1 or literal == "" or pattern[i + 1] == "." or pattern[i + 1] not in NON_LITERAL_SET ): literal += pattern[i] i += 1 else: break if literal: seq.append((literal, True)) return join_seq() return self._add_rule( name, ( to_rule(transform()) if self._raw_pattern else '"\\"" ' + to_rule(transform()) + ' "\\"" space' ), ) def _resolve_ref(self, ref): ref_name = ref.split("/")[-1] if ref_name not in self._rules and ref not in self._refs_being_resolved: self._refs_being_resolved.add(ref) resolved = self._refs[ref] ref_name = self.visit(resolved, ref_name) self._refs_being_resolved.remove(ref) return ref_name def _generate_constant_rule(self, value): return self._format_literal(json.dumps(value)) def visit(self, schema, name): schema_type = schema.get("type") schema_format = schema.get("format") rule_name = name + "-" if name in RESERVED_NAMES else name or "root" if (ref := schema.get("$ref")) is not None: return self._add_rule(rule_name, self._resolve_ref(ref)) elif "oneOf" in schema or "anyOf" in schema: return self._add_rule( rule_name, self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]), ) elif isinstance(schema_type, list): return self._add_rule( rule_name, self._generate_union_rule(name, [{"type": t} for t in schema_type]), ) elif "const" in schema: return self._add_rule( rule_name, self._generate_constant_rule(schema["const"]) ) elif "enum" in schema: rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"])) return self._add_rule(rule_name, rule) elif schema_type in (None, "object") and ( "properties" in schema or ( "additionalProperties" in schema and schema["additionalProperties"] is not True ) ): required = set(schema.get("required", [])) properties = list(schema.get("properties", {}).items()) return self._add_rule( rule_name, self._build_object_rule( properties, required, name, schema.get("additionalProperties") ), ) elif schema_type in (None, "object") and "allOf" in schema: required = set() properties = [] hybrid_name = name def add_component(comp_schema, is_required): if (ref := comp_schema.get("$ref")) is not None: comp_schema = self._refs[ref] if "properties" in comp_schema: for prop_name, prop_schema in comp_schema["properties"].items(): properties.append((prop_name, prop_schema)) if is_required: required.add(prop_name) for t in schema["allOf"]: if "anyOf" in t: for tt in t["anyOf"]: add_component(tt, is_required=False) else: add_component(t, is_required=True) return self._add_rule( rule_name, self._build_object_rule( properties, required, hybrid_name, additional_properties=[] ), ) elif schema_type in (None, "array") and ( "items" in schema or "prefixItems" in schema ): items = schema.get("items") or schema["prefixItems"] if isinstance(items, list): return self._add_rule( rule_name, '"[" space ' + ' "," space '.join( self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') for i, item in enumerate(items) ) + ' "]" space', ) else: item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') min_items = schema.get("minItems", 0) max_items = schema.get("maxItems") return self._add_rule( rule_name, '"[" space ' + _build_repetition( item_rule_name, min_items, max_items, separator_rule='"," space' ) + ' "]" space', ) elif schema_type in (None, "string") and "pattern" in schema: return self._visit_pattern(schema["pattern"], rule_name) elif schema_type in (None, "string") and re.match( r"^uuid[1-5]?$", schema_format or "" ): return self._add_primitive( "root" if rule_name == "root" else schema_format, PRIMITIVE_RULES["uuid"], ) elif ( schema_type in (None, "string") and f"{schema_format}-string" in STRING_FORMAT_RULES ): prim_name = f"{schema_format}-string" return self._add_rule( rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]), ) elif schema_type == "string" and ( "minLength" in schema or "maxLength" in schema ): char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"]) min_len = schema.get("minLength", 0) max_len = schema.get("maxLength") return self._add_rule( rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space', ) elif (schema_type == "object") or (len(schema) == 0): return self._add_rule( rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"]) ) else: assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero return self._add_primitive( "root" if rule_name == "root" else schema_type, PRIMITIVE_RULES[schema_type], ) def _add_primitive(self, name: str, rule: BuiltinRule): n = self._add_rule(name, rule.content) for dep in rule.deps: dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) assert dep_rule, f"Rule {dep} not known" if dep not in self._rules: self._add_primitive(dep, dep_rule) return n def _build_object_rule( self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any], ): prop_order = self._prop_order # sort by position in prop_order (if specified) then by original order sorted_props = [ kv[0] for _, kv in sorted( enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]), ) ] prop_kv_rule_names = {} for prop_name, prop_schema in properties: prop_rule_name = self.visit( prop_schema, f'{name}{"-" if name else ""}{prop_name}' ) prop_kv_rule_names[prop_name] = self._add_rule( f'{name}{"-" if name else ""}{prop_name}-kv', rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}', ) required_props = [k for k in sorted_props if k in required] optional_props = [k for k in sorted_props if k not in required] if additional_properties == True or isinstance(additional_properties, dict): sub_name = f'{name}{"-" if name else ""}additional' value_rule = self.visit( {} if additional_properties == True else additional_properties, f"{sub_name}-value", ) prop_kv_rule_names["*"] = self._add_rule( f"{sub_name}-kv", self._add_primitive("string", PRIMITIVE_RULES["string"]) + f' ":" space {value_rule}', ) optional_props.append("*") rule = '"{" space ' rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) if optional_props: rule += " (" if required_props: rule += ' "," space ( ' def get_recursive_refs(ks, first_is_optional): [k, *rest] = ks kv_rule_name = prop_kv_rule_names[k] if k == "*": res = self._add_rule( f'{name}{"-" if name else ""}additional-kvs', f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*", ) elif first_is_optional: res = f'( "," space {kv_rule_name} )?' else: res = kv_rule_name if len(rest) > 0: res += " " + self._add_rule( f'{name}{"-" if name else ""}{k}-rest', get_recursive_refs(rest, first_is_optional=True), ) return res rule += " | ".join( get_recursive_refs(optional_props[i:], first_is_optional=False) for i in range(len(optional_props)) ) if required_props: rule += " )" rule += " )?" rule += ' "}" space' return rule def format_grammar(self): return "\n".join( f"{name} ::= {rule}" for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) ) def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): prop_order = prop_order or [] schema = json.loads(schema) prop_order = {name: idx for idx, name in enumerate(prop_order)} converter = SchemaConverter( prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False ) schema = converter.resolve_refs(schema, "stdin") converter.visit(schema, "") return converter.format_grammar()