Spaces:
Running
Running
File size: 2,627 Bytes
54f5afe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
"""Extract structured data from unstructured text with an LLM."""
from typing import Any, TypeVar
from litellm import completion
from pydantic import BaseModel, ValidationError
from raglite._config import RAGLiteConfig
T = TypeVar("T", bound=BaseModel)
def extract_with_llm(
return_type: type[T],
user_prompt: str | list[str],
config: RAGLiteConfig | None = None,
**kwargs: Any,
) -> T:
"""Extract structured data from unstructured text with an LLM.
This function expects a `return_type.system_prompt: ClassVar[str]` that contains the system
prompt to use. Example:
from typing import ClassVar
from pydantic import BaseModel, Field
class MyNameResponse(BaseModel):
my_name: str = Field(..., description="The user's name.")
system_prompt: ClassVar[str] = "The system prompt to use (excluded from JSON schema)."
my_name_response = extract_with_llm(MyNameResponse, "My name is Thomas A. Anderson.")
"""
# Load the default config if not provided.
config = config or RAGLiteConfig()
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = (
return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined]
"Format your response according to this JSON schema:\n",
return_type.model_json_schema(),
)
# Concatenate the user prompt if it is a list of strings.
if isinstance(user_prompt, list):
user_prompt = "\n\n".join(
f'<context index="{i}">\n{chunk.strip()}\n</context>'
for i, chunk in enumerate(user_prompt)
)
# Extract structured data from the unstructured input.
for _ in range(config.llm_max_tries):
response = completion(
model=config.llm,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object", "schema": return_type.model_json_schema()},
**kwargs,
)
try:
instance = return_type.model_validate_json(response["choices"][0]["message"]["content"])
except (KeyError, ValueError, ValidationError) as e:
# Malformed response, not a JSON string, or not a valid instance of the return type.
last_exception = e
continue
else:
break
else:
error_message = f"Failed to extract {return_type} from input {user_prompt}."
raise ValueError(error_message) from last_exception
return instance
|