testingtodeploy / app.py
Charan5775's picture
Update app.py
b3901a6 verified
raw
history blame
2.25 kB
from fastapi import FastAPI, HTTPException
from typing import Optional
from fastapi.responses import StreamingResponse
from huggingface_hub import InferenceClient
from pydantic import BaseModel
import os
import uvicorn
app = FastAPI()
# Default model
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
class QueryRequest(BaseModel):
query: str
stream: bool = False
model_name: Optional[str] = None # If not provided, will use DEFAULT_MODEL
def get_client(model_name: Optional[str] = None):
"""Get inference client for specified model or default model"""
try:
# Use provided model_name if it exists and is not empty, otherwise use DEFAULT_MODEL
model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
return InferenceClient(
model_path
)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Error initializing model {model_path}: {str(e)}"
)
def generate_response(query: str, model_name: Optional[str] = None):
messages = []
messages.append({
"role": "user",
"content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
})
try:
client = get_client(model_name)
for message in client.chat_completion(
messages,
max_tokens=2048,
stream=True
):
token = message.choices[0].delta.content
yield token
except Exception as e:
yield f"Error generating response: {str(e)}"
@app.get("/")
async def root():
return {"message": "Welcome to FastAPI server!"}
@app.post("/chat")
async def chat(request: QueryRequest):
try:
if request.stream:
return StreamingResponse(
generate_response(request.query, request.model_name),
media_type="text/event-stream"
)
else:
response = ""
for chunk in generate_response(request.query, request.model_name):
response += chunk
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))