File size: 2,250 Bytes
b3901a6
 
 
 
 
 
 
c49ad09
 
 
 
 
b3901a6
 
c49ad09
b3901a6
 
 
 
c49ad09
b3901a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI, HTTPException
from typing import Optional
from fastapi.responses import StreamingResponse
from huggingface_hub import InferenceClient
from pydantic import BaseModel
import os
import uvicorn


app = FastAPI()


# Default model
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"

class QueryRequest(BaseModel):
    query: str
    stream: bool = False
    model_name: Optional[str] = None  # If not provided, will use DEFAULT_MODEL

def get_client(model_name: Optional[str] = None):
    """Get inference client for specified model or default model"""
    try:
        # Use provided model_name if it exists and is not empty, otherwise use DEFAULT_MODEL
        model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
        
        return InferenceClient(
            model_path
        )
    except Exception as e:
        raise HTTPException(
            status_code=400, 
            detail=f"Error initializing model {model_path}: {str(e)}"
        )

def generate_response(query: str, model_name: Optional[str] = None):
    messages = []
    messages.append({
        "role": "user", 
        "content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
    })

    try:
        client = get_client(model_name)
        for message in client.chat_completion(
            messages,
            max_tokens=2048,
            stream=True
        ):
            token = message.choices[0].delta.content
            yield token
    except Exception as e:
        yield f"Error generating response: {str(e)}"

@app.get("/")
async def root():
    return {"message": "Welcome to FastAPI server!"}

@app.post("/chat")
async def chat(request: QueryRequest):
    try:
        if request.stream:
            return StreamingResponse(
                generate_response(request.query, request.model_name),
                media_type="text/event-stream"
            )
        else:
            response = ""
            for chunk in generate_response(request.query, request.model_name):
                response += chunk
            return {"response": response}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))