Apriel-Chat / app.py
bradnow's picture
Add stop button and new styles
0bb4279
import datetime
from openai import OpenAI
import gradio as gr
from theme import apriel
from utils import COMMUNITY_POSTFIX_URL, get_model_config, log_message, check_format, models_config, DEBUG_MODE
MODEL_TEMPERATURE = 0.8
BUTTON_WIDTH = 160
DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker" if not DEBUG_MODE else "Apriel-5b"
# DEFAULT_MODEL_NAME = "Apriel-5b"
print(f"Gradio version: {gr.__version__}")
BUTTON_ENABLED = gr.update(interactive=True)
BUTTON_DISABLED = gr.update(interactive=False)
INPUT_ENABLED = gr.update(interactive=True)
INPUT_DISABLED = gr.update(interactive=False)
DROPDOWN_ENABLED = gr.update(interactive=True)
DROPDOWN_DISABLED = gr.update(interactive=False)
SEND_BUTTON_ENABLED = gr.update(interactive=True, visible=True)
SEND_BUTTON_DISABLED = gr.update(interactive=True, visible=False)
STOP_BUTTON_ENABLED = gr.update(interactive=True, visible=True)
STOP_BUTTON_DISABLED = gr.update(interactive=True, visible=False)
chat_start_count = 0
model_config = {}
openai_client = None
def update_model_and_clear_chat(model_name):
actual_model_name = model_name.replace("Model: ", "")
desc = setup_model(actual_model_name)
return desc, []
def setup_model(model_name, intial=False):
global model_config, openai_client
model_config = get_model_config(model_name)
log_message(f"update_model() --> Model config: {model_config}")
openai_client = OpenAI(
api_key=model_config.get('AUTH_TOKEN'),
base_url=model_config.get('VLLM_API_URL')
)
_model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1]
_link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>"
_description = f"We'd love to hear your thoughts on the model. Click here to provide feedback - {_link}"
print(f"Switched to model {_model_hf_name}")
if intial:
return
else:
return _description
def chat_started():
# outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn
return (DROPDOWN_DISABLED, gr.update(value="", interactive=False),
SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED)
def chat_finished():
# outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn
return DROPDOWN_ENABLED, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED
def stop_chat(state):
state["stop_flag"] = True
gr.Info("Chat stopped")
return state
def run_chat_inference(history, message, state):
global chat_start_count
state["is_streaming"] = True
state["stop_flag"] = False
# outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn, session_state
log_message(f"{'-' * 80}")
log_message(f"chat_fn() --> Message: {message}")
log_message(f"chat_fn() --> History: {history}")
try:
# Check if the message is empty
if not message.strip():
gr.Info("Please enter a message before sending")
yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
chat_start_count = chat_start_count + 1
print(
f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")
is_reasoning = model_config.get("REASONING")
# Remove any assistant messages with metadata from history for multiple turns
log_message(f"Initial History: {history}")
check_format(history, "messages")
history.append({"role": "user", "content": message})
log_message(f"History with user message: {history}")
check_format(history, "messages")
# Create the streaming response
try:
history_no_thoughts = [item for item in history if
not (isinstance(item, dict) and
item.get("role") == "assistant" and
isinstance(item.get("metadata"), dict) and
item.get("metadata", {}).get("title") is not None)]
log_message(f"Updated History: {history_no_thoughts}")
check_format(history_no_thoughts, "messages")
log_message(f"history_no_thoughts with user message: {history_no_thoughts}")
stream = openai_client.chat.completions.create(
model=model_config.get('MODEL_NAME'),
messages=history_no_thoughts,
temperature=MODEL_TEMPERATURE,
stream=True
)
except Exception as e:
print(f"Error: {e}")
yield ([{"role": "assistant",
"content": "😔 The model is unavailable at the moment. Please try again later."}],
INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state)
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
if is_reasoning:
history.append(gr.ChatMessage(
role="assistant",
content="Thinking...",
metadata={"title": "🧠 Thought"}
))
log_message(f"History added thinking: {history}")
check_format(history, "messages")
else:
history.append(gr.ChatMessage(
role="assistant",
content="",
))
log_message(f"History added empty assistant: {history}")
check_format(history, "messages")
output = ""
completion_started = False
for chunk in stream:
if state["stop_flag"]:
log_message(f"chat_fn() --> Stopping streaming...")
break # Exit the loop if the stop flag is set
# Extract the new content from the delta field
content = getattr(chunk.choices[0].delta, "content", "")
output += content
if is_reasoning:
parts = output.split("[BEGIN FINAL RESPONSE]")
if len(parts) > 1:
if parts[1].endswith("[END FINAL RESPONSE]"):
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")
if parts[1].endswith("<|end|>"):
parts[1] = parts[1].replace("<|end|>", "")
history[-1 if not completion_started else -2] = gr.ChatMessage(
role="assistant",
content=parts[0],
metadata={"title": "🧠 Thought"}
)
if completion_started:
history[-1] = gr.ChatMessage(
role="assistant",
content=parts[1]
)
elif len(parts) > 1 and not completion_started:
completion_started = True
history.append(gr.ChatMessage(
role="assistant",
content=parts[1]
))
else:
if output.endswith("<|end|>"):
output = output.replace("<|end|>", "")
history[-1] = gr.ChatMessage(
role="assistant",
content=output
)
# log_message(f"Yielding messages: {history}")
yield history, INPUT_DISABLED, SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED, state
log_message(f"Final History: {history}")
check_format(history, "messages")
yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
finally:
state["is_streaming"] = False
state["stop_flag"] = False
log_message(f"chat_fn() --> Finished streaming. {chat_start_count} chats started.")
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
title = None
description = None
theme = apriel
with open('styles.css', 'r') as f:
custom_css = f.read()
with gr.Blocks(theme=theme, css=custom_css) as demo:
session_state = gr.State(value={"is_streaming": False, "stop_flag": False}) # Store session state as a dictionary
gr.HTML(f"""
<style>
@media (min-width: 1024px) {{
.send-button-container, .clear-button-container {{
max-width: {BUTTON_WIDTH}px;
}}
}}
</style>
""", elem_classes="css-styles")
with gr.Row(variant="panel", elem_classes="responsive-row"):
with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"):
model_dropdown = gr.Dropdown(
choices=[f"Model: {model}" for model in models_config.keys()],
value=f"Model: {DEFAULT_MODEL_NAME}",
label=None,
interactive=True,
container=False,
scale=0,
min_width=400
)
with gr.Column(scale=4, min_width=0):
description_html = gr.HTML(description, elem_classes="model-message")
chatbot = gr.Chatbot(
type="messages",
height="calc(100dvh - 280px)",
elem_classes="chatbot",
)
with gr.Row():
with gr.Column(scale=10, min_width=400, elem_classes="user-input-container"):
user_input = gr.Textbox(
show_label=False,
placeholder="Type your message here and press Enter",
container=False
)
with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20):
with gr.Row():
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"):
send_btn = gr.Button("Send", variant="primary")
stop_btn = gr.Button("Stop", variant="cancel", visible=False)
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"):
clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary")
gr.on(
triggers=[send_btn.click, user_input.submit],
fn=run_chat_inference, # this generator streams results. do not use logged_event_handler wrapper
inputs=[chatbot, user_input, session_state],
outputs=[chatbot, user_input, send_btn, stop_btn, clear_btn, session_state]
).then(
fn=chat_finished, inputs=None, outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn], queue=False)
# In parallel, disable or update the UI controls
gr.on(
triggers=[send_btn.click, user_input.submit],
fn=chat_started,
inputs=None,
outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn],
queue=False,
show_progress='hidden'
)
stop_btn.click(
fn=stop_chat,
inputs=[session_state],
outputs=[session_state]
)
# Ensure the model is reset to default on page reload
demo.load(lambda: setup_model(DEFAULT_MODEL_NAME, intial=False), [], [description_html])
model_dropdown.change(
fn=update_model_and_clear_chat,
inputs=[model_dropdown],
outputs=[description_html, chatbot]
)
demo.launch(ssr_mode=False, show_api=False)