hz2475's picture
init
72f684c
import argparse
import gradio as gr
from openai import OpenAI
# Argument parser setup
parser = argparse.ArgumentParser(
description="Chatbot Interface with Customizable Parameters"
)
parser.add_argument(
"--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
)
parser.add_argument(
"-m",
"--model",
type=str,
default="ServiceNow/starvector-1.4b-im2svg-v6",
help="Model name for the chatbot",
)
parser.add_argument(
"--temp", type=float, default=0.8, help="Temperature for text generation"
)
parser.add_argument(
"--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
)
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
# Parse the arguments
args = parser.parse_args()
# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = args.model_url
# Create an OpenAI client to interact with the API server
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
def predict(message, history):
# Convert chat history to OpenAI format
history_openai_format = [
{"role": "system", "content": "You are a great ai assistant."}
]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
# Create a chat completion request and send it to the API server
stream = client.chat.completions.create(
model=args.model, # Model name to use
messages=history_openai_format, # Chat history
temperature=args.temp, # Temperature for text generation
stream=True, # Stream response
extra_body={
"repetition_penalty": 1,
"stop_token_ids": (
[int(id.strip()) for id in args.stop_token_ids.split(",") if id.strip()]
if args.stop_token_ids
else []
),
},
)
# Read and return generated text from response stream
partial_message = ""
for chunk in stream:
partial_message += chunk.choices[0].delta.content or ""
yield partial_message
# Create and launch a chat interface with Gradio
gr.ChatInterface(predict).queue().launch(
server_name=args.host, server_port=args.port, share=True
)