api-test / app.py
OjciecTadeusz's picture
Update app.py
37e4010 verified
raw
history blame
3.5 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import datetime
# Initialize FastAPI
app = FastAPI()
# Load model and tokenizer
model_name = "Qwen/Qwen2.5-Coder-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16
)
def format_chat_response(response_text, prompt_tokens, completion_tokens):
return {
"id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
"object": "chat.completion",
"created": int(datetime.datetime.now().timestamp()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
@app.post("/v1/chat/completions")
async def chat_completion(request: Request):
try:
data = await request.json()
messages = data.get("messages", [])
# Format messages for Qwen
conversation = []
for msg in messages:
conversation.append({
"role": msg["role"],
"content": msg["content"]
})
# Convert messages to model input format
prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Count prompt tokens
prompt_tokens = len(tokenizer.encode(prompt))
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=data.get("max_tokens", 2048),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.95),
do_sample=True
)
response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
completion_tokens = len(tokenizer.encode(response_text))
return JSONResponse(
content=format_chat_response(response_text, prompt_tokens, completion_tokens)
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
# Gradio interface for testing
def chat_interface(message, history):
history = history or []
messages = [{"role": "user", "content": message}]
# Add history to messages
for h in history:
messages.insert(0, {"role": "assistant" if i % 2 else "user", "content": h[1 if i % 2 else 0]}
for i in range(len(h)))
response = chat_completion(Request({"messages": messages}))
return response.choices[0].message.content
interface = gr.ChatInterface(
chat_interface,
title="Qwen2.5-Coder-32B Chat",
description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint."
)
# Mount both FastAPI and Gradio
app = gr.mount_gradio_app(app, interface, path="/")