Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import json | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse | |
import datetime | |
# Initialize FastAPI | |
app = FastAPI() | |
# Load model and tokenizer | |
model_name = "Qwen/Qwen2.5-Coder-32B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
) | |
def format_chat_response(response_text, prompt_tokens, completion_tokens): | |
return { | |
"id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}", | |
"object": "chat.completion", | |
"created": int(datetime.datetime.now().timestamp()), | |
"model": model_name, | |
"choices": [{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": response_text | |
}, | |
"finish_reason": "stop" | |
}], | |
"usage": { | |
"prompt_tokens": prompt_tokens, | |
"completion_tokens": completion_tokens, | |
"total_tokens": prompt_tokens + completion_tokens | |
} | |
} | |
async def chat_completion(request: Request): | |
try: | |
data = await request.json() | |
messages = data.get("messages", []) | |
# Convert messages to model input format | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Count prompt tokens | |
prompt_tokens = len(tokenizer.encode(prompt)) | |
# Generate response | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=data.get("max_tokens", 2048), | |
temperature=data.get("temperature", 0.7), | |
top_p=data.get("top_p", 0.95), | |
do_sample=True | |
) | |
response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
completion_tokens = len(tokenizer.encode(response_text)) | |
return JSONResponse( | |
content=format_chat_response(response_text, prompt_tokens, completion_tokens) | |
) | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={"error": str(e)} | |
) | |
# Gradio interface for testing | |
def chat_interface(message, history): | |
history = history or [] | |
messages = [] | |
# Convert history to messages format | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Get response | |
response = chat_completion(Request(scope={"type": "http"}, receive=None)) | |
if isinstance(response, JSONResponse): | |
response_data = response.body.decode() | |
response_json = json.loads(response_data) | |
return response_json["choices"][0]["message"]["content"] | |
return "Error generating response" | |
interface = gr.ChatInterface( | |
chat_interface, | |
title="Qwen2.5-Coder-32B Chat", | |
description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint." | |
) | |
# Mount both FastAPI and Gradio | |
app = gr.mount_gradio_app(app, interface, path="/") |