wound_detect / predict.py
Ani14's picture
Update predict.py
ab3691c verified
raw
history blame
4.83 kB
from fastapi import FastAPI, File, UploadFile, Response
from fastapi.responses import FileResponse
from tensorflow.keras.preprocessing.image import img_to_array
import tensorflow as tf
import cv2
import numpy as np
import os
from scipy import fftpack
from scipy import ndimage
from ultralytics import YOLO
from PIL import Image
import io
app = FastAPI()
uploads_dir = 'uploads'
if not os.path.exists(uploads_dir):
os.makedirs(uploads_dir)
# Load the saved models
segmentation_model_path = 'segmentation_model.h5'
segmentation_model = tf.keras.models.load_model(segmentation_model_path)
yolo_model_path = 'best.pt'
yolo_model = YOLO(yolo_model_path)
def calculate_moisture_and_texture(image):
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
fft_image = fftpack.fft2(gray_image)
fft_shifted = fftpack.fftshift(fft_image)
magnitude_spectrum = 20 * np.log(np.abs(fft_shifted))
height, width = magnitude_spectrum.shape
center_x, center_y = width // 2, height // 2
radius = min(center_x, center_y) // 2
moisture_region = magnitude_spectrum[center_y - radius:center_y + radius, center_x - radius:center_x + radius]
moisture_level = np.mean(moisture_region)
return moisture_level
def calculate_wound_dimensions(mask):
labeled_mask, num_labels = ndimage.label(mask > 0.5)
label_count = np.bincount(labeled_mask.ravel())
wound_label = np.argmax(label_count[1:]) + 1
wound_region = labeled_mask == wound_label
rows = np.any(wound_region, axis=1)
cols = np.any(wound_region, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
length_pixels = rmax - rmin
breadth_pixels = cmax - cmin
pixel_to_cm_ratio = 0.1
length_cm = length_pixels * pixel_to_cm_ratio
breadth_cm = breadth_pixels * pixel_to_cm_ratio
depth_cm = np.mean(mask[wound_region]) * pixel_to_cm_ratio
length_cm = round(length_cm, 3)
breadth_cm = round(breadth_cm, 3)
depth_cm = round(depth_cm, 3)
area_cm2 = length_cm * breadth_cm
return length_cm, breadth_cm, depth_cm, area_cm2
@app.post("/analyze_wound")
async def analyze_wounds(file: UploadFile = File(...)):
if file.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
contents = await file.read()
nparr = np.fromstring(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
results = yolo_model(img)
combined_xmin = float('inf')
combined_ymin = float('inf')
combined_xmax = float('-inf')
combined_ymax = float('-inf')
for detection in results[0].boxes.xyxy.tolist():
xmin, ymin, xmax, ymax = detection
combined_xmin = min(combined_xmin, xmin)
combined_ymin = min(combined_ymin, ymin)
combined_xmax = max(combined_xmax, xmax)
combined_ymax = max(combined_ymax, ymax)
combined_xmin = int(combined_xmin)
combined_ymin = int(combined_ymin)
combined_xmax = int(combined_xmax)
combined_ymax = int(combined_ymax)
combined_img = img[combined_ymin:combined_ymax, combined_xmin:combined_xmax]
combined_img_resized = cv2.resize(combined_img, (224, 224))
img_array = img_to_array(combined_img_resized) / 255.0
img_array = np.expand_dims(img_array, axis=0)
output = segmentation_model.predict(img_array)
predicted_mask = output[0]
mask_overlay = (predicted_mask.squeeze() * 255).astype(np.uint8)
mask_overlay_colored = np.zeros((mask_overlay.shape[0], mask_overlay.shape[1], 3), dtype=np.uint8)
mask_overlay_colored[mask_overlay > 200] = [255, 0, 0] # Red
mask_overlay_colored[(mask_overlay > 100) & (mask_overlay <= 200)] = [0, 255, 0] # Green
mask_overlay_colored[mask_overlay <= 100] = [0, 0, 255] # Blue
mask_overlay_colored = cv2.resize(mask_overlay_colored, (224, 224))
blended_image = cv2.addWeighted(combined_img_resized.astype(np.uint8), 0.6, mask_overlay_colored, 0.4, 0)
segmented_image = Image.fromarray(cv2.cvtColor(blended_image, cv2.COLOR_BGR2RGB))
img_byte_arr = io.BytesIO()
segmented_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
length_cm, breadth_cm, depth_cm, area_cm2 = calculate_wound_dimensions(predicted_mask)
moisture = calculate_moisture_and_texture(combined_img)
response = Response(img_byte_arr.getvalue(), media_type='image/png')
response.headers['X-Length-Cm'] = str(length_cm)
response.headers['X-Breadth-Cm'] = str(breadth_cm)
response.headers['X-Depth-Cm'] = str(depth_cm)
response.headers['X-Area-Cm2'] = str(area_cm2)
response.headers['X-Moisture'] = str(moisture)
return response
return {'error': 'Invalid file format'}