Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/transformers
/utils
/chat_template_utils.py
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
import json | |
import re | |
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints | |
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) | |
# Extracts the initial segment of the docstring, containing the function description | |
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) | |
# Extracts the Args: block from the docstring | |
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) | |
# Splits the Args: block into individual arguments | |
args_split_re = re.compile( | |
r""" | |
(?:^|\n) # Match the start of the args block, or a newline | |
\s*(\w+):\s* # Capture the argument name and strip spacing | |
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing | |
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block | |
""", | |
re.DOTALL | re.VERBOSE, | |
) | |
# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc! | |
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) | |
class TypeHintParsingException(Exception): | |
"""Exception raised for errors in parsing type hints to generate JSON schemas""" | |
pass | |
class DocstringParsingException(Exception): | |
"""Exception raised for errors in parsing docstrings to generate JSON schemas""" | |
pass | |
def _get_json_schema_type(param_type: str) -> Dict[str, str]: | |
type_mapping = { | |
int: {"type": "integer"}, | |
float: {"type": "number"}, | |
str: {"type": "string"}, | |
bool: {"type": "boolean"}, | |
Any: {}, | |
} | |
return type_mapping.get(param_type, {"type": "object"}) | |
def _parse_type_hint(hint: str) -> Dict: | |
origin = get_origin(hint) | |
args = get_args(hint) | |
if origin is None: | |
try: | |
return _get_json_schema_type(hint) | |
except KeyError: | |
raise TypeHintParsingException( | |
"Couldn't parse this type hint, likely due to a custom class or object: ", hint | |
) | |
elif origin is Union: | |
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end | |
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)] | |
if len(subtypes) == 1: | |
# A single non-null type can be expressed directly | |
return_dict = subtypes[0] | |
elif all(isinstance(subtype["type"], str) for subtype in subtypes): | |
# A union of basic types can be expressed as a list in the schema | |
return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])} | |
else: | |
# A union of more complex types requires "anyOf" | |
return_dict = {"anyOf": subtypes} | |
if type(None) in args: | |
return_dict["nullable"] = True | |
return return_dict | |
elif origin is list: | |
if not args: | |
return {"type": "array"} | |
else: | |
# Lists can only have a single type argument, so recurse into it | |
return {"type": "array", "items": _parse_type_hint(args[0])} | |
elif origin is tuple: | |
if not args: | |
return {"type": "array"} | |
if len(args) == 1: | |
raise TypeHintParsingException( | |
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which " | |
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain " | |
"more than one element, we recommend " | |
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just " | |
"pass the element directly." | |
) | |
if ... in args: | |
raise TypeHintParsingException( | |
"Conversion of '...' is not supported in Tuple type hints. " | |
"Use List[] types for variable-length" | |
" inputs instead." | |
) | |
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} | |
elif origin is dict: | |
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings | |
# However, we can specify the type of the dict values with "additionalProperties" | |
out = {"type": "object"} | |
if len(args) == 2: | |
out["additionalProperties"] = _parse_type_hint(args[1]) | |
return out | |
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint) | |
def _convert_type_hints_to_json_schema(func: Callable) -> Dict: | |
type_hints = get_type_hints(func) | |
signature = inspect.signature(func) | |
required = [] | |
for param_name, param in signature.parameters.items(): | |
if param.annotation == inspect.Parameter.empty: | |
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") | |
if param.default == inspect.Parameter.empty: | |
required.append(param_name) | |
properties = {} | |
for param_name, param_type in type_hints.items(): | |
properties[param_name] = _parse_type_hint(param_type) | |
schema = {"type": "object", "properties": properties} | |
if required: | |
schema["required"] = required | |
return schema | |
def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]: | |
""" | |
Parses a Google-style docstring to extract the function description, | |
argument descriptions, and return description. | |
Args: | |
docstring (str): The docstring to parse. | |
Returns: | |
The function description, arguments, and return description. | |
""" | |
# Extract the sections | |
description_match = description_re.search(docstring) | |
args_match = args_re.search(docstring) | |
returns_match = returns_re.search(docstring) | |
# Clean and store the sections | |
description = description_match.group(1).strip() if description_match else None | |
docstring_args = args_match.group(1).strip() if args_match else None | |
returns = returns_match.group(1).strip() if returns_match else None | |
# Parsing the arguments into a dictionary | |
if docstring_args is not None: | |
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines | |
matches = args_split_re.findall(docstring_args) | |
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches} | |
else: | |
args_dict = {} | |
return description, args_dict, returns | |
def get_json_schema(func: Callable) -> Dict: | |
""" | |
This function generates a JSON schema for a given function, based on its docstring and type hints. This is | |
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of | |
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires | |
that the function has a docstring, and that each argument has a description in the docstring, in the standard | |
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint. | |
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is | |
optional because most chat templates ignore the return value of the function. | |
Args: | |
func: The function to generate a JSON schema for. | |
Returns: | |
A dictionary containing the JSON schema for the function. | |
Examples: | |
```python | |
>>> def multiply(x: float, y: float): | |
>>> ''' | |
>>> A function that multiplies two numbers | |
>>> | |
>>> Args: | |
>>> x: The first number to multiply | |
>>> y: The second number to multiply | |
>>> ''' | |
>>> return x * y | |
>>> | |
>>> print(get_json_schema(multiply)) | |
{ | |
"name": "multiply", | |
"description": "A function that multiplies two numbers", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"x": {"type": "number", "description": "The first number to multiply"}, | |
"y": {"type": "number", "description": "The second number to multiply"} | |
}, | |
"required": ["x", "y"] | |
} | |
} | |
``` | |
The general use for these schemas is that they are used to generate tool descriptions for chat templates that | |
support them, like so: | |
```python | |
>>> from transformers import AutoTokenizer | |
>>> from transformers.utils import get_json_schema | |
>>> | |
>>> def multiply(x: float, y: float): | |
>>> ''' | |
>>> A function that multiplies two numbers | |
>>> | |
>>> Args: | |
>>> x: The first number to multiply | |
>>> y: The second number to multiply | |
>>> return x * y | |
>>> ''' | |
>>> | |
>>> multiply_schema = get_json_schema(multiply) | |
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") | |
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] | |
>>> formatted_chat = tokenizer.apply_chat_template( | |
>>> messages, | |
>>> tools=[multiply_schema], | |
>>> chat_template="tool_use", | |
>>> return_dict=True, | |
>>> return_tensors="pt", | |
>>> add_generation_prompt=True | |
>>> ) | |
>>> # The formatted chat can now be passed to model.generate() | |
``` | |
Each argument description can also have an optional `(choices: ...)` block at the end, such as | |
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will | |
only be parsed correctly if it is at the end of the line: | |
```python | |
>>> def drink_beverage(beverage: str): | |
>>> ''' | |
>>> A function that drinks a beverage | |
>>> | |
>>> Args: | |
>>> beverage: The beverage to drink (choices: ["tea", "coffee"]) | |
>>> ''' | |
>>> pass | |
>>> | |
>>> print(get_json_schema(drink_beverage)) | |
``` | |
{ | |
'name': 'drink_beverage', | |
'description': 'A function that drinks a beverage', | |
'parameters': { | |
'type': 'object', | |
'properties': { | |
'beverage': { | |
'type': 'string', | |
'enum': ['tea', 'coffee'], | |
'description': 'The beverage to drink' | |
} | |
}, | |
'required': ['beverage'] | |
} | |
} | |
""" | |
doc = inspect.getdoc(func) | |
if not doc: | |
raise DocstringParsingException( | |
f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" | |
) | |
doc = doc.strip() | |
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) | |
json_schema = _convert_type_hints_to_json_schema(func) | |
if (return_dict := json_schema["properties"].pop("return", None)) is not None: | |
if return_doc is not None: # We allow a missing return docstring since most templates ignore it | |
return_dict["description"] = return_doc | |
for arg, schema in json_schema["properties"].items(): | |
if arg not in param_descriptions: | |
raise DocstringParsingException( | |
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" | |
) | |
desc = param_descriptions[arg] | |
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) | |
if enum_choices: | |
schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] | |
desc = enum_choices.string[: enum_choices.start()].strip() | |
schema["description"] = desc | |
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} | |
if return_dict is not None: | |
output["return"] = return_dict | |
return {"type": "function", "function": output} | |