api-test / app.py
OjciecTadeusz's picture
Update app.py
dff7757 verified
raw
history blame
3.76 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import datetime
# Initialize FastAPI
app = FastAPI()
# Load model and tokenizer
model_name = "Qwen/Qwen2.5-Coder-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Configure model loading with specific parameters
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
def format_chat_response(response_text, prompt_tokens, completion_tokens):
return {
"id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
"object": "chat.completion",
"created": int(datetime.datetime.now().timestamp()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
@app.post("/v1/chat/completions")
async def chat_completion(request: Request):
try:
data = await request.json()
messages = data.get("messages", [])
# Convert messages to model input format
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Count prompt tokens
prompt_tokens = len(tokenizer.encode(prompt))
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=data.get("max_tokens", 2048),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.95),
do_sample=True
)
response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
completion_tokens = len(tokenizer.encode(response_text))
return JSONResponse(
content=format_chat_response(response_text, prompt_tokens, completion_tokens)
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
# Gradio interface for testing
def chat_interface(message, history):
history = history or []
messages = []
# Convert history to messages format
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Create a mock request object with the messages
mock_request = Request(scope={"type": "http"}, receive=None)
mock_request.json = lambda: {"messages": messages}
# Get response
response = await chat_completion(mock_request)
if isinstance(response, JSONResponse):
response_data = json.loads(response.body.decode())
return response_data["choices"][0]["message"]["content"]
return "Error generating response"
interface = gr.ChatInterface(
chat_interface,
title="Qwen2.5-Coder-32B Chat",
description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint."
)
# Mount both FastAPI and Gradio
app = gr.mount_gradio_app(app, interface, path="/")