Ani14 commited on
Commit
97562fa
·
verified ·
1 Parent(s): ecbbf4c

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +97 -145
predict.py CHANGED
@@ -1,57 +1,37 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Response
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
  import tensorflow as tf
6
- import io
7
  from typing import Union
8
 
9
- # --- Configuration ---
10
- PIXELS_PER_CM = 50.0
11
-
12
- # --- App Initialization ---
13
- app = FastAPI(
14
- title="Wound Analysis API",
15
- description="An API to analyze wound images and return an annotated image with data in headers.",
16
- version="3.4.0" # Version updated for prediction output fix
17
- )
18
-
19
- # --- Model Loading ---
20
- def load_models():
21
- segmentation_model, yolo_model = None, None
22
- try:
23
- segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
24
- print("Segmentation model 'segmentation model.h5' loaded successfully.")
25
- except Exception as e:
26
- print(f"Warning: Could not load segmentation model. Using fallback. Error: {e}")
27
 
28
- try:
29
- yolo_model = YOLO("best.pt")
30
- print("YOLO model 'best.pt' loaded successfully.")
31
- except Exception as e:
32
- print(f"Warning: Could not load YOLO model. Using fallback. Error: {e}")
33
-
34
- return segmentation_model, yolo_model
35
 
36
- segmentation_model, yolo_model = load_models()
 
 
 
 
 
37
 
38
- # --- Helper Functions ---
 
 
 
39
 
 
40
  def preprocess_image(image: np.ndarray) -> np.ndarray:
41
- img_denoised = cv2.medianBlur(image, 3)
42
- lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
43
- l_channel, a_channel, b_channel = cv2.split(lab)
44
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
45
- l_clahe = clahe.apply(l_channel)
46
- lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
47
- img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
48
- gamma = 1.2
49
- img_float = img_clahe.astype(np.float32) / 255.0
50
- img_gamma = np.power(img_float, gamma)
51
- return (img_gamma * 255).astype(np.uint8)
52
 
53
- def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
54
- if not yolo_model: return None
55
  try:
56
  results = yolo_model.predict(image, verbose=False)
57
  if results and results[0].boxes:
@@ -59,122 +39,94 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
59
  coords = best_box.xyxy[0].cpu().numpy()
60
  return tuple(map(int, coords))
61
  except Exception as e:
62
- print(f"YOLO prediction failed: {e}")
63
- return None
64
-
65
- def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
66
- if not segmentation_model:
67
- return None
68
- try:
69
- input_shape = segmentation_model.input_shape[1:3]
70
- img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
71
- img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
72
-
73
- prediction = segmentation_model.predict(img_norm, verbose=0)
74
-
75
- # FIX: Handle nested list output or Tensor
76
- while isinstance(prediction, list):
77
- prediction = prediction[0]
78
- if isinstance(prediction, tf.Tensor):
79
- prediction = prediction.numpy()
80
-
81
- pred_mask = prediction[0]
82
- pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
83
- return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
84
- except Exception as e:
85
- print(f"Segmentation model prediction failed: {e}")
86
  return None
87
 
88
-
89
- def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
90
- pixels = image.reshape((-1, 3)).astype(np.float32)
91
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
92
- _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
93
- centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
94
- wound_cluster_idx = np.argmax(centers_lab[:, 1])
95
- mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
96
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
97
- if contours:
98
- largest_contour = max(contours, key=cv2.contourArea)
99
- refined_mask = np.zeros_like(mask)
100
- cv2.drawContours(refined_mask, [largest_contour], -1, 255, cv2.FILLED)
101
- return refined_mask
102
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
105
- wound_pixels = cv2.countNonZero(mask)
106
- if wound_pixels == 0:
107
- return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_score": 0.0, "moisture_score": 0.0}
108
-
109
- area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
110
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
111
- largest_contour = max(contours, key=cv2.contourArea)
112
- (_, (width, height), _) = cv2.minAreaRect(largest_contour)
113
- length_cm = max(width, height) / PIXELS_PER_CM
114
- breadth_cm = min(width, height) / PIXELS_PER_CM
115
-
116
- mask_bool = mask.astype(bool)
117
- lab_img = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
118
- gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
119
-
120
- mean_a = np.mean(lab_img[:, :, 1][mask_bool])
121
- depth_score = mean_a - 128.0
122
- texture_std = np.std(gray_img[mask_bool])
123
- moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
124
-
125
- return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_score": depth_score, "moisture_score": moisture_score}
 
 
 
 
126
 
127
- def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
128
- if cv2.countNonZero(mask) == 0: return image
129
- dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
130
- cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
131
- overlay = np.zeros_like(image)
132
- overlay[dist >= 0.66] = (0, 0, 255)
133
- overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
134
- overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
135
- blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
136
- annotated_img = image.copy()
137
- annotated_img[mask.astype(bool)] = blended[mask.astype(bool)]
138
- return annotated_img
139
 
140
- # --- Main API Endpoint ---
141
  @app.post("/analyze_wound")
142
  async def analyze_wound(file: UploadFile = File(...)):
143
  contents = await file.read()
144
- image_array = np.frombuffer(contents, np.uint8)
145
- original_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
146
- if original_image is None:
147
- raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
148
-
149
- processed_image = preprocess_image(original_image)
150
- bbox = detect_wound_region_yolo(processed_image)
 
 
 
 
 
151
  if bbox:
152
- xmin, ymin, xmax, ymax = bbox
153
- cropped_image = processed_image[ymin:ymax, xmin:xmax]
154
- else:
155
- cropped_image = processed_image
156
-
157
- mask = segment_wound_with_model(cropped_image)
158
- if mask is None:
159
- mask = segment_wound_with_fallback(cropped_image)
160
-
161
- metrics = calculate_metrics(mask, cropped_image)
162
- full_mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
163
- if bbox:
164
- full_mask[ymin:ymax, xmin:xmax] = mask
165
  else:
166
  full_mask = mask
167
-
168
- annotated_image = create_visual_overlay(original_image, full_mask)
169
- success, png_data = cv2.imencode(".png", annotated_image)
170
- if not success:
171
- raise HTTPException(status_code=500, detail="Failed to encode output image")
172
 
173
- headers = {
174
- "Wound-Area-cm2": f"{metrics['area_cm2']:.2f}",
175
- "Wound-Length-cm": f"{metrics['length_cm']:.2f}",
176
- "Wound-Breadth-cm": f"{metrics['breadth_cm']:.2f}",
177
- "Wound-Depth-Score": f"{metrics['depth_score']:.1f}",
178
- "Wound-Moisture-Score": f"{metrics['moisture_score']:.0f}"
179
- }
180
- return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Response, HTTPException
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
  import tensorflow as tf
6
+ import os
7
  from typing import Union
8
 
9
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ PIXELS_PER_CM = 50.0
 
 
 
 
 
 
12
 
13
+ # --- Model loading ---
14
+ segmentation_model, yolo_model = None, None
15
+ try:
16
+ segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
17
+ except Exception as e:
18
+ print(f"Segmentation model not loaded: {e}")
19
 
20
+ try:
21
+ yolo_model = YOLO("best.pt")
22
+ except Exception as e:
23
+ print(f"YOLO model not loaded: {e}")
24
 
25
+ # --- Helpers ---
26
  def preprocess_image(image: np.ndarray) -> np.ndarray:
27
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
28
+ l, a, b = cv2.split(lab)
 
29
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
30
+ cl = clahe.apply(l)
31
+ limg = cv2.merge((cl, a, b))
32
+ return cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
 
 
 
 
33
 
34
+ def detect_with_yolo(image: np.ndarray) -> Union[tuple, None]:
 
35
  try:
36
  results = yolo_model.predict(image, verbose=False)
37
  if results and results[0].boxes:
 
39
  coords = best_box.xyxy[0].cpu().numpy()
40
  return tuple(map(int, coords))
41
  except Exception as e:
42
+ print(f"YOLO error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return None
44
 
45
+ def fallback_segmentation(image: np.ndarray) -> np.ndarray:
46
+ Z = image.reshape((-1, 3)).astype(np.float32)
 
47
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
48
+ _, label, center = cv2.kmeans(Z, 2, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
49
+ label = label.reshape(image.shape[:2])
50
+ unique_vals = np.unique(label)
51
+ if len(unique_vals) > 1:
52
+ wound_label = np.argmax([np.sum(label == val) for val in unique_vals])
53
+ else:
54
+ wound_label = unique_vals[0]
55
+ return (label == wound_label).astype(np.uint8) * 255
56
+
57
+ def segment(image: np.ndarray) -> np.ndarray:
58
+ if segmentation_model is not None:
59
+ try:
60
+ input_shape = segmentation_model.input.shape[1:3]
61
+ resized = cv2.resize(image, (input_shape[1], input_shape[0]))
62
+ norm = np.expand_dims(resized / 255.0, axis=0)
63
+ prediction = segmentation_model.predict(norm)
64
+ if isinstance(prediction, list):
65
+ prediction = prediction[0]
66
+ mask = (prediction[0].squeeze() >= 0.5).astype(np.uint8) * 255
67
+ return cv2.resize(mask, (image.shape[1], image.shape[0]))
68
+ except Exception as e:
69
+ print(f"Segmentation model failed: {e}")
70
+ return fallback_segmentation(image)
71
 
72
  def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
73
+ area_px = cv2.countNonZero(mask)
74
+ area_cm2 = area_px / (PIXELS_PER_CM ** 2)
 
 
 
75
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
76
+ if not contours:
77
+ return {"length": 0, "breadth": 0, "area": 0, "depth": 0, "moisture": 0}
78
+ c = max(contours, key=cv2.contourArea)
79
+ rect = cv2.minAreaRect(c)
80
+ length, breadth = max(rect[1]) / PIXELS_PER_CM, min(rect[1]) / PIXELS_PER_CM
81
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
82
+ texture_std = np.std(gray[mask.astype(bool)])
83
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
84
+ mean_a = np.mean(lab[:, :, 1][mask.astype(bool)])
85
+ depth = mean_a - 128
86
+ moisture = max(0, 100 * (1.0 - texture_std / 127.0))
87
+ return {
88
+ "area": area_cm2,
89
+ "length": length,
90
+ "breadth": breadth,
91
+ "depth": depth,
92
+ "moisture": moisture,
93
+ "contour": c
94
+ }
95
 
96
+ def annotate(image: np.ndarray, mask: np.ndarray, contour) -> np.ndarray:
97
+ poly_image = image.copy()
98
+ if contour is not None:
99
+ cv2.drawContours(poly_image, [contour], -1, (0, 255, 0), 2)
100
+ return poly_image
 
 
 
 
 
 
 
101
 
102
+ # --- API ---
103
  @app.post("/analyze_wound")
104
  async def analyze_wound(file: UploadFile = File(...)):
105
  contents = await file.read()
106
+ arr = np.frombuffer(contents, np.uint8)
107
+ image = cv2.imdecode(arr, cv2.IMREAD_COLOR)
108
+ if image is None:
109
+ raise HTTPException(status_code=400, detail="Invalid image")
110
+
111
+ image = preprocess_image(image)
112
+ bbox = detect_with_yolo(image)
113
+ cropped = image[bbox[1]:bbox[3], bbox[0]:bbox[2]] if bbox else image
114
+ mask = segment(cropped)
115
+
116
+ metrics = calculate_metrics(mask, cropped)
117
+ full_mask = np.zeros(image.shape[:2], dtype=np.uint8)
118
  if bbox:
119
+ full_mask[bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
 
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
  full_mask = mask
 
 
 
 
 
122
 
123
+ final_image = annotate(image, full_mask, metrics['contour'])
124
+ _, buf = cv2.imencode(".png", final_image)
125
+
126
+ response = Response(content=buf.tobytes(), media_type="image/png")
127
+ response.headers['X-Length-Cm'] = str(metrics['length'])
128
+ response.headers['X-Breadth-Cm'] = str(metrics['breadth'])
129
+ response.headers['X-Depth-Cm'] = str(metrics['depth'])
130
+ response.headers['X-Area-Cm2'] = str(metrics['area'])
131
+ response.headers['X-Moisture'] = str(metrics['moisture'])
132
+ return response