harpreetsahota's picture
Update app.py
eda2dbf verified
raw
history blame
4.62 kB
import gradio as gr
from huggingface_hub import InferenceClient
from openai import OpenAI
from prompt_template import PromptTemplate, PromptLoader
from assistant import AIAssistant
from pathlib import Path
# Load prompts from YAML
prompts = PromptLoader.load_prompts("prompts.yaml")
# Available models and their configurations
MODELS = {
"Zephyr 7B Beta": {
"name": "HuggingFaceH4/zephyr-7b-beta",
"provider": "huggingface"
},
"Mistral 7B": {
"name": "mistralai/Mistral-7B-v0.1",
"provider": "huggingface"
},
"GPT-3.5 Turbo": {
"name": "gpt-3.5-turbo",
"provider": "openai"
}
}
# Available prompt strategies
PROMPT_STRATEGIES = {
"Default": "system_context",
"Chain of Thought": "cot_prompt",
"Knowledge-based": "knowledge_prompt",
"Few-shot Learning": "few_shot_prompt",
"Meta-prompting": "meta_prompt"
}
def create_assistant(model_name):
model_info = MODELS[model_name]
if model_info["provider"] == "huggingface":
client = InferenceClient(model_info["name"])
else: # OpenAI
client = OpenAI()
return AIAssistant(
client=client,
model=model_info["name"]
)
def respond(
message,
history: list[tuple[str, str]],
model_name,
prompt_strategy,
system_message,
override_params: bool,
max_tokens,
temperature,
top_p,
):
assistant = create_assistant(model_name)
# Get prompt template
prompt_template: PromptTemplate = prompts[PROMPT_STRATEGIES[prompt_strategy]]
# Generate system message using prompt template
formatted_system_message = prompt_template.format(prompt_strategy=system_message)
# Prepare messages
messages = [{"role": "system", "content": formatted_system_message}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# Get generation parameters
generation_params = prompt_template.parameters if not override_params else {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
# Generate response using the assistant
for response in assistant.generate_response(
prompt_template=prompt_template,
generation_params=generation_params,
stream=True,
messages=messages
):
yield response
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model"
)
prompt_strategy_dropdown = gr.Dropdown(
choices=list(PROMPT_STRATEGIES.keys()),
value=list(PROMPT_STRATEGIES.keys())[0],
label="Select Prompt Strategy"
)
system_message = gr.Textbox(
value="You are a friendly and helpful AI assistant.",
label="System Message"
)
with gr.Row():
override_params = gr.Checkbox(
label="Override Template Parameters",
value=False
)
with gr.Row():
with gr.Column(visible=False) as param_controls:
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
)
chatbot = gr.ChatInterface(
fn=respond,
additional_inputs=[
model_dropdown,
prompt_strategy_dropdown,
system_message,
override_params,
max_tokens,
temperature,
top_p,
]
)
def toggle_param_controls(override):
return gr.Column(visible=override)
override_params.change(
toggle_param_controls,
inputs=[override_params],
outputs=[param_controls]
)
if __name__ == "__main__":
demo.launch()