Spaces:
Running
Running
"""Extract structured data from unstructured text with an LLM.""" | |
from typing import Any, TypeVar | |
from litellm import completion | |
from pydantic import BaseModel, ValidationError | |
from raglite._config import RAGLiteConfig | |
T = TypeVar("T", bound=BaseModel) | |
def extract_with_llm( | |
return_type: type[T], | |
user_prompt: str | list[str], | |
config: RAGLiteConfig | None = None, | |
**kwargs: Any, | |
) -> T: | |
"""Extract structured data from unstructured text with an LLM. | |
This function expects a `return_type.system_prompt: ClassVar[str]` that contains the system | |
prompt to use. Example: | |
from typing import ClassVar | |
from pydantic import BaseModel, Field | |
class MyNameResponse(BaseModel): | |
my_name: str = Field(..., description="The user's name.") | |
system_prompt: ClassVar[str] = "The system prompt to use (excluded from JSON schema)." | |
my_name_response = extract_with_llm(MyNameResponse, "My name is Thomas A. Anderson.") | |
""" | |
# Load the default config if not provided. | |
config = config or RAGLiteConfig() | |
# Update the system prompt with the JSON schema of the return type to help the LLM. | |
system_prompt = ( | |
return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined] | |
"Format your response according to this JSON schema:\n", | |
return_type.model_json_schema(), | |
) | |
# Concatenate the user prompt if it is a list of strings. | |
if isinstance(user_prompt, list): | |
user_prompt = "\n\n".join( | |
f'<context index="{i}">\n{chunk.strip()}\n</context>' | |
for i, chunk in enumerate(user_prompt) | |
) | |
# Extract structured data from the unstructured input. | |
for _ in range(config.llm_max_tries): | |
response = completion( | |
model=config.llm, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
], | |
response_format={"type": "json_object", "schema": return_type.model_json_schema()}, | |
**kwargs, | |
) | |
try: | |
instance = return_type.model_validate_json(response["choices"][0]["message"]["content"]) | |
except (KeyError, ValueError, ValidationError) as e: | |
# Malformed response, not a JSON string, or not a valid instance of the return type. | |
last_exception = e | |
continue | |
else: | |
break | |
else: | |
error_message = f"Failed to extract {return_type} from input {user_prompt}." | |
raise ValueError(error_message) from last_exception | |
return instance | |