File size: 3,568 Bytes
37e4010
 
cce0194
37e4010
 
 
 
cce0194
37e4010
cce0194
 
37e4010
 
 
 
 
 
 
 
404e508
 
37e4010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cce0194
37e4010
cce0194
37e4010
 
cce0194
37e4010
 
 
 
 
2e36566
37e4010
 
404e508
 
37e4010
 
 
 
 
 
 
 
 
 
 
cce0194
404e508
37e4010
 
 
 
 
 
cce0194
37e4010
 
 
 
cce0194
37e4010
 
 
2e36566
 
 
 
 
 
37e4010
2e36566
 
37e4010
2e36566
 
 
 
 
 
 
37e4010
 
 
 
 
 
404e508
37e4010
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import datetime

# Initialize FastAPI
app = FastAPI()

# Load model and tokenizer
model_name = "Qwen/Qwen2.5-Coder-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.float16
)

def format_chat_response(response_text, prompt_tokens, completion_tokens):
    return {
        "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
        "object": "chat.completion",
        "created": int(datetime.datetime.now().timestamp()),
        "model": model_name,
        "choices": [{
            "index": 0,
            "message": {
                "role": "assistant",
                "content": response_text
            },
            "finish_reason": "stop"
        }],
        "usage": {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": prompt_tokens + completion_tokens
        }
    }

@app.post("/v1/chat/completions")
async def chat_completion(request: Request):
    try:
        data = await request.json()
        messages = data.get("messages", [])
        
        # Convert messages to model input format
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Count prompt tokens
        prompt_tokens = len(tokenizer.encode(prompt))
        
        # Generate response
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=data.get("max_tokens", 2048),
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.95),
            do_sample=True
        )
        
        response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        completion_tokens = len(tokenizer.encode(response_text))
        
        return JSONResponse(
            content=format_chat_response(response_text, prompt_tokens, completion_tokens)
        )
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={"error": str(e)}
        )

# Gradio interface for testing
def chat_interface(message, history):
    history = history or []
    messages = []
    
    # Convert history to messages format
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": assistant_msg})
    
    # Add current message
    messages.append({"role": "user", "content": message})
    
    # Get response
    response = chat_completion(Request(scope={"type": "http"}, receive=None))
    if isinstance(response, JSONResponse):
        response_data = response.body.decode()
        response_json = json.loads(response_data)
        return response_json["choices"][0]["message"]["content"]
    return "Error generating response"

interface = gr.ChatInterface(
    chat_interface,
    title="Qwen2.5-Coder-32B Chat",
    description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint."
)

# Mount both FastAPI and Gradio
app = gr.mount_gradio_app(app, interface, path="/")