Spaces:
Running
Running
import os | |
import sys | |
from openai import OpenAI | |
import gradio as gr | |
from gradio.components.chatbot import ChatMessage, Message | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Literal, | |
Optional, | |
Union, | |
cast, | |
) | |
title = None # "ServiceNow-AI Chat" # modelConfig.get('MODE_DISPLAY_NAME') | |
description = None | |
model_config = { | |
"MODEL_NAME": os.environ.get("MODEL_NAME"), | |
"MODE_DISPLAY_NAME": os.environ.get("MODE_DISPLAY_NAME"), | |
"MODEL_HF_URL": os.environ.get("MODEL_HF_URL"), | |
"VLLM_API_URL": os.environ.get("VLLM_API_URL"), | |
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN") | |
} | |
# Initialize the OpenAI client with the vLLM API URL and token | |
client = OpenAI( | |
api_key=model_config.get('AUTH_TOKEN'), | |
base_url=model_config.get('VLLM_API_URL') | |
) | |
def _check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None: | |
if type == "messages": | |
all_valid = all( | |
isinstance(message, dict) | |
and "role" in message | |
and "content" in message | |
or isinstance(message, ChatMessage | Message) | |
for message in messages | |
) | |
if not all_valid: | |
# Display which message is not valid | |
for i, message in enumerate(messages): | |
if not (isinstance(message, dict) and | |
"role" in message and | |
"content" in message) and not isinstance(message, ChatMessage | Message): | |
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr) | |
break | |
raise Exception( | |
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object." | |
) | |
else: | |
print("_check_format() --> All messages are valid.") | |
elif not all( | |
isinstance(message, (tuple, list)) and len(message) == 2 | |
for message in messages | |
): | |
raise Exception( | |
"Data incompatible with tuples format. Each message should be a list of length 2." | |
) | |
def chat_fn(message, history): | |
print(f"{'-' * 80}\nchat_fn() --> Message: {message}") | |
# Remove any assistant messages with metadata from history for multiple turns | |
print(f"Original History: {history}") | |
_check_format(history, "messages") | |
history = [item for item in history if | |
not (isinstance(item, dict) and | |
item.get("role") == "assistant" and | |
isinstance(item.get("metadata"), dict) and | |
item.get("metadata", {}).get("title") is not None)] | |
print(f"Updated History: {history}") | |
_check_format(history, "messages") | |
# messages = history + [{"role": "user", "content": message}] | |
# print(f"Messages: {messages}") | |
# _check_format(messages, "messages") | |
history.append({"role": "user", "content": message}) | |
print(f"History with user message: {history}") | |
_check_format(history, "messages") | |
# Create the streaming response | |
stream = client.chat.completions.create( | |
model=model_config.get('MODEL_NAME'), | |
messages=history, | |
temperature=0.8, | |
stream=True | |
) | |
history.append(gr.ChatMessage( | |
role="assistant", | |
content="Thinking...", | |
metadata={"title": "🧠 Thought"} | |
)) | |
print(f"History added thinking: {history}") | |
_check_format(history, "messages") | |
output = "" | |
completion_started = False | |
for chunk in stream: | |
# Extract the new content from the delta field | |
content = getattr(chunk.choices[0].delta, "content", "") | |
output += content | |
parts = output.split("[BEGIN FINAL RESPONSE]") | |
if len(parts) > 1: | |
if parts[1].endswith("[END FINAL RESPONSE]"): | |
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "") | |
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"): | |
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "") | |
history[-1 if not completion_started else -2] = gr.ChatMessage( | |
role="assistant", | |
content=parts[0], | |
metadata={"title": "🧠 Thought"} | |
) | |
if completion_started: | |
history[-1] = gr.ChatMessage( | |
role="assistant", | |
content=parts[1] | |
) | |
elif len(parts) > 1 and not completion_started: | |
completion_started = True | |
history.append(gr.ChatMessage( | |
role="assistant", | |
content=parts[1] | |
)) | |
# only yield the most recent assistant messages | |
messages_to_yield = history[-1:] if not completion_started else history[-2:] | |
# _check_format(messages_to_yield, "messages") | |
yield messages_to_yield | |
print(f"Final History: {history}") | |
_check_format(history, "messages") | |
# Add the model display name and Hugging Face URL to the description | |
# description = f"### Model: [{MODE_DISPLAY_NAME}]({MODEL_HF_URL})" | |
print(f"Running model {model_config.get('MODE_DISPLAY_NAME')} ({model_config.get('MODEL_NAME')})") | |
gr.ChatInterface( | |
chat_fn, | |
title=title, | |
description=description, | |
theme=gr.themes.Default(primary_hue="green"), | |
type="messages", | |
).launch() | |