xcurvnubaim commited on
Commit
d9c0ea9
·
1 Parent(s): 829102d

fix: sanitize image route input

Browse files
Files changed (1) hide show
  1. main.py +17 -4
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import numpy as np
2
- from fastapi import FastAPI, File, UploadFile
3
  import tensorflow as tf
4
  from PIL import Image
5
  from io import BytesIO
@@ -8,6 +8,7 @@ import cv2
8
  from datetime import datetime
9
  from fastapi.responses import FileResponse
10
  from fastapi.middleware.cors import CORSMiddleware
 
11
  app = FastAPI()
12
 
13
  app.add_middleware(
@@ -129,14 +130,26 @@ async def predict_v2(file: UploadFile = File(...)):
129
  plot_detected_rectangles(cv2.imread("input/" + filename), detections, "output/" + filename)
130
  return {
131
  "message": "Detection and classification completed successfully",
132
- "data": "output/" + filename,
133
  "class_names": class_names
134
  }
135
 
 
 
136
  @app.get("/image/")
137
  async def get_image(image_name: str):
138
- # Assume the images are stored in a directory named "images"
139
- image_path = f"output/{image_name}"
 
 
 
 
 
 
 
 
 
 
140
  return FileResponse(image_path)
141
 
142
  @app.post("/predict")
 
1
  import numpy as np
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException
3
  import tensorflow as tf
4
  from PIL import Image
5
  from io import BytesIO
 
8
  from datetime import datetime
9
  from fastapi.responses import FileResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from pathlib import Path
12
  app = FastAPI()
13
 
14
  app.add_middleware(
 
130
  plot_detected_rectangles(cv2.imread("input/" + filename), detections, "output/" + filename)
131
  return {
132
  "message": "Detection and classification completed successfully",
133
+ "out": filename,
134
  "class_names": class_names
135
  }
136
 
137
+ IMAGE_DIR = Path("output")
138
+
139
  @app.get("/image/")
140
  async def get_image(image_name: str):
141
+ # Sanitize the image_name to prevent directory traversal attacks
142
+ if "../" in image_name:
143
+ raise HTTPException(status_code=400, detail="Invalid image name")
144
+
145
+ # Construct the image path
146
+ image_path = IMAGE_DIR / image_name
147
+
148
+ # Check if the image exists
149
+ if not image_path.exists() or not image_path.is_file():
150
+ raise HTTPException(status_code=404, detail="Image not found")
151
+
152
+ # Return the image file
153
  return FileResponse(image_path)
154
 
155
  @app.post("/predict")