Charan5775 commited on
Commit
b3901a6
·
verified ·
1 Parent(s): 4040b1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -11
app.py CHANGED
@@ -1,18 +1,73 @@
1
- from fastapi import FastAPI
2
- # Use a pipeline as a high-level helper
3
- from transformers import pipeline
 
 
 
 
4
 
5
- pipe = pipeline("text2text-generation", model="google/flan-t5-base")
6
 
7
  app = FastAPI()
8
 
9
 
10
- @app.get('/')
11
- def home():
12
- return {"hello": "Bitfumes"}
13
 
 
 
 
 
14
 
15
- @app.get('/ask')
16
- def ask(prompt: str):
17
- result = pipe(prompt)
18
- return result[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from typing import Optional
3
+ from fastapi.responses import StreamingResponse
4
+ from huggingface_hub import InferenceClient
5
+ from pydantic import BaseModel
6
+ import os
7
+ import uvicorn
8
 
 
9
 
10
  app = FastAPI()
11
 
12
 
13
+ # Default model
14
+ DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
 
15
 
16
+ class QueryRequest(BaseModel):
17
+ query: str
18
+ stream: bool = False
19
+ model_name: Optional[str] = None # If not provided, will use DEFAULT_MODEL
20
 
21
+ def get_client(model_name: Optional[str] = None):
22
+ """Get inference client for specified model or default model"""
23
+ try:
24
+ # Use provided model_name if it exists and is not empty, otherwise use DEFAULT_MODEL
25
+ model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
26
+
27
+ return InferenceClient(
28
+ model_path
29
+ )
30
+ except Exception as e:
31
+ raise HTTPException(
32
+ status_code=400,
33
+ detail=f"Error initializing model {model_path}: {str(e)}"
34
+ )
35
+
36
+ def generate_response(query: str, model_name: Optional[str] = None):
37
+ messages = []
38
+ messages.append({
39
+ "role": "user",
40
+ "content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
41
+ })
42
+
43
+ try:
44
+ client = get_client(model_name)
45
+ for message in client.chat_completion(
46
+ messages,
47
+ max_tokens=2048,
48
+ stream=True
49
+ ):
50
+ token = message.choices[0].delta.content
51
+ yield token
52
+ except Exception as e:
53
+ yield f"Error generating response: {str(e)}"
54
+
55
+ @app.get("/")
56
+ async def root():
57
+ return {"message": "Welcome to FastAPI server!"}
58
+
59
+ @app.post("/chat")
60
+ async def chat(request: QueryRequest):
61
+ try:
62
+ if request.stream:
63
+ return StreamingResponse(
64
+ generate_response(request.query, request.model_name),
65
+ media_type="text/event-stream"
66
+ )
67
+ else:
68
+ response = ""
69
+ for chunk in generate_response(request.query, request.model_name):
70
+ response += chunk
71
+ return {"response": response}
72
+ except Exception as e:
73
+ raise HTTPException(status_code=500, detail=str(e))