File size: 6,198 Bytes
0fd775b 0dd8117 eda2dbf 0fd775b c90c7ed 42f3ad6 0fd775b eda2dbf c90c7ed eda2dbf fd137e6 eda2dbf fd137e6 eda2dbf fd137e6 eda2dbf 0dd8117 eda2dbf fd137e6 0fd775b fd137e6 eda2dbf 7697929 eda2dbf a573f7f eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 93d55ac eda2dbf 93d55ac eda2dbf 93d55ac eda2dbf 9a56c61 eda2dbf 9a56c61 eda2dbf 9a56c61 eda2dbf 9a56c61 0dd8117 eda2dbf 0dd8117 9a56c61 0dd8117 eda2dbf bbcf582 eda2dbf bbcf582 eda2dbf d0568c3 dba44ef 3d91d29 d0568c3 760dc49 d0568c3 eda2dbf 5283206 d0568c3 5283206 eda2dbf 0dd8117 eda2dbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
from dotenv import load_dotenv
import os
from pathlib import Path
import gradio as gr
from huggingface_hub import InferenceClient
from openai import OpenAI
from prompt_template import PromptTemplate, PromptLoader
from assistant import AIAssistant
# Load .env file
load_dotenv()
API_KEY = os.getenv('API_KEY')
# Load prompts from YAML
prompts = PromptLoader.load_prompts("prompts.yaml")
# Available models and their configurations
MODELS = {
"Llama 3.3 70B Instruct": {
"name": "meta/llama-3.3-70b-instruct",
},
"Llama 3.1 405B Instruct": {
"name": "meta/llama-3.1-405b-instruct",
},
"Llama 3.2 3B Instruct": {
"name": "meta/llama-3.2-3b-instruct",
},
"Falcon 3 7B Instruct": {
"name": "tiiuae/falcon3-7b-instruct",
},
"Granite 3.0 8B Instruct": {
"name": "ibm/granite-3.0-8b-instruct",
}
}
# Available prompt strategies
PROMPT_STRATEGIES = {
"Default": "system_context",
"Chain of Thought": "cot_prompt",
"Knowledge-based": "knowledge_prompt",
"Few-shot Learning": "few_shot_prompt",
"Meta-prompting": "meta_prompt"
}
def create_assistant(model_name):
client = OpenAI(
base_url = "https://integrate.api.nvidia.com/v1",
api_key = API_KEY
)
# Should use MODELS dictionary to get the actual model name
model_name = MODELS[model_name]["name"] # Add this line
return AIAssistant(
client=client,
model=model_name
)
def respond(
message,
history: list[tuple[str, str]],
model_name,
prompt_strategy,
override_params: bool,
max_tokens,
temperature,
top_p,
):
assistant = create_assistant(model_name)
# Get selected prompt template and system context
prompt_template: PromptTemplate = prompts[PROMPT_STRATEGIES[prompt_strategy]]
system_context: PromptTemplate = prompts["system_context"]
# Format system context with the selected prompt strategy
formatted_system_message = system_context.format(prompt_strategy=prompt_template.template)
# Prepare messages with proper format
messages = [{"role": "system", "content": formatted_system_message}]
# Add conversation history
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": str(user_msg)})
if assistant_msg:
messages.append({"role": "assistant", "content": str(assistant_msg)})
# Add current message
messages.append({"role": "user", "content": str(message)})
# Get generation parameters
generation_params = prompt_template.parameters if not override_params else {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
try:
for response in assistant.generate_response(
prompt_template=prompt_template,
generation_params=generation_params,
stream=True,
messages=messages
):
yield response
except Exception as e:
yield f"Error: {str(e)}"
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model"
)
prompt_strategy_dropdown = gr.Dropdown(
choices=list(PROMPT_STRATEGIES.keys()),
value=list(PROMPT_STRATEGIES.keys())[0],
label="Select Prompt Strategy"
)
with gr.Row():
override_params = gr.Checkbox(
label="Override Template Parameters",
value=False
)
with gr.Row():
with gr.Column(visible=False) as param_controls:
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
)
chatbot = gr.ChatInterface(
fn=respond,
additional_inputs=[
model_dropdown,
prompt_strategy_dropdown,
override_params,
max_tokens,
temperature,
top_p,
]
)
# Parameters and Prompt Details section below the chat
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=300):
with gr.Accordion("Current Prompt Details", open=False):
system_prompt_display = gr.TextArea(
label="System Prompt",
interactive=False,
lines=20
)
current_messages_display = gr.JSON(
label="Full Conversation Context",
)
def toggle_param_controls(override):
return gr.Column(visible=override)
def update_prompt_display(prompt_strategy):
prompt_template = prompts[PROMPT_STRATEGIES[prompt_strategy]]
system_context = prompts["system_context"]
formatted_system_message = system_context.format(prompt_strategy=prompt_template.template)
return (
formatted_system_message,
{
"Template Parameters": prompt_template.parameters,
"Prompt Strategy": prompt_template.template
}
)
# Update prompt display when strategy changes
prompt_strategy_dropdown.change(
update_prompt_display,
inputs=[prompt_strategy_dropdown],
outputs=[system_prompt_display, current_messages_display]
)
override_params.change(
toggle_param_controls,
inputs=[override_params],
outputs=[param_controls]
)
if __name__ == "__main__":
demo.launch() |