# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import json import re import types from contextlib import contextmanager from datetime import datetime from functools import lru_cache from inspect import isfunction from typing import ( Any, Callable, Literal, Optional, Union, get_args, get_origin, get_type_hints, ) from packaging import version from . import logging from .import_utils import is_jinja_available, is_torch_available, is_vision_available logger = logging.get_logger(__name__) if is_jinja_available(): import jinja2 from jinja2.ext import Extension from jinja2.sandbox import ImmutableSandboxedEnvironment else: jinja2 = None if is_vision_available(): from PIL.Image import Image if is_torch_available(): from torch import Tensor BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) # Extracts the initial segment of the docstring, containing the function description description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) # Extracts the Args: block from the docstring args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) # Splits the Args: block into individual arguments args_split_re = re.compile( r""" (?:^|\n) # Match the start of the args block, or a newline \s*(\w+):\s* # Capture the argument name and strip spacing (.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing (?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block """, re.DOTALL | re.VERBOSE, ) # Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc! returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) class TypeHintParsingException(Exception): """Exception raised for errors in parsing type hints to generate JSON schemas""" pass class DocstringParsingException(Exception): """Exception raised for errors in parsing docstrings to generate JSON schemas""" pass def _get_json_schema_type(param_type: type) -> dict[str, str]: type_mapping = { int: {"type": "integer"}, float: {"type": "number"}, str: {"type": "string"}, bool: {"type": "boolean"}, type(None): {"type": "null"}, Any: {}, } if is_vision_available(): type_mapping[Image] = {"type": "image"} if is_torch_available(): type_mapping[Tensor] = {"type": "audio"} return type_mapping.get(param_type, {"type": "object"}) def _parse_type_hint(hint: str) -> dict: origin = get_origin(hint) args = get_args(hint) if origin is None: try: return _get_json_schema_type(hint) except KeyError: raise TypeHintParsingException( "Couldn't parse this type hint, likely due to a custom class or object: ", hint ) elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType): # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end subtypes = [_parse_type_hint(t) for t in args if t is not type(None)] if len(subtypes) == 1: # A single non-null type can be expressed directly return_dict = subtypes[0] elif all(isinstance(subtype["type"], str) for subtype in subtypes): # A union of basic types can be expressed as a list in the schema return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])} else: # A union of more complex types requires "anyOf" return_dict = {"anyOf": subtypes} if type(None) in args: return_dict["nullable"] = True return return_dict elif origin is Literal and len(args) > 0: LITERAL_TYPES = (int, float, str, bool, type(None)) args_types = [] for arg in args: if type(arg) not in LITERAL_TYPES: raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.") arg_type = _get_json_schema_type(type(arg)).get("type") if arg_type is not None and arg_type not in args_types: args_types.append(arg_type) return { "type": args_types.pop() if len(args_types) == 1 else list(args_types), "enum": list(args), } elif origin is list: if not args: return {"type": "array"} else: # Lists can only have a single type argument, so recurse into it return {"type": "array", "items": _parse_type_hint(args[0])} elif origin is tuple: if not args: return {"type": "array"} if len(args) == 1: raise TypeHintParsingException( f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which " "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain " "more than one element, we recommend " "using a list[] type instead, or if it really is a single element, remove the tuple[] wrapper and just " "pass the element directly." ) if ... in args: raise TypeHintParsingException( "Conversion of '...' is not supported in Tuple type hints. " "Use list[] types for variable-length" " inputs instead." ) return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" out = {"type": "object"} if len(args) == 2: out["additionalProperties"] = _parse_type_hint(args[1]) return out raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint) def _convert_type_hints_to_json_schema(func: Callable) -> dict: type_hints = get_type_hints(func) signature = inspect.signature(func) required = [] for param_name, param in signature.parameters.items(): if param.annotation == inspect.Parameter.empty: raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") if param.default == inspect.Parameter.empty: required.append(param_name) properties = {} for param_name, param_type in type_hints.items(): properties[param_name] = _parse_type_hint(param_type) schema = {"type": "object", "properties": properties} if required: schema["required"] = required return schema def parse_google_format_docstring(docstring: str) -> tuple[Optional[str], Optional[dict], Optional[str]]: """ Parses a Google-style docstring to extract the function description, argument descriptions, and return description. Args: docstring (str): The docstring to parse. Returns: The function description, arguments, and return description. """ # Extract the sections description_match = description_re.search(docstring) args_match = args_re.search(docstring) returns_match = returns_re.search(docstring) # Clean and store the sections description = description_match.group(1).strip() if description_match else None docstring_args = args_match.group(1).strip() if args_match else None returns = returns_match.group(1).strip() if returns_match else None # Parsing the arguments into a dictionary if docstring_args is not None: docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines matches = args_split_re.findall(docstring_args) args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches} else: args_dict = {} return description, args_dict, returns def get_json_schema(func: Callable) -> dict: """ This function generates a JSON schema for a given function, based on its docstring and type hints. This is mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires that the function has a docstring, and that each argument has a description in the docstring, in the standard Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint. Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is optional because most chat templates ignore the return value of the function. Args: func: The function to generate a JSON schema for. Returns: A dictionary containing the JSON schema for the function. Examples: ```python >>> def multiply(x: float, y: float): >>> ''' >>> A function that multiplies two numbers >>> >>> Args: >>> x: The first number to multiply >>> y: The second number to multiply >>> ''' >>> return x * y >>> >>> print(get_json_schema(multiply)) { "name": "multiply", "description": "A function that multiplies two numbers", "parameters": { "type": "object", "properties": { "x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"} }, "required": ["x", "y"] } } ``` The general use for these schemas is that they are used to generate tool descriptions for chat templates that support them, like so: ```python >>> from transformers import AutoTokenizer >>> from transformers.utils import get_json_schema >>> >>> def multiply(x: float, y: float): >>> ''' >>> A function that multiplies two numbers >>> >>> Args: >>> x: The first number to multiply >>> y: The second number to multiply >>> return x * y >>> ''' >>> >>> multiply_schema = get_json_schema(multiply) >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] >>> formatted_chat = tokenizer.apply_chat_template( >>> messages, >>> tools=[multiply_schema], >>> chat_template="tool_use", >>> return_dict=True, >>> return_tensors="pt", >>> add_generation_prompt=True >>> ) >>> # The formatted chat can now be passed to model.generate() ``` Each argument description can also have an optional `(choices: ...)` block at the end, such as `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will only be parsed correctly if it is at the end of the line: ```python >>> def drink_beverage(beverage: str): >>> ''' >>> A function that drinks a beverage >>> >>> Args: >>> beverage: The beverage to drink (choices: ["tea", "coffee"]) >>> ''' >>> pass >>> >>> print(get_json_schema(drink_beverage)) ``` { 'name': 'drink_beverage', 'description': 'A function that drinks a beverage', 'parameters': { 'type': 'object', 'properties': { 'beverage': { 'type': 'string', 'enum': ['tea', 'coffee'], 'description': 'The beverage to drink' } }, 'required': ['beverage'] } } """ doc = inspect.getdoc(func) if not doc: raise DocstringParsingException( f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" ) doc = doc.strip() main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) json_schema = _convert_type_hints_to_json_schema(func) if (return_dict := json_schema["properties"].pop("return", None)) is not None: if return_doc is not None: # We allow a missing return docstring since most templates ignore it return_dict["description"] = return_doc for arg, schema in json_schema["properties"].items(): if arg not in param_descriptions: raise DocstringParsingException( f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" ) desc = param_descriptions[arg] enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) if enum_choices: schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] desc = enum_choices.string[: enum_choices.start()].strip() schema["description"] = desc output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} if return_dict is not None: output["return"] = return_dict return {"type": "function", "function": output} def _render_with_assistant_indices( compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs ): rendered_blocks = [] generation_indices = [] with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices): for block in compiled_template.generate( messages=messages, tools=tools, documents=documents, add_generation_prompt=add_generation_prompt, **template_kwargs, ): rendered_blocks.append(block) rendered_chat = "".join(rendered_blocks) return rendered_chat, generation_indices @lru_cache def _compile_jinja_template(chat_template): if not is_jinja_available(): raise ImportError( "apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`." ) class AssistantTracker(Extension): # This extension is used to track the indices of assistant-generated tokens in the rendered chat tags = {"generation"} def __init__(self, environment: ImmutableSandboxedEnvironment): # The class is only initiated by jinja. super().__init__(environment) environment.extend(activate_tracker=self.activate_tracker) self._rendered_blocks = None self._generation_indices = None def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: lineno = next(parser.stream).lineno body = parser.parse_statements(["name:endgeneration"], drop_needle=True) return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) @jinja2.pass_eval_context def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: rv = caller() if self.is_active(): # Only track generation indices if the tracker is active start_index = len("".join(self._rendered_blocks)) end_index = start_index + len(rv) self._generation_indices.append((start_index, end_index)) return rv def is_active(self) -> bool: return self._rendered_blocks or self._generation_indices @contextmanager def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]): try: if self.is_active(): raise ValueError("AssistantTracker should not be reused before closed") self._rendered_blocks = rendered_blocks self._generation_indices = generation_indices yield finally: self._rendered_blocks = None self._generation_indices = None if version.parse(jinja2.__version__) < version.parse("3.1.0"): raise ImportError( f"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}." ) def raise_exception(message): raise jinja2.exceptions.TemplateError(message) def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): # We override the built-in tojson filter because Jinja's default filter escapes HTML characters # We also expose some options like custom indents and separators return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) def strftime_now(format): return datetime.now().strftime(format) jinja_env = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] ) jinja_env.filters["tojson"] = tojson jinja_env.globals["raise_exception"] = raise_exception jinja_env.globals["strftime_now"] = strftime_now return jinja_env.from_string(chat_template) def render_jinja_template( conversations: list[list[dict[str, str]]], tools: Optional[list[Union[dict, Callable]]] = None, documents: Optional[list[dict[str, str]]] = None, chat_template: Optional[str] = None, return_assistant_tokens_mask: Optional[bool] = False, continue_final_message: Optional[bool] = False, add_generation_prompt: Optional[bool] = False, **kwargs, ) -> str: if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template): logger.warning_once( "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword." ) # Compilation function uses a cache to avoid recompiling the same template compiled_template = _compile_jinja_template(chat_template) # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas if tools is not None: tool_schemas = [] for tool in tools: if isinstance(tool, dict): tool_schemas.append(tool) elif isfunction(tool): tool_schemas.append(get_json_schema(tool)) else: raise ValueError( "Tools should either be a JSON schema, or a callable function with type hints " "and a docstring suitable for auto-conversion to a schema." ) else: tool_schemas = None if documents is not None: for document in documents: if not isinstance(document, dict): raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!") rendered = [] all_generation_indices = [] for chat in conversations: if hasattr(chat, "messages"): # Indicates it's a Conversation object chat = chat.messages if return_assistant_tokens_mask: rendered_chat, generation_indices = _render_with_assistant_indices( compiled_template=compiled_template, messages=chat, tools=tool_schemas, documents=documents, add_generation_prompt=add_generation_prompt, **kwargs, ) all_generation_indices.append(generation_indices) else: rendered_chat = compiled_template.render( messages=chat, tools=tool_schemas, documents=documents, add_generation_prompt=add_generation_prompt, **kwargs, ) if continue_final_message: final_message = chat[-1]["content"] if isinstance(final_message, (list, tuple)): for content_block in reversed(final_message): if "text" in content_block: # Pick the last text block in the message (the first one we hit while iterating in reverse) final_message = content_block["text"] break else: raise ValueError( "continue_final_message is set but we could not find any text to continuein the final message!" ) if final_message.strip() not in rendered_chat: raise ValueError( "continue_final_message is set but the final message does not appear in the chat after " "applying the chat template! This can happen if the chat template deletes portions of " "the final message. Please verify the chat template and final message in your chat to " "ensure they are compatible." ) final_msg_loc = rendered_chat.rindex(final_message.strip()) if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message: # The template preserves spacing or the message doesn't have trailing spacing, so things are simple rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())] else: # The message has trailing spacing that was trimmed, so we must be more cautious rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())] rendered.append(rendered_chat) return rendered, all_generation_indices