Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import sys | |
sys.path.append("..") | |
import json | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union | |
import yaml | |
from pydantic import Field | |
from ....base import BotBase | |
from .formatter import FStringFormatter, JinjiaFormatter | |
from .parser import BaseOutputParser, DictParser, ListParser, StrParser | |
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { | |
"f-string": FStringFormatter(), | |
"jinja2": JinjiaFormatter(), | |
} | |
_OUTPUT_PARSER = { | |
"StrParser": StrParser, | |
"ListParser": ListParser, | |
"DictParser": DictParser, | |
} | |
def check_valid_template( | |
template: str, template_format: str, input_variables: List[str] | |
) -> None: | |
"""Check that template string is valid.""" | |
if template_format not in DEFAULT_FORMATTER_MAPPING: | |
valid_formats = list(DEFAULT_FORMATTER_MAPPING) | |
raise ValueError( | |
f"Invalid template format. Got `{template_format}`;" | |
f" should be one of {valid_formats}" | |
) | |
try: | |
formatter = DEFAULT_FORMATTER_MAPPING[template_format] | |
formatter.validate(template, input_variables) | |
except KeyError as e: | |
raise ValueError( | |
"Invalid prompt schema; check for mismatched or missing input parameters. " | |
+ str(e) | |
) | |
def _get_jinja2_variables_from_template(template: str) -> Set[str]: | |
try: | |
from jinja2 import Environment, meta | |
except ImportError: | |
raise ImportError( | |
"jinja2 not installed, which is needed to use the jinja2_formatter. " | |
"Please install it with `pip install jinja2`." | |
) | |
env = Environment() | |
ast = env.parse(template) | |
variables = meta.find_undeclared_variables(ast) | |
return variables | |
class BasePromptTemplate(BotBase, ABC): | |
"""Base class for all prompt templates, returning a prompt.""" | |
input_variables: List[str] | |
"""A list of the names of the variables the prompt template expects.""" | |
output_parser: Optional[BaseOutputParser] = None | |
"""How to parse the output of calling an LLM on this formatted prompt.""" | |
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( | |
default_factory=dict | |
) | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = "forbid" | |
arbitrary_types_allowed = True | |
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: | |
"""Return a partial of the prompt template.""" | |
prompt_dict = self.__dict__.copy() | |
prompt_dict["input_variables"] = list( | |
set(self.input_variables).difference(kwargs) | |
) | |
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} | |
return type(self)(**prompt_dict) | |
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: | |
# Get partial params: | |
partial_kwargs = { | |
k: v if isinstance(v, str) else v() | |
for k, v in self.partial_variables.items() | |
} | |
return {**partial_kwargs, **kwargs} | |
def format(self, **kwargs: Any) -> str: | |
"""Format the prompt with the inputs. | |
Args: | |
kwargs: Any arguments to be passed to the prompt template. | |
Returns: | |
A formatted string. | |
Example: | |
.. code-block:: python | |
prompt.format(variable1="foo") | |
""" | |
def save(self, file_path: Union[Path, str]) -> None: | |
"""Save the prompt. | |
Args: | |
file_path: Path to directory to save prompt to. | |
Example: | |
.. code-block:: python | |
prompt.save(file_path="path/prompt.yaml") | |
""" | |
if self.partial_variables: | |
raise ValueError("Cannot save prompt with partial variables.") | |
# Convert file to Path object. | |
if isinstance(file_path, str): | |
save_path = Path(file_path) | |
else: | |
save_path = file_path | |
directory_path = save_path.parent | |
directory_path.mkdir(parents=True, exist_ok=True) | |
# Fetch dictionary to save | |
prompt_dict = self.dict() | |
if save_path.suffix == ".json": | |
with open(file_path, "w") as f: | |
json.dump(prompt_dict, f, indent=4) | |
elif save_path.suffix == ".yaml": | |
with open(file_path, "w") as f: | |
yaml.dump(prompt_dict, f, default_flow_style=False) | |
else: | |
raise ValueError(f"{save_path} must be json or yaml") | |
def from_template(cls, template: str, **kwargs: Any) -> BasePromptTemplate: | |
"""Create a prompt from a template.""" | |
def from_config(cls, config: Dict) -> BasePromptTemplate: | |
"""Create a prompt from config.""" | |