AItoolstack commited on
Commit
534b16a
·
verified ·
1 Parent(s): 9877732

Update app/routers/inference.py

Browse files
Files changed (1) hide show
  1. app/routers/inference.py +217 -218
app/routers/inference.py CHANGED
@@ -1,218 +1,217 @@
1
-
2
- from fastapi import APIRouter, Request, UploadFile, File, Form
3
- from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
4
- from fastapi.templating import Jinja2Templates
5
- from starlette.background import BackgroundTask
6
- import shutil
7
- import os
8
- import uuid
9
- from pathlib import Path
10
- from typing import Optional
11
- import json
12
- import base64
13
- from ultralytics import YOLO
14
- import cv2
15
- import numpy as np
16
-
17
-
18
- # Templates directory
19
- TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates")
20
- templates = Jinja2Templates(directory=TEMPLATES_DIR)
21
-
22
- router = APIRouter()
23
-
24
- UPLOAD_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "uploads")
25
- RESULTS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "results")
26
-
27
- os.makedirs(UPLOAD_DIR, exist_ok=True)
28
- os.makedirs(RESULTS_DIR, exist_ok=True)
29
-
30
- ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "tiff", "tif"}
31
-
32
- # Model paths
33
- BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
34
- DAMAGE_MODEL_PATH = os.path.join(BASE_DIR, "models", "damage", "weights", "weights", "best.pt")
35
- PARTS_MODEL_PATH = os.path.join(BASE_DIR, "models", "parts", "weights", "weights", "best.pt")
36
-
37
- # Class names for parts
38
- PARTS_CLASS_NAMES = ['headlamp', 'front_bumper', 'hood', 'door', 'rear_bumper']
39
-
40
- # Helper: Run YOLO inference and return results
41
- def run_yolo_inference(model_path, image_path, task='segment'):
42
- model = YOLO(model_path)
43
- results = model.predict(source=image_path, imgsz=640, conf=0.25, save=False, task=task)
44
- return results[0]
45
-
46
- # Helper: Draw masks and confidence on image
47
- def draw_masks_and_conf(image_path, yolo_result, class_names=None):
48
- img = cv2.imread(image_path)
49
- overlay = img.copy()
50
- out_img = img.copy()
51
- colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255), (0,255,255)]
52
- for i, box in enumerate(yolo_result.boxes):
53
- conf = float(box.conf[0])
54
- cls = int(box.cls[0])
55
- color = colors[cls % len(colors)]
56
- # Draw bbox
57
- x1, y1, x2, y2 = map(int, box.xyxy[0])
58
- cv2.rectangle(overlay, (x1, y1), (x2, y2), color, 2)
59
- label = f"{class_names[cls] if class_names else 'damage'}: {conf:.2f}"
60
- cv2.putText(overlay, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
61
- # Draw mask if available
62
- if hasattr(yolo_result, 'masks') and yolo_result.masks is not None:
63
- mask = yolo_result.masks.data[i].cpu().numpy()
64
- mask = (mask * 255).astype(np.uint8)
65
- mask = cv2.resize(mask, (x2-x1, y2-y1))
66
- roi = overlay[y1:y2, x1:x2]
67
- colored_mask = np.zeros_like(roi)
68
- colored_mask[mask > 127] = color
69
- overlay[y1:y2, x1:x2] = cv2.addWeighted(roi, 0.5, colored_mask, 0.5, 0)
70
- out_img = cv2.addWeighted(overlay, 0.7, img, 0.3, 0)
71
- return out_img
72
-
73
- # Helper: Generate JSON output
74
- def generate_json_output(filename, damage_result, parts_result):
75
- # Damage severity: use max confidence
76
- severity_score = float(max([float(box.conf[0]) for box in damage_result.boxes], default=0))
77
- damage_regions = []
78
- for box in damage_result.boxes:
79
- x1, y1, x2, y2 = map(float, box.xyxy[0])
80
- conf = float(box.conf[0])
81
- damage_regions.append({"bbox": [x1, y1, x2, y2], "confidence": conf})
82
- # Parts
83
- parts = []
84
- for i, box in enumerate(parts_result.boxes):
85
- x1, y1, x2, y2 = map(float, box.xyxy[0])
86
- conf = float(box.conf[0])
87
- cls = int(box.cls[0])
88
- # Damage %: use mask area / bbox area if available
89
- damage_percentage = None
90
- if hasattr(parts_result, 'masks') and parts_result.masks is not None:
91
- mask = parts_result.masks.data[i].cpu().numpy()
92
- mask_area = np.sum(mask > 0.5)
93
- bbox_area = (x2-x1)*(y2-y1)
94
- damage_percentage = float(mask_area / bbox_area) if bbox_area > 0 else None
95
- parts.append({
96
- "part": PARTS_CLASS_NAMES[cls] if cls < len(PARTS_CLASS_NAMES) else str(cls),
97
- "damaged": True,
98
- "confidence": conf,
99
- "damage_percentage": damage_percentage,
100
- "bbox": [x1, y1, x2, y2]
101
- })
102
- # Optionally, add base64 masks
103
- # (not implemented here for brevity)
104
- return {
105
- "filename": filename,
106
- "damage": {
107
- "severity_score": severity_score,
108
- "regions": damage_regions
109
- },
110
- "parts": parts,
111
- "cost_estimate": None
112
- }
113
-
114
- # Dummy login credentials
115
- def check_login(username: str, password: str) -> bool:
116
- return username == "demo" and password == "demo123"
117
-
118
- @router.get("/", response_class=HTMLResponse)
119
- def home(request: Request):
120
- return templates.TemplateResponse("index.html", {"request": request, "result": None})
121
-
122
- @router.post("/login", response_class=HTMLResponse)
123
- def login(request: Request, username: str = Form(...), password: str = Form(...)):
124
- if check_login(username, password):
125
- return templates.TemplateResponse("index.html", {"request": request, "result": None, "user": username})
126
- return templates.TemplateResponse("login.html", {"request": request, "error": "Invalid credentials"})
127
-
128
- @router.get("/login", response_class=HTMLResponse)
129
- def login_page(request: Request):
130
- return templates.TemplateResponse("login.html", {"request": request})
131
-
132
- @router.post("/upload", response_class=HTMLResponse)
133
- def upload_image(request: Request, file: UploadFile = File(...)):
134
- ext = file.filename.split(".")[-1].lower()
135
- if ext not in ALLOWED_EXTENSIONS:
136
- return templates.TemplateResponse("index.html", {"request": request, "error": "Unsupported file type."})
137
-
138
- # Save uploaded file
139
- session_id = str(uuid.uuid4())
140
- upload_path = os.path.join(UPLOAD_DIR, f"{session_id}.{ext}")
141
- with open(upload_path, "wb") as buffer:
142
- shutil.copyfileobj(file.file, buffer)
143
-
144
- # Run both inferences
145
- try:
146
- damage_result = run_yolo_inference(DAMAGE_MODEL_PATH, upload_path)
147
- parts_result = run_yolo_inference(PARTS_MODEL_PATH, upload_path)
148
-
149
- # Save annotated images
150
- damage_img_path = os.path.join(RESULTS_DIR, f"{session_id}_damage.png")
151
- parts_img_path = os.path.join(RESULTS_DIR, f"{session_id}_parts.png")
152
- json_path = os.path.join(RESULTS_DIR, f"{session_id}_result.json")
153
- damage_img_url = f"/static/results/{session_id}_damage.png"
154
- parts_img_url = f"/static/results/{session_id}_parts.png"
155
- json_url = f"/static/results/{session_id}_result.json"
156
-
157
- # Defensive: set to None by default
158
- damage_img = None
159
- parts_img = None
160
- json_output = None
161
-
162
- # Only save and set if inference returns boxes
163
- if hasattr(damage_result, 'boxes') and len(damage_result.boxes) > 0:
164
- damage_img = draw_masks_and_conf(upload_path, damage_result)
165
- cv2.imwrite(damage_img_path, damage_img)
166
- if hasattr(parts_result, 'boxes') and len(parts_result.boxes) > 0:
167
- parts_img = draw_masks_and_conf(upload_path, parts_result, class_names=PARTS_CLASS_NAMES)
168
- cv2.imwrite(parts_img_path, parts_img)
169
- if (hasattr(damage_result, 'boxes') and len(damage_result.boxes) > 0) or (hasattr(parts_result, 'boxes') and len(parts_result.boxes) > 0):
170
- json_output = generate_json_output(file.filename, damage_result, parts_result)
171
- with open(json_path, "w") as jf:
172
- json.dump(json_output, jf, indent=2)
173
-
174
- # Prepare URLs for download (only if files exist)
175
- result = {
176
- "filename": file.filename,
177
- "damage_image": damage_img_url if damage_img is not None else None,
178
- "parts_image": parts_img_url if parts_img is not None else None,
179
- "json": json_output,
180
- "json_download": json_url if json_output is not None else None
181
- }
182
- # Debug log
183
- print("[DEBUG] Result dict:", result)
184
- except Exception as e:
185
- result = {
186
- "filename": file.filename,
187
- "error": f"Inference failed: {str(e)}",
188
- "damage_image": None,
189
- "parts_image": None,
190
- "json": None,
191
- "json_download": None
192
- }
193
- print("[ERROR] Inference failed:", e)
194
-
195
- import threading
196
- import time
197
- def delayed_cleanup():
198
- time.sleep(300) # 5 minutes
199
- try:
200
- os.remove(upload_path)
201
- except Exception:
202
- pass
203
- for suffix in ["_damage.png", "_parts.png", "_result.json"]:
204
- try:
205
- os.remove(os.path.join(RESULTS_DIR, f"{session_id}{suffix}"))
206
- except Exception:
207
- pass
208
-
209
- threading.Thread(target=delayed_cleanup, daemon=True).start()
210
-
211
- return templates.TemplateResponse(
212
- "index.html",
213
- {
214
- "request": request,
215
- "result": result,
216
- "original_image": f"/static/uploads/{session_id}.{ext}"
217
- }
218
- )
 
1
+
2
+ from fastapi import APIRouter, Request, UploadFile, File, Form
3
+ from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
4
+ from fastapi.templating import Jinja2Templates
5
+ from starlette.background import BackgroundTask
6
+ import shutil
7
+ import os
8
+ import uuid
9
+ from pathlib import Path
10
+ from typing import Optional
11
+ import json
12
+ import base64
13
+ from ultralytics import YOLO
14
+ import cv2
15
+ import numpy as np
16
+
17
+
18
+ # Templates directory
19
+ TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates")
20
+ templates = Jinja2Templates(directory=TEMPLATES_DIR)
21
+
22
+ router = APIRouter()
23
+
24
+ UPLOAD_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "uploads")
25
+ RESULTS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "results")
26
+
27
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
28
+ os.makedirs(RESULTS_DIR, exist_ok=True)
29
+
30
+ ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "tiff", "tif"}
31
+
32
+ # Model paths
33
+ DAMAGE_MODEL_PATH = os.path.join("/tmp", "models", "damage", "weights", "weights", "best.pt")
34
+ PARTS_MODEL_PATH = os.path.join("/tmp", "models", "parts", "weights", "weights", "best.pt")
35
+
36
+ # Class names for parts
37
+ PARTS_CLASS_NAMES = ['headlamp', 'front_bumper', 'hood', 'door', 'rear_bumper']
38
+
39
+ # Helper: Run YOLO inference and return results
40
+ def run_yolo_inference(model_path, image_path, task='segment'):
41
+ model = YOLO(model_path)
42
+ results = model.predict(source=image_path, imgsz=640, conf=0.25, save=False, task=task)
43
+ return results[0]
44
+
45
+ # Helper: Draw masks and confidence on image
46
+ def draw_masks_and_conf(image_path, yolo_result, class_names=None):
47
+ img = cv2.imread(image_path)
48
+ overlay = img.copy()
49
+ out_img = img.copy()
50
+ colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255), (0,255,255)]
51
+ for i, box in enumerate(yolo_result.boxes):
52
+ conf = float(box.conf[0])
53
+ cls = int(box.cls[0])
54
+ color = colors[cls % len(colors)]
55
+ # Draw bbox
56
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
57
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), color, 2)
58
+ label = f"{class_names[cls] if class_names else 'damage'}: {conf:.2f}"
59
+ cv2.putText(overlay, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
60
+ # Draw mask if available
61
+ if hasattr(yolo_result, 'masks') and yolo_result.masks is not None:
62
+ mask = yolo_result.masks.data[i].cpu().numpy()
63
+ mask = (mask * 255).astype(np.uint8)
64
+ mask = cv2.resize(mask, (x2-x1, y2-y1))
65
+ roi = overlay[y1:y2, x1:x2]
66
+ colored_mask = np.zeros_like(roi)
67
+ colored_mask[mask > 127] = color
68
+ overlay[y1:y2, x1:x2] = cv2.addWeighted(roi, 0.5, colored_mask, 0.5, 0)
69
+ out_img = cv2.addWeighted(overlay, 0.7, img, 0.3, 0)
70
+ return out_img
71
+
72
+ # Helper: Generate JSON output
73
+ def generate_json_output(filename, damage_result, parts_result):
74
+ # Damage severity: use max confidence
75
+ severity_score = float(max([float(box.conf[0]) for box in damage_result.boxes], default=0))
76
+ damage_regions = []
77
+ for box in damage_result.boxes:
78
+ x1, y1, x2, y2 = map(float, box.xyxy[0])
79
+ conf = float(box.conf[0])
80
+ damage_regions.append({"bbox": [x1, y1, x2, y2], "confidence": conf})
81
+ # Parts
82
+ parts = []
83
+ for i, box in enumerate(parts_result.boxes):
84
+ x1, y1, x2, y2 = map(float, box.xyxy[0])
85
+ conf = float(box.conf[0])
86
+ cls = int(box.cls[0])
87
+ # Damage %: use mask area / bbox area if available
88
+ damage_percentage = None
89
+ if hasattr(parts_result, 'masks') and parts_result.masks is not None:
90
+ mask = parts_result.masks.data[i].cpu().numpy()
91
+ mask_area = np.sum(mask > 0.5)
92
+ bbox_area = (x2-x1)*(y2-y1)
93
+ damage_percentage = float(mask_area / bbox_area) if bbox_area > 0 else None
94
+ parts.append({
95
+ "part": PARTS_CLASS_NAMES[cls] if cls < len(PARTS_CLASS_NAMES) else str(cls),
96
+ "damaged": True,
97
+ "confidence": conf,
98
+ "damage_percentage": damage_percentage,
99
+ "bbox": [x1, y1, x2, y2]
100
+ })
101
+ # Optionally, add base64 masks
102
+ # (not implemented here for brevity)
103
+ return {
104
+ "filename": filename,
105
+ "damage": {
106
+ "severity_score": severity_score,
107
+ "regions": damage_regions
108
+ },
109
+ "parts": parts,
110
+ "cost_estimate": None
111
+ }
112
+
113
+ # Dummy login credentials
114
+ def check_login(username: str, password: str) -> bool:
115
+ return username == "demo" and password == "demo123"
116
+
117
+ @router.get("/", response_class=HTMLResponse)
118
+ def home(request: Request):
119
+ return templates.TemplateResponse("index.html", {"request": request, "result": None})
120
+
121
+ @router.post("/login", response_class=HTMLResponse)
122
+ def login(request: Request, username: str = Form(...), password: str = Form(...)):
123
+ if check_login(username, password):
124
+ return templates.TemplateResponse("index.html", {"request": request, "result": None, "user": username})
125
+ return templates.TemplateResponse("login.html", {"request": request, "error": "Invalid credentials"})
126
+
127
+ @router.get("/login", response_class=HTMLResponse)
128
+ def login_page(request: Request):
129
+ return templates.TemplateResponse("login.html", {"request": request})
130
+
131
+ @router.post("/upload", response_class=HTMLResponse)
132
+ def upload_image(request: Request, file: UploadFile = File(...)):
133
+ ext = file.filename.split(".")[-1].lower()
134
+ if ext not in ALLOWED_EXTENSIONS:
135
+ return templates.TemplateResponse("index.html", {"request": request, "error": "Unsupported file type."})
136
+
137
+ # Save uploaded file
138
+ session_id = str(uuid.uuid4())
139
+ upload_path = os.path.join(UPLOAD_DIR, f"{session_id}.{ext}")
140
+ with open(upload_path, "wb") as buffer:
141
+ shutil.copyfileobj(file.file, buffer)
142
+
143
+ # Run both inferences
144
+ try:
145
+ damage_result = run_yolo_inference(DAMAGE_MODEL_PATH, upload_path)
146
+ parts_result = run_yolo_inference(PARTS_MODEL_PATH, upload_path)
147
+
148
+ # Save annotated images
149
+ damage_img_path = os.path.join(RESULTS_DIR, f"{session_id}_damage.png")
150
+ parts_img_path = os.path.join(RESULTS_DIR, f"{session_id}_parts.png")
151
+ json_path = os.path.join(RESULTS_DIR, f"{session_id}_result.json")
152
+ damage_img_url = f"/static/results/{session_id}_damage.png"
153
+ parts_img_url = f"/static/results/{session_id}_parts.png"
154
+ json_url = f"/static/results/{session_id}_result.json"
155
+
156
+ # Defensive: set to None by default
157
+ damage_img = None
158
+ parts_img = None
159
+ json_output = None
160
+
161
+ # Only save and set if inference returns boxes
162
+ if hasattr(damage_result, 'boxes') and len(damage_result.boxes) > 0:
163
+ damage_img = draw_masks_and_conf(upload_path, damage_result)
164
+ cv2.imwrite(damage_img_path, damage_img)
165
+ if hasattr(parts_result, 'boxes') and len(parts_result.boxes) > 0:
166
+ parts_img = draw_masks_and_conf(upload_path, parts_result, class_names=PARTS_CLASS_NAMES)
167
+ cv2.imwrite(parts_img_path, parts_img)
168
+ if (hasattr(damage_result, 'boxes') and len(damage_result.boxes) > 0) or (hasattr(parts_result, 'boxes') and len(parts_result.boxes) > 0):
169
+ json_output = generate_json_output(file.filename, damage_result, parts_result)
170
+ with open(json_path, "w") as jf:
171
+ json.dump(json_output, jf, indent=2)
172
+
173
+ # Prepare URLs for download (only if files exist)
174
+ result = {
175
+ "filename": file.filename,
176
+ "damage_image": damage_img_url if damage_img is not None else None,
177
+ "parts_image": parts_img_url if parts_img is not None else None,
178
+ "json": json_output,
179
+ "json_download": json_url if json_output is not None else None
180
+ }
181
+ # Debug log
182
+ print("[DEBUG] Result dict:", result)
183
+ except Exception as e:
184
+ result = {
185
+ "filename": file.filename,
186
+ "error": f"Inference failed: {str(e)}",
187
+ "damage_image": None,
188
+ "parts_image": None,
189
+ "json": None,
190
+ "json_download": None
191
+ }
192
+ print("[ERROR] Inference failed:", e)
193
+
194
+ import threading
195
+ import time
196
+ def delayed_cleanup():
197
+ time.sleep(300) # 5 minutes
198
+ try:
199
+ os.remove(upload_path)
200
+ except Exception:
201
+ pass
202
+ for suffix in ["_damage.png", "_parts.png", "_result.json"]:
203
+ try:
204
+ os.remove(os.path.join(RESULTS_DIR, f"{session_id}{suffix}"))
205
+ except Exception:
206
+ pass
207
+
208
+ threading.Thread(target=delayed_cleanup, daemon=True).start()
209
+
210
+ return templates.TemplateResponse(
211
+ "index.html",
212
+ {
213
+ "request": request,
214
+ "result": result,
215
+ "original_image": f"/static/uploads/{session_id}.{ext}"
216
+ }
217
+ )