Update prompt_template.py
Browse files- prompt_template.py +34 -2
prompt_template.py
CHANGED
@@ -3,8 +3,6 @@ import yaml
|
|
3 |
from dataclasses import dataclass, field
|
4 |
from typing import Dict, List, Set, Optional
|
5 |
|
6 |
-
from openai import OpenAI
|
7 |
-
|
8 |
@dataclass
|
9 |
class PromptTemplate:
|
10 |
"""
|
@@ -84,3 +82,37 @@ class PromptTemplate:
|
|
84 |
{"role": "system", "content": self.system_prompt},
|
85 |
{"role": "user", "content": formatted_user_message}
|
86 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from dataclasses import dataclass, field
|
4 |
from typing import Dict, List, Set, Optional
|
5 |
|
|
|
|
|
6 |
@dataclass
|
7 |
class PromptTemplate:
|
8 |
"""
|
|
|
82 |
{"role": "system", "content": self.system_prompt},
|
83 |
{"role": "user", "content": formatted_user_message}
|
84 |
]
|
85 |
+
|
86 |
+
|
87 |
+
def load_prompt(yaml_path: str, version: str = None) -> tuple[PromptTemplate, dict]:
|
88 |
+
"""
|
89 |
+
Load prompt configuration from YAML file.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
yaml_path: Path to YAML configuration file
|
93 |
+
version: Specific version to load (defaults to 'current_version')
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
tuple: (PromptTemplate instance, generation parameters dictionary)
|
97 |
+
|
98 |
+
Example:
|
99 |
+
prompt, params = load_prompt('prompts.yaml', version='v2')
|
100 |
+
"""
|
101 |
+
with open(yaml_path, 'r') as f:
|
102 |
+
data = yaml.safe_load(f)
|
103 |
+
|
104 |
+
# Use specified version or fall back to current_version
|
105 |
+
version_to_use = version or data.get('current_version')
|
106 |
+
if version_to_use not in data:
|
107 |
+
raise KeyError(f"Version '{version_to_use}' not found in {yaml_path}")
|
108 |
+
|
109 |
+
version_data = data[version_to_use]
|
110 |
+
|
111 |
+
prompt = PromptTemplate(
|
112 |
+
system_prompt=version_data['system_prompt'],
|
113 |
+
user_template=version_data['user_template']
|
114 |
+
)
|
115 |
+
|
116 |
+
generation_params = version_data.get('generation_params', {})
|
117 |
+
|
118 |
+
return prompt, generation_params
|