api-test / app.py
OjciecTadeusz's picture
Update app.py
2e36566 verified
raw
history blame
3.57 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import datetime
# Initialize FastAPI
app = FastAPI()
# Load model and tokenizer
model_name = "Qwen/Qwen2.5-Coder-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16
)
def format_chat_response(response_text, prompt_tokens, completion_tokens):
return {
"id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
"object": "chat.completion",
"created": int(datetime.datetime.now().timestamp()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
@app.post("/v1/chat/completions")
async def chat_completion(request: Request):
try:
data = await request.json()
messages = data.get("messages", [])
# Convert messages to model input format
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Count prompt tokens
prompt_tokens = len(tokenizer.encode(prompt))
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=data.get("max_tokens", 2048),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.95),
do_sample=True
)
response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
completion_tokens = len(tokenizer.encode(response_text))
return JSONResponse(
content=format_chat_response(response_text, prompt_tokens, completion_tokens)
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
# Gradio interface for testing
def chat_interface(message, history):
history = history or []
messages = []
# Convert history to messages format
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Get response
response = chat_completion(Request(scope={"type": "http"}, receive=None))
if isinstance(response, JSONResponse):
response_data = response.body.decode()
response_json = json.loads(response_data)
return response_json["choices"][0]["message"]["content"]
return "Error generating response"
interface = gr.ChatInterface(
chat_interface,
title="Qwen2.5-Coder-32B Chat",
description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint."
)
# Mount both FastAPI and Gradio
app = gr.mount_gradio_app(app, interface, path="/")