harpreetsahota's picture
Update app.py
1506bd9 verified
from dotenv import load_dotenv
import os
from pathlib import Path
import gradio as gr
from huggingface_hub import InferenceClient
from openai import OpenAI
from prompt_template import PromptTemplate, PromptLoader
from assistant import AIAssistant
# Load .env file
load_dotenv()
API_KEY = os.getenv('API_KEY')
# Load prompts from YAML
prompts = PromptLoader.load_prompts("prompts.yaml")
# Available models and their configurations
MODELS = {
"Llama 3.3 70B Instruct": {
"name": "meta/llama-3.3-70b-instruct",
},
"Llama 3.1 405B Instruct": {
"name": "meta/llama-3.1-405b-instruct",
},
"Llama 3.2 3B Instruct": {
"name": "meta/llama-3.2-3b-instruct",
},
"Falcon 3 7B Instruct": {
"name": "tiiuae/falcon3-7b-instruct",
},
"Granite 3.0 8B Instruct": {
"name": "ibm/granite-3.0-8b-instruct",
}
}
# Available prompt strategies
PROMPT_STRATEGIES = {
"Default": "system_context",
"Chain of Thought": "cot_prompt",
"Knowledge-based": "knowledge_prompt",
"Few-shot Learning": "few_shot_prompt",
"Meta-prompting": "meta_prompt"
}
def create_assistant(model_name):
client = OpenAI(
base_url = "https://integrate.api.nvidia.com/v1",
api_key = API_KEY
)
# Should use MODELS dictionary to get the actual model name
model_name = MODELS[model_name]["name"] # Add this line
return AIAssistant(
client=client,
model=model_name
)
def respond(
message,
history: list[tuple[str, str]],
model_name,
prompt_strategy,
override_params: bool,
max_tokens,
temperature,
top_p,
):
assistant = create_assistant(model_name)
# Get selected prompt template and system context
prompt_template: PromptTemplate = prompts[PROMPT_STRATEGIES[prompt_strategy]]
system_context: PromptTemplate = prompts["system_context"]
# Format system context with the selected prompt strategy
formatted_system_message = system_context.format(prompt_strategy=prompt_template.template)
# Prepare messages with proper format
messages = [{"role": "system", "content": formatted_system_message}]
# Add conversation history
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": str(user_msg)})
if assistant_msg:
messages.append({"role": "assistant", "content": str(assistant_msg)})
# Add current message
messages.append({"role": "user", "content": str(message)})
# Get generation parameters
generation_params = prompt_template.parameters if not override_params else {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
try:
for response in assistant.generate_response(
prompt_template=prompt_template,
generation_params=generation_params,
stream=True,
messages=messages
):
yield response
except Exception as e:
yield f"Error: {str(e)}"
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model"
)
prompt_strategy_dropdown = gr.Dropdown(
choices=list(PROMPT_STRATEGIES.keys()),
value=list(PROMPT_STRATEGIES.keys())[0],
label="Select Prompt Strategy"
)
with gr.Row():
override_params = gr.Checkbox(
label="Override Template Parameters",
value=False
)
with gr.Row():
with gr.Column(visible=False) as param_controls:
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
)
chatbot = gr.ChatInterface(
fn=respond,
additional_inputs=[
model_dropdown,
prompt_strategy_dropdown,
override_params,
max_tokens,
temperature,
top_p,
]
)
# Parameters and Prompt Details section below the chat
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=300):
with gr.Accordion("Current Prompt Details", open=False):
system_prompt_display = gr.TextArea(
label="System Prompt",
interactive=False,
lines=20
)
current_messages_display = gr.JSON(
label="Full Conversation Context",
)
def toggle_param_controls(override):
return gr.Column(visible=override)
def update_prompt_display(prompt_strategy):
prompt_template = prompts[PROMPT_STRATEGIES[prompt_strategy]]
system_context = prompts["system_context"]
formatted_system_message = system_context.format(prompt_strategy=prompt_template.template)
return (
formatted_system_message,
{
"Template Parameters": prompt_template.parameters,
"Prompt Strategy": prompt_template.template
}
)
# Update prompt display when strategy changes
prompt_strategy_dropdown.change(
update_prompt_display,
inputs=[prompt_strategy_dropdown],
outputs=[system_prompt_display, current_messages_display]
)
override_params.change(
toggle_param_controls,
inputs=[override_params],
outputs=[param_controls]
)
if __name__ == "__main__":
demo.launch()