|
from typing import Dict, Any, List |
|
import yaml |
|
|
|
class SimplePromptChain: |
|
"""A flexible prompt chain implementation using an AIAssistant wrapper.""" |
|
|
|
def __init__(self, assistant: AIAssistant, prompts_path: str): |
|
""" |
|
Initialize chain with AI assistant and prompts. |
|
|
|
Args: |
|
assistant: Configured AIAssistant instance |
|
prompts_path: Path to YAML prompts file |
|
""" |
|
self.assistant = assistant |
|
self.prompts = PromptLoader.load_prompts(prompts_path) |
|
|
|
def execute_step(self, |
|
prompt_name: str, |
|
generation_params: Dict[str, Any] = None, |
|
variables: Dict[str, Any] = None) -> str: |
|
""" |
|
Execute single chain step using the AI assistant. |
|
|
|
Args: |
|
prompt_name: Name of prompt template to use |
|
generation_params: Optional parameters for generation |
|
variables: Variables to format the prompt |
|
|
|
Returns: |
|
Processed response content |
|
|
|
Raises: |
|
ValueError: If prompt template not found |
|
""" |
|
|
|
if prompt_name not in self.prompts: |
|
raise ValueError(f"Prompt '{prompt_name}' not found in loaded templates") |
|
|
|
prompt_template = self.prompts[prompt_name] |
|
|
|
try: |
|
|
|
response = self.assistant.generate_response( |
|
prompt_template=prompt_template, |
|
generation_params=generation_params, |
|
stream=True, |
|
**variables or {} |
|
) |
|
|
|
|
|
return response.choices[0].message.content |
|
|
|
except Exception as e: |
|
raise Exception(f"Error in step execution: {str(e)}") |
|
|
|
def run_chain(self, steps: List[Dict[str, Any]]) -> Dict[str, str]: |
|
""" |
|
Execute chain of prompts using the AI assistant. |
|
|
|
Args: |
|
steps: List of steps to execute, each containing: |
|
- prompt_name: Name of prompt template |
|
- variables: Variables for the prompt |
|
- output_key: Key to store step output |
|
- generation_params: Optional generation parameters |
|
|
|
Returns: |
|
Dict of step outputs keyed by output_key |
|
|
|
Example: |
|
steps = [ |
|
{ |
|
"prompt_name": "analyze", |
|
"variables": {"text": "Sample text"}, |
|
"output_key": "analysis", |
|
"generation_params": {"temperature": 0.7} |
|
}, |
|
{ |
|
"prompt_name": "summarize", |
|
"variables": {"text": "{analysis}"}, |
|
"output_key": "summary" |
|
} |
|
] |
|
""" |
|
results = {} |
|
|
|
for step in steps: |
|
prompt_name = step["prompt_name"] |
|
output_key = step["output_key"] |
|
generation_params = step.get("generation_params", None) |
|
|
|
|
|
variables = {} |
|
for key, value in step.get("variables", {}).items(): |
|
if isinstance(value, str) and value.startswith("{") and value.endswith("}"): |
|
|
|
ref_key = value[1:-1] |
|
if ref_key not in results: |
|
raise ValueError(f"Referenced output '{ref_key}' not found in previous results") |
|
variables[key] = results[ref_key] |
|
else: |
|
variables[key] = value |
|
|
|
|
|
print(f"\nExecuting step: {prompt_name}...") |
|
result = self.execute_step( |
|
prompt_name=prompt_name, |
|
generation_params=generation_params, |
|
variables=variables |
|
) |
|
results[output_key] = result |
|
|
|
return results |