Apriel-Chat / app.py
bradnow's picture
Debug error
9497fc8 verified
raw
history blame
5.32 kB
import os
import sys
from openai import OpenAI
import gradio as gr
from gradio.components.chatbot import ChatMessage, Message
from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
Union,
cast,
)
title = None # "ServiceNow-AI Chat" # modelConfig.get('MODE_DISPLAY_NAME')
description = None
model_config = {
"MODEL_NAME": os.environ.get("MODEL_NAME"),
"MODE_DISPLAY_NAME": os.environ.get("MODE_DISPLAY_NAME"),
"MODEL_HF_URL": os.environ.get("MODEL_HF_URL"),
"VLLM_API_URL": os.environ.get("VLLM_API_URL"),
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN")
}
# Initialize the OpenAI client with the vLLM API URL and token
client = OpenAI(
api_key=model_config.get('AUTH_TOKEN'),
base_url=model_config.get('VLLM_API_URL')
)
def _check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
if type == "messages":
all_valid = all(
isinstance(message, dict)
and "role" in message
and "content" in message
or isinstance(message, ChatMessage | Message)
for message in messages
)
if not all_valid:
# Display which message is not valid
for i, message in enumerate(messages):
if not (isinstance(message, dict) and
"role" in message and
"content" in message) and not isinstance(message, ChatMessage | Message):
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
break
raise Exception(
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
)
else:
print("_check_format() --> All messages are valid.")
elif not all(
isinstance(message, (tuple, list)) and len(message) == 2
for message in messages
):
raise Exception(
"Data incompatible with tuples format. Each message should be a list of length 2."
)
def chat_fn(message, history):
print(f"{'-' * 80}\nchat_fn() --> Message: {message}")
# Remove any assistant messages with metadata from history for multiple turns
print(f"Original History: {history}")
_check_format(history, "messages")
history = [item for item in history if
not (isinstance(item, dict) and
item.get("role") == "assistant" and
isinstance(item.get("metadata"), dict) and
item.get("metadata", {}).get("title") is not None)]
print(f"Updated History: {history}")
_check_format(history, "messages")
# messages = history + [{"role": "user", "content": message}]
# print(f"Messages: {messages}")
# _check_format(messages, "messages")
history.append({"role": "user", "content": message})
print(f"History with user message: {history}")
_check_format(history, "messages")
# Create the streaming response
stream = client.chat.completions.create(
model=model_config.get('MODEL_NAME'),
messages=history,
temperature=0.8,
stream=True
)
history.append(gr.ChatMessage(
role="assistant",
content="Thinking...",
metadata={"title": "🧠 Thought"}
))
print(f"History added thinking: {history}")
_check_format(history, "messages")
output = ""
completion_started = False
for chunk in stream:
# Extract the new content from the delta field
content = getattr(chunk.choices[0].delta, "content", "")
output += content
parts = output.split("[BEGIN FINAL RESPONSE]")
if len(parts) > 1:
if parts[1].endswith("[END FINAL RESPONSE]"):
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")
history[-1 if not completion_started else -2] = gr.ChatMessage(
role="assistant",
content=parts[0],
metadata={"title": "🧠 Thought"}
)
if completion_started:
history[-1] = gr.ChatMessage(
role="assistant",
content=parts[1]
)
elif len(parts) > 1 and not completion_started:
completion_started = True
history.append(gr.ChatMessage(
role="assistant",
content=parts[1]
))
# only yield the most recent assistant messages
messages_to_yield = history[-1:] if not completion_started else history[-2:]
# _check_format(messages_to_yield, "messages")
yield messages_to_yield
print(f"Final History: {history}")
_check_format(history, "messages")
# Add the model display name and Hugging Face URL to the description
# description = f"### Model: [{MODE_DISPLAY_NAME}]({MODEL_HF_URL})"
print(f"Running model {model_config.get('MODE_DISPLAY_NAME')} ({model_config.get('MODEL_NAME')})")
gr.ChatInterface(
chat_fn,
title=title,
description=description,
theme=gr.themes.Default(primary_hue="green"),
type="messages",
).launch()