wound_detect / predict.py
Pamudu13's picture
Update predict.py
801ff85 verified
raw
history blame contribute delete
6.34 kB
from fastapi import FastAPI, File, UploadFile, Response
from fastapi.responses import FileResponse
from tensorflow.keras.preprocessing.image import img_to_array
import tensorflow as tf
import cv2
import numpy as np
import os
from scipy import fftpack
from scipy import ndimage
from ultralytics import YOLO
from PIL import Image
import io
import threading
live_view_running = False
app = FastAPI()
uploads_dir = 'uploads'
if not os.path.exists(uploads_dir):
os.makedirs(uploads_dir)
# Load the saved models
segmentation_model_path = 'segmentation_model.h5'
segmentation_model = tf.keras.models.load_model(segmentation_model_path)
yolo_model_path = 'best.pt'
yolo_model = YOLO(yolo_model_path)
def calculate_moisture_and_texture(image):
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
fft_image = fftpack.fft2(gray_image)
fft_shifted = fftpack.fftshift(fft_image)
magnitude_spectrum = 20 * np.log(np.abs(fft_shifted))
height, width = magnitude_spectrum.shape
center_x, center_y = width // 2, height // 2
radius = min(center_x, center_y) // 2
moisture_region = magnitude_spectrum[center_y - radius:center_y + radius, center_x - radius:center_x + radius]
moisture_level = np.mean(moisture_region)
return moisture_level
def calculate_wound_dimensions(mask):
labeled_mask, num_labels = ndimage.label(mask > 0.5)
label_count = np.bincount(labeled_mask.ravel())
wound_label = np.argmax(label_count[1:]) + 1
wound_region = labeled_mask == wound_label
rows = np.any(wound_region, axis=1)
cols = np.any(wound_region, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
length_pixels = rmax - rmin
breadth_pixels = cmax - cmin
pixel_to_cm_ratio = 0.1
length_cm = length_pixels * pixel_to_cm_ratio
breadth_cm = breadth_pixels * pixel_to_cm_ratio
depth_cm = np.mean(mask[wound_region]) * pixel_to_cm_ratio
length_cm = round(length_cm, 3)
breadth_cm = round(breadth_cm, 3)
depth_cm = round(depth_cm, 3)
area_cm2 = length_cm * breadth_cm
return length_cm, breadth_cm, depth_cm, area_cm2
# Draw YOLO detection landmarks (bounding boxes) on the image
def draw_square_landmarks(frame):
results = yolo_model(frame)[0]
for box in results.boxes.xyxy.tolist():
x1, y1, x2, y2 = map(int, box)
w = x2 - x1
h = y2 - y1
side = max(w, h)
cx = x1 + w // 2
cy = y1 + h // 2
new_x1 = max(cx - side // 2, 0)
new_y1 = max(cy - side // 2, 0)
new_x2 = new_x1 + side
new_y2 = new_y1 + side
cv2.rectangle(frame, (new_x1, new_y1), (new_x2, new_y2), (0, 255, 0), 2)
return frame
@app.post("/analyze_wound")
async def analyze_wounds(file: UploadFile = File(...)):
if file.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8) # safer than np.fromstring
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
results = yolo_model(img)
img = draw_square_landmarks(img) # ✅ Add landmarks to the original image
combined_xmin = float('inf')
combined_ymin = float('inf')
combined_xmax = float('-inf')
combined_ymax = float('-inf')
for detection in results[0].boxes.xyxy.tolist():
xmin, ymin, xmax, ymax = detection
combined_xmin = min(combined_xmin, xmin)
combined_ymin = min(combined_ymin, ymin)
combined_xmax = max(combined_xmax, xmax)
combined_ymax = max(combined_ymax, ymax)
combined_xmin = int(combined_xmin)
combined_ymin = int(combined_ymin)
combined_xmax = int(combined_xmax)
combined_ymax = int(combined_ymax)
combined_img = img[combined_ymin:combined_ymax, combined_xmin:combined_xmax]
combined_img_resized = cv2.resize(combined_img, (224, 224))
img_array = img_to_array(combined_img_resized) / 255.0
img_array = np.expand_dims(img_array, axis=0)
output = segmentation_model.predict(img_array)
predicted_mask = output[0]
mask_overlay = (predicted_mask.squeeze() * 255).astype(np.uint8)
mask_overlay_colored = np.zeros((mask_overlay.shape[0], mask_overlay.shape[1], 3), dtype=np.uint8)
mask_overlay_colored[mask_overlay > 200] = [255, 0, 0] # Red
mask_overlay_colored[(mask_overlay > 100) & (mask_overlay <= 200)] = [0, 255, 0] # Green
mask_overlay_colored[mask_overlay <= 100] = [0, 0, 255] # Blue
mask_overlay_colored = cv2.resize(mask_overlay_colored, (224, 224))
blended_image = cv2.addWeighted(combined_img_resized.astype(np.uint8), 0.6, mask_overlay_colored, 0.4, 0)
segmented_image = Image.fromarray(cv2.cvtColor(blended_image, cv2.COLOR_BGR2RGB))
img_byte_arr = io.BytesIO()
segmented_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
length_cm, breadth_cm, depth_cm, area_cm2 = calculate_wound_dimensions(predicted_mask)
moisture = calculate_moisture_and_texture(combined_img)
response = Response(img_byte_arr.getvalue(), media_type='image/png')
response.headers['X-Length-Cm'] = str(length_cm)
response.headers['X-Breadth-Cm'] = str(breadth_cm)
response.headers['X-Depth-Cm'] = str(depth_cm)
response.headers['X-Area-Cm2'] = str(area_cm2)
response.headers['X-Moisture'] = str(moisture)
return response
return {'error': 'Invalid file format'}
def start_camera():
global live_view_running
cap = cv2.VideoCapture(0)
live_view_running = True
while live_view_running:
ret, frame = cap.read()
if not ret:
break
frame = draw_square_landmarks(frame)
cv2.imshow('Live Landmarks - Press Q to stop', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
live_view_running = False
break
cap.release()
cv2.destroyAllWindows()
@app.get("/live_landmarks")
def live_camera_with_landmarks():
if not live_view_running:
threading.Thread(target=start_camera).start()
return {"message": "Live camera started. Check your system's display window."}
else:
return {"message": "Live camera already running."}