sreejith8100 commited on
Commit
ac2dd2f
·
verified ·
1 Parent(s): abbeb5a

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +90 -0
main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from pydantic import BaseModel
4
+ import types
5
+ import json
6
+ from pydantic import validator
7
+ from endpoint_handler import EndpointHandler # your handler file
8
+ import base64
9
+
10
+ app = FastAPI()
11
+
12
+ handler = None
13
+
14
+ @app.on_event("startup")
15
+ async def load_handler():
16
+ global handler
17
+ handler = EndpointHandler()
18
+
19
+ class PredictInput(BaseModel):
20
+ image: str # base64-encoded image string
21
+ question: str
22
+ stream: bool = False
23
+
24
+ @validator("question")
25
+ def question_not_empty(cls, v):
26
+ if not v.strip():
27
+ raise ValueError("Question must not be empty")
28
+ return v
29
+
30
+ @validator("image")
31
+ def valid_base64_and_size(cls, v):
32
+ try:
33
+ decoded = base64.b64decode(v, validate=True)
34
+ except Exception:
35
+ raise ValueError("`image` must be valid base64")
36
+ if len(decoded) > 10 * 1024 * 1024: # 10 MB limit
37
+ raise ValueError("Image exceeds 10 MB after decoding")
38
+ return v
39
+
40
+ class PredictRequest(BaseModel):
41
+ inputs: PredictInput
42
+
43
+ @app.get("/")
44
+ async def root():
45
+ return {"message": "FastAPI app is running on Hugging Face"}
46
+
47
+ @app.post("/predict")
48
+ async def predict_endpoint(payload: PredictRequest):
49
+ """
50
+ Handles prediction requests by processing the input payload and returning the prediction result.
51
+ Args:
52
+ payload (PredictRequest): The request payload containing the input data for prediction, including image, question, and stream flag.
53
+ Returns:
54
+ JSONResponse: If a ValueError occurs, returns a JSON response with an error message and status code 400.
55
+ JSONResponse: If any other exception occurs, returns a JSON response with a generic error message and status code 500.
56
+ StreamingResponse: If the prediction result is a generator (streaming), returns a streaming response with event-stream media type, yielding prediction chunks as JSON.
57
+ Notes:
58
+ - Logs the received question for debugging purposes.
59
+ - Handles both standard and streaming prediction results.
60
+ - Structured JSON messages are sent to indicate the end of the stream or errors during streaming.
61
+ """
62
+ print(f"[Request] Received question: {payload.inputs.question}")
63
+
64
+ data = {
65
+ "inputs": {
66
+ "image": payload.inputs.image,
67
+ "question": payload.inputs.question,
68
+ "stream": payload.inputs.stream
69
+ }
70
+ }
71
+
72
+ try:
73
+ result = handler.predict(data)
74
+ except ValueError as ve:
75
+ return JSONResponse({"error": str(ve)}, status_code=400)
76
+ except Exception as e:
77
+ return JSONResponse({"error": "Internal server error"}, status_code=500)
78
+
79
+ if isinstance(result, types.GeneratorType):
80
+ def event_stream():
81
+ try:
82
+ for chunk in result:
83
+ yield f"data: {json.dumps(chunk)}\n\n"
84
+ # Return structured JSON to indicate end of stream
85
+ yield f"data: {json.dumps({'end': True})}\n\n"
86
+ except Exception as e:
87
+ # Return structured JSON to indicate error
88
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
89
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
90
+