Spaces:
Running
Running
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
phi4_model_path = "microsoft/phi-4" | |
phi4_mini_model_path = "microsoft/Phi-4-mini-instruct" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, torch_dtype="auto").to(device) | |
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
phi4_mini_model = AutoModelForCausalLM.from_pretrained(phi4_mini_model_path, torch_dtype="auto").to(device) | |
phi4_mini_tokenizer = AutoTokenizer.from_pretrained(phi4_mini_model_path) | |
def generate_response(user_message, model_name, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): | |
if not user_message.strip(): | |
return history_state, history_state | |
# Select models | |
if model_name == "Phi-4": | |
model = phi4_model | |
tokenizer = phi4_tokenizer | |
start_tag = "<|im_start|>" | |
sep_tag = "<|im_sep|>" | |
end_tag = "<|im_end|>" | |
elif model_name == "Phi-4-mini-instruct": | |
model = phi4_mini_model | |
tokenizer = phi4_mini_tokenizer | |
start_tag = "" | |
sep_tag = "" | |
end_tag = "<|end|>" | |
else: | |
raise ValueError("Invalid model selected") | |
# Recommended prompt settings by Microsoft | |
system_message = "You are a friendly and knowledgeable assistant, here to help with any questions or tasks." | |
if model_name == "Phi-4": | |
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
for message in history_state: | |
if message["role"] == "user": | |
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}" | |
elif message["role"] == "assistant" and message["content"]: | |
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}" | |
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
else: | |
prompt = f"<|system|>{system_message}{end_tag}" | |
for message in history_state: | |
if message["role"] == "user": | |
prompt += f"<|user|>{message['content']}{end_tag}" | |
elif message["role"] == "assistant" and message["content"]: | |
prompt += f"<|assistant|>{message['content']}{end_tag}" | |
prompt += f"<|user|>{user_message}{end_tag}<|assistant|>" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
do_sample = not (temperature == 1.0 and top_k >= 100 and top_p == 1.0) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
# sampling techniques | |
generation_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens": int(max_tokens), | |
"do_sample": do_sample, | |
"temperature": temperature, | |
"top_k": int(top_k), | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"streamer": streamer, | |
} | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Stream the response | |
assistant_response = "" | |
new_history = history_state + [ | |
{"role": "user", "content": user_message}, | |
{"role": "assistant", "content": ""} | |
] | |
for new_token in streamer: | |
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "").replace("<|end|>", "").replace("<|system|>", "").replace("<|user|>", "").replace("<|assistant|>", "") | |
assistant_response += cleaned_token | |
new_history[-1]["content"] = assistant_response.strip() | |
yield new_history, new_history | |
yield new_history, new_history | |
example_messages = { | |
"Learn about physics": "Explain Newton’s laws of motion.", | |
"Discover space facts": "What are some interesting facts about black holes?", | |
"Write a factorial function": "Write a Python function to calculate the factorial of a number." | |
} | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Phi-4 Chatbot Demo | |
Welcome to the Phi-4 Chatbot Demo! You can chat with Microsoft's Phi-4 or Phi-4-mini-instruct models. Adjust the settings on the left to customize the model's responses. | |
""" | |
) | |
history_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Settings") | |
model_dropdown = gr.Dropdown( | |
choices=["Phi-4", "Phi-4-mini-instruct"], | |
label="Select Model", | |
value="Phi-4" | |
) | |
max_tokens_slider = gr.Slider( | |
minimum=64, | |
maximum=4096, | |
step=50, | |
value=512, | |
label="Max Tokens" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature_slider = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1.0, | |
label="Temperature" | |
) | |
top_k_slider = gr.Slider( | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
label="Top-k" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
label="Top-p" | |
) | |
repetition_penalty_slider = gr.Slider( | |
minimum=1.0, | |
maximum=2.0, | |
value=1.0, | |
label="Repetition Penalty" | |
) | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(label="Chat", type="messages") | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
scale=3 | |
) | |
submit_button = gr.Button("Send", variant="primary", scale=1) | |
clear_button = gr.Button("Clear", scale=1) | |
gr.Markdown("**Try these examples:**") | |
with gr.Row(): | |
example1_button = gr.Button("Learn about physics") | |
example2_button = gr.Button("Discover space facts") | |
example3_button = gr.Button("Write a factorial function") | |
submit_button.click( | |
fn=generate_response, | |
inputs=[user_input, model_dropdown, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], | |
outputs=[chatbot, history_state] | |
).then( | |
fn=lambda: gr.update(value=""), | |
inputs=None, | |
outputs=user_input | |
) | |
clear_button.click( | |
fn=lambda: ([], []), | |
inputs=None, | |
outputs=[chatbot, history_state] | |
) | |
example1_button.click( | |
fn=lambda: gr.update(value=example_messages["Learn about physics"]), | |
inputs=None, | |
outputs=user_input | |
) | |
example2_button.click( | |
fn=lambda: gr.update(value=example_messages["Discover space facts"]), | |
inputs=None, | |
outputs=user_input | |
) | |
example3_button.click( | |
fn=lambda: gr.update(value=example_messages["Write a factorial function"]), | |
inputs=None, | |
outputs=user_input | |
) | |
demo.launch(ssr_mode=False) |