novamysticX commited on
Commit
c195a75
·
verified ·
1 Parent(s): 751254b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -23
app.py CHANGED
@@ -1,24 +1,61 @@
1
- from fastapi import FastAPI
 
2
  from transformers import pipeline
3
-
4
- ## create a new FASTAPI app instance
5
- app=FastAPI()
6
-
7
- # Initialize the text generation pipeline
8
- pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b")
9
-
10
-
11
- @app.get("/")
12
- def home():
13
- return {"message":"Hello World"}
14
-
15
- # Define a function to handle the GET request at `/generate`
16
-
17
-
18
- @app.get("/generate")
19
- def generate(text:str):
20
- ## use the pipeline to generate text from given input text
21
- output=pipe(text)
22
-
23
- ## return the generate text in Json reposne
24
- return {"output":output[0]['generated_text']}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  from transformers import pipeline
4
+ import logging
5
+
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ app = FastAPI(title="SQL Coder API")
11
+
12
+ # Initialize pipeline
13
+ try:
14
+ pipe = pipeline("text-generation",
15
+ model="defog/llama-3-sqlcoder-8b",
16
+ device_map="auto",
17
+ torch_dtype="auto")
18
+ logger.info("Pipeline initialized successfully")
19
+ except Exception as e:
20
+ logger.error(f"Error initializing pipeline: {str(e)}")
21
+ raise
22
+
23
+ class ChatMessage(BaseModel):
24
+ role: str
25
+ content: str
26
+
27
+ class QueryRequest(BaseModel):
28
+ messages: list[ChatMessage]
29
+ max_length: int = 1024
30
+ temperature: float = 0.7
31
+
32
+ class QueryResponse(BaseModel):
33
+ generated_text: str
34
+
35
+ @app.post("/generate", response_model=QueryResponse)
36
+ async def generate(request: QueryRequest):
37
+ try:
38
+ # Format messages into a single string
39
+ formatted_prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages])
40
+
41
+ # Generate response using pipeline
42
+ response = pipe(
43
+ formatted_prompt,
44
+ max_length=request.max_length,
45
+ temperature=request.temperature,
46
+ do_sample=True,
47
+ num_return_sequences=1
48
+ )
49
+
50
+ # Extract generated text
51
+ generated_text = response[0]['generated_text']
52
+
53
+ return QueryResponse(generated_text=generated_text)
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error generating response: {str(e)}")
57
+ raise HTTPException(status_code=500, detail=str(e))
58
+
59
+ @app.get("/health")
60
+ async def health_check():
61
+ return {"status": "healthy"}