File size: 4,616 Bytes
0dd8117 eda2dbf c90c7ed 42f3ad6 eda2dbf c90c7ed eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf 0dd8117 eda2dbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import gradio as gr
from huggingface_hub import InferenceClient
from openai import OpenAI
from prompt_template import PromptTemplate, PromptLoader
from assistant import AIAssistant
from pathlib import Path
# Load prompts from YAML
prompts = PromptLoader.load_prompts("prompts.yaml")
# Available models and their configurations
MODELS = {
"Zephyr 7B Beta": {
"name": "HuggingFaceH4/zephyr-7b-beta",
"provider": "huggingface"
},
"Mistral 7B": {
"name": "mistralai/Mistral-7B-v0.1",
"provider": "huggingface"
},
"GPT-3.5 Turbo": {
"name": "gpt-3.5-turbo",
"provider": "openai"
}
}
# Available prompt strategies
PROMPT_STRATEGIES = {
"Default": "system_context",
"Chain of Thought": "cot_prompt",
"Knowledge-based": "knowledge_prompt",
"Few-shot Learning": "few_shot_prompt",
"Meta-prompting": "meta_prompt"
}
def create_assistant(model_name):
model_info = MODELS[model_name]
if model_info["provider"] == "huggingface":
client = InferenceClient(model_info["name"])
else: # OpenAI
client = OpenAI()
return AIAssistant(
client=client,
model=model_info["name"]
)
def respond(
message,
history: list[tuple[str, str]],
model_name,
prompt_strategy,
system_message,
override_params: bool,
max_tokens,
temperature,
top_p,
):
assistant = create_assistant(model_name)
# Get prompt template
prompt_template: PromptTemplate = prompts[PROMPT_STRATEGIES[prompt_strategy]]
# Generate system message using prompt template
formatted_system_message = prompt_template.format(prompt_strategy=system_message)
# Prepare messages
messages = [{"role": "system", "content": formatted_system_message}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# Get generation parameters
generation_params = prompt_template.parameters if not override_params else {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
# Generate response using the assistant
for response in assistant.generate_response(
prompt_template=prompt_template,
generation_params=generation_params,
stream=True,
messages=messages
):
yield response
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model"
)
prompt_strategy_dropdown = gr.Dropdown(
choices=list(PROMPT_STRATEGIES.keys()),
value=list(PROMPT_STRATEGIES.keys())[0],
label="Select Prompt Strategy"
)
system_message = gr.Textbox(
value="You are a friendly and helpful AI assistant.",
label="System Message"
)
with gr.Row():
override_params = gr.Checkbox(
label="Override Template Parameters",
value=False
)
with gr.Row():
with gr.Column(visible=False) as param_controls:
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
)
chatbot = gr.ChatInterface(
fn=respond,
additional_inputs=[
model_dropdown,
prompt_strategy_dropdown,
system_message,
override_params,
max_tokens,
temperature,
top_p,
]
)
def toggle_param_controls(override):
return gr.Column(visible=override)
override_params.change(
toggle_param_controls,
inputs=[override_params],
outputs=[param_controls]
)
if __name__ == "__main__":
demo.launch() |