Spaces:
Running
Running
import argparse | |
import gradio as gr | |
from openai import OpenAI | |
# Argument parser setup | |
parser = argparse.ArgumentParser( | |
description="Chatbot Interface with Customizable Parameters" | |
) | |
parser.add_argument( | |
"--model-url", type=str, default="http://localhost:8000/v1", help="Model URL" | |
) | |
parser.add_argument( | |
"-m", | |
"--model", | |
type=str, | |
default="ServiceNow/starvector-1.4b-im2svg-v6", | |
help="Model name for the chatbot", | |
) | |
parser.add_argument( | |
"--temp", type=float, default=0.8, help="Temperature for text generation" | |
) | |
parser.add_argument( | |
"--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs" | |
) | |
parser.add_argument("--host", type=str, default=None) | |
parser.add_argument("--port", type=int, default=8001) | |
# Parse the arguments | |
args = parser.parse_args() | |
# Set OpenAI's API key and API base to use vLLM's API server. | |
openai_api_key = "EMPTY" | |
openai_api_base = args.model_url | |
# Create an OpenAI client to interact with the API server | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=openai_api_base, | |
) | |
def predict(message, history): | |
# Convert chat history to OpenAI format | |
history_openai_format = [ | |
{"role": "system", "content": "You are a great ai assistant."} | |
] | |
for human, assistant in history: | |
history_openai_format.append({"role": "user", "content": human}) | |
history_openai_format.append({"role": "assistant", "content": assistant}) | |
history_openai_format.append({"role": "user", "content": message}) | |
# Create a chat completion request and send it to the API server | |
stream = client.chat.completions.create( | |
model=args.model, # Model name to use | |
messages=history_openai_format, # Chat history | |
temperature=args.temp, # Temperature for text generation | |
stream=True, # Stream response | |
extra_body={ | |
"repetition_penalty": 1, | |
"stop_token_ids": ( | |
[int(id.strip()) for id in args.stop_token_ids.split(",") if id.strip()] | |
if args.stop_token_ids | |
else [] | |
), | |
}, | |
) | |
# Read and return generated text from response stream | |
partial_message = "" | |
for chunk in stream: | |
partial_message += chunk.choices[0].delta.content or "" | |
yield partial_message | |
# Create and launch a chat interface with Gradio | |
gr.ChatInterface(predict).queue().launch( | |
server_name=args.host, server_port=args.port, share=True | |
) | |