Spaces:
Running
Running
from flask import Flask, request, jsonify, render_template, url_for | |
from flask_socketio import SocketIO | |
import threading | |
from ultralytics import YOLO | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
import importlib | |
from segment_anything import sam_model_registry, SamPredictor | |
import os | |
from werkzeug.utils import secure_filename | |
import logging | |
import json | |
import shutil | |
import sys | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
app = Flask(__name__) | |
socketio = SocketIO(app) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Configuration | |
class Config: | |
BASE_DIR = os.path.abspath(os.path.dirname(__file__)) | |
UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static', 'uploads') | |
SAM_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'sam','sam_results') | |
YOLO_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','yolo_results') | |
YOLO_TRAIN_IMAGE_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','images') | |
YOLO_TRAIN_LABEL_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','labels') | |
AREA_DATA_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','area_data') | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size | |
SAM_CHECKPOINT = os.path.join(BASE_DIR, 'static', 'sam',"sam_vit_h_4b8939.pth") | |
SAM_2 = os.path.join(BASE_DIR, 'static', 'sam',"sam2.1_hiera_large.pt") | |
YOLO_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_yolo.pt") | |
RETRAINED_MODEL_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_retrained.pt") | |
DATA_PATH = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo', "data.yaml") | |
app.config.from_object(Config) | |
# Ensure directories exist | |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
os.makedirs(app.config['SAM_RESULT_FOLDER'], exist_ok=True) | |
os.makedirs(app.config['YOLO_RESULT_FOLDER'], exist_ok=True) | |
os.makedirs(app.config['YOLO_TRAIN_IMAGE_FOLDER'], exist_ok=True) | |
os.makedirs(app.config['YOLO_TRAIN_LABEL_FOLDER'], exist_ok=True) | |
os.makedirs(app.config['AREA_DATA_FOLDER'], exist_ok=True) | |
# Initialize Yolo model | |
try: | |
model = YOLO(app.config['YOLO_PATH']) | |
except Exception as e: | |
logger.error(f"Failed to initialize YOLO model: {str(e)}") | |
raise | |
try: | |
sam2_checkpoint = app.config['SAM_2'] | |
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu") | |
predictor = SAM2ImagePredictor(sam2_model) | |
except Exception as e: | |
logger.error(f"Failed to initialize SAM model: {str(e)}") | |
raise | |
def allowed_file(filename): | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
def scale_coordinates(coords, original_dims, target_dims): | |
""" | |
Scale coordinates from one dimension space to another. | |
Args: | |
coords: List of [x, y] coordinates | |
original_dims: Tuple of (width, height) of original space | |
target_dims: Tuple of (width, height) of target space | |
Returns: | |
Scaled coordinates | |
""" | |
scale_x = target_dims[0] / original_dims[0] | |
scale_y = target_dims[1] / original_dims[1] | |
return [ | |
[int(coord[0] * scale_x), int(coord[1] * scale_y)] | |
for coord in coords | |
] | |
def scale_box(box, original_dims, target_dims): | |
""" | |
Scale bounding box coordinates from one dimension space to another. | |
Args: | |
box: List of [x1, y1, x2, y2] coordinates | |
original_dims: Tuple of (width, height) of original space | |
target_dims: Tuple of (width, height) of target space | |
Returns: | |
Scaled box coordinates | |
""" | |
scale_x = target_dims[0] / original_dims[0] | |
scale_y = target_dims[1] / original_dims[1] | |
return [ | |
int(box[0] * scale_x), # x1 | |
int(box[1] * scale_y), # y1 | |
int(box[2] * scale_x), # x2 | |
int(box[3] * scale_y) # y2 | |
] | |
def retrain_model_fn(): | |
# Parameters for retraining | |
data_path = app.config['DATA_PATH'] | |
epochs = 5 | |
img_size = 640 | |
batch_size = 8 | |
# Start training with YOLO, using event listeners for epoch completion | |
for epoch in range(epochs): | |
# Train the model for one epoch, here we simulate with a loop | |
model.train( | |
data=data_path, | |
epochs=1, # Use 1 epoch per call to get individual progress | |
imgsz=img_size, | |
batch=batch_size, | |
device="cpu" # Adjust based on system capabilities | |
) | |
# Emit an update to the client after each epoch | |
socketio.emit('training_update', { | |
'epoch': epoch + 1, | |
'status': f"Epoch {epoch + 1} complete" | |
}) | |
# Emit a message once training is complete | |
socketio.emit('training_complete', {'status': "Retraining complete"}) | |
model.save(app.config['YOLO_PATH']) | |
logger.info("Model retrained successfully") | |
def index(): | |
return render_template('index.html') | |
def yolo(): | |
return render_template('yolo.html') | |
def upload_sam_file(): | |
""" | |
Handles SAM image upload and embeds the image into the predictor instance. | |
Returns: | |
JSON response with 'message', 'image_url', 'filename', and 'dimensions' keys | |
on success, or 'error' key with an appropriate error message on failure. | |
""" | |
try: | |
if 'file' not in request.files: | |
return jsonify({'error': 'No file part'}), 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return jsonify({'error': 'No selected file'}), 400 | |
if not allowed_file(file.filename): | |
return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400 | |
filename = secure_filename(file.filename) | |
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
file.save(filepath) | |
# Set the image for predictor right after upload | |
image = cv2.imread(filepath) | |
if image is None: | |
return jsonify({'error': 'Failed to load uploaded image'}), 500 | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
predictor.set_image(image) | |
logger.info("Image embedded successfully") | |
# Get image dimensions | |
height, width = image.shape[:2] | |
image_url = url_for('static', filename=f'uploads/{filename}') | |
logger.info(f"File uploaded successfully: {filepath}") | |
return jsonify({ | |
'message': 'File uploaded successfully', | |
'image_url': image_url, | |
'filename': filename, | |
'dimensions': { | |
'width': width, | |
'height': height | |
} | |
}) | |
except Exception as e: | |
logger.error(f"Upload error: {str(e)}") | |
return jsonify({'error': 'Server error during upload'}), 500 | |
def upload_yolo_file(): | |
""" | |
Upload a YOLO image file | |
This endpoint allows a POST request containing a single image file. The file is | |
saved to the uploads folder and the image is embedded into the YOLO model. | |
Returns a JSON response with the following keys: | |
- message: a success message | |
- image_url: the URL of the uploaded image | |
- filename: the name of the uploaded file | |
If an error occurs, the JSON response will contain an 'error' key with a | |
descriptive error message. | |
""" | |
try: | |
if 'file' not in request.files: | |
return jsonify({'error': 'No file part'}), 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return jsonify({'error': 'No selected file'}), 400 | |
if not allowed_file(file.filename): | |
return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400 | |
filename = secure_filename(file.filename) | |
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
file.save(filepath) | |
image_url = url_for('static', filename=f'uploads/{filename}') | |
logger.info(f"File uploaded successfully: {filepath}") | |
return jsonify({ | |
'message': 'File uploaded successfully', | |
'image_url': image_url, | |
'filename': filename, | |
}) | |
except Exception as e: | |
logger.error(f"Upload error: {str(e)}") | |
return jsonify({'error': 'Server error during upload'}), 500 | |
def generate_mask(): | |
""" | |
Generate a mask for a given image using the YOLO model | |
@param data: a JSON object containing the following keys: | |
- filename: the name of the image file | |
- normalized_void_points: a list of normalized 2D points (x, y) representing the voids | |
- normalized_component_boxes: a list of normalized 2D bounding boxes (x, y, w, h) representing the components | |
@return: a JSON object containing the following keys: | |
- status: a string indicating the status of the request | |
- train_image_url: the URL of the saved train image | |
- result_path: the URL of the saved result image | |
""" | |
try: | |
data = request.json | |
normalized_void_points = data.get('void_points', []) | |
normalized_component_boxes = data.get('component_boxes', []) | |
filename = data.get('filename', '') | |
if not filename: | |
return jsonify({'error': 'No filename provided'}), 400 | |
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
if not os.path.exists(image_path): | |
return jsonify({'error': 'Image file not found'}), 404 | |
# Read image | |
image = cv2.imread(image_path) | |
if image is None: | |
return jsonify({'error': 'Failed to load image'}), 500 | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_height, image_width = image.shape[:2] | |
# Denormalize coordinates back to pixel values | |
void_points = [ | |
[int(point[0] * image_width), int(point[1] * image_height)] | |
for point in normalized_void_points | |
] | |
logger.info(f"Void points: {void_points}") | |
component_boxes = [ | |
[ | |
int(box[0] * image_width), | |
int(box[1] * image_height), | |
int(box[2] * image_width), | |
int(box[3] * image_height) | |
] | |
for box in normalized_component_boxes | |
] | |
logger.info(f"Void points: {void_points}") | |
# Create a list to store individual void masks | |
void_masks = [] | |
# Process void points one by one | |
for point in void_points: | |
# Convert point to correct format: [N, 2] array | |
point_coord = np.array([[point[0], point[1]]]) | |
point_label = np.array([1]) # Single label | |
masks, scores, _ = predictor.predict( | |
point_coords=point_coord, | |
point_labels=point_label, | |
multimask_output=True # Get multiple masks | |
) | |
if len(masks) > 0: # Check if any masks were generated | |
# Get the mask with highest score | |
best_mask_idx = np.argmax(scores) | |
void_masks.append(masks[best_mask_idx]) | |
logger.info(f"Processed void point {point} with score {scores[best_mask_idx]}") | |
# Process component boxes | |
component_masks = [] | |
if component_boxes: | |
for box in component_boxes: | |
# Convert box to correct format: [2, 2] array | |
box_np = np.array([[box[0], box[1]], [box[2], box[3]]]) | |
masks, scores, _ = predictor.predict( | |
box=box_np, | |
multimask_output=True | |
) | |
if len(masks) > 0: | |
best_mask_idx = np.argmax(scores) | |
component_masks.append(masks[best_mask_idx]) | |
logger.info(f"Processed component box {box}") | |
# Create visualization with different colors for each void | |
combined_image = image.copy() | |
# Font settings for labels | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = 0.6 | |
font_color = (0,0,0) # White text color | |
font_thickness = 1 | |
background_color = (255, 255, 255) # White background for text | |
# Helper function to get bounding box coordinates | |
def get_bounding_box(mask): | |
coords = np.column_stack(np.where(mask)) | |
x_min, y_min = coords.min(axis=0) | |
x_max, y_max = coords.max(axis=0) | |
return (x_min, y_min, x_max, y_max) | |
# Helper function to add text with background | |
def put_text_with_background(img, text, pos): | |
# Calculate text size | |
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness) | |
# Define the rectangle coordinates for background | |
background_tl = (pos[0], pos[1] - text_h - 2) | |
background_br = (pos[0] + text_w, pos[1] + 2) | |
# Draw white rectangle as background | |
cv2.rectangle(img, background_tl, background_br, background_color, -1) | |
# Put the text over the background rectangle | |
cv2.putText(img, text, pos, font, font_scale, font_color, font_thickness, cv2.LINE_AA) | |
def get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, img_width, img_height): | |
# Default to top-right of bounding box | |
x_pos = min(y_max, img_width - text_w - 10) # Keep 10px margin from the right | |
y_pos = max(x_min + text_h + 5, text_h + 5) # Keep 5px margin from the top | |
return x_pos, y_pos | |
# Apply void masks with different colors | |
for mask in void_masks: | |
mask = mask.astype(bool) | |
combined_image[mask, 0] = np.clip(0.5 * image[mask, 0] + 0.5 * 255, 0, 255) # Red channel with transparency | |
combined_image[mask, 1] = np.clip(0.5 * image[mask, 1], 0, 255) # Green channel reduced | |
combined_image[mask, 2] = np.clip(0.5 * image[mask, 2], 0, 255) | |
logger.info("Mask Drawn") | |
# Apply component masks in green | |
for mask in component_masks: | |
mask = mask.astype(bool) | |
# Only apply green where there is no red overlay | |
non_red_area = mask & ~np.any([void_mask for void_mask in void_masks], axis=0) | |
combined_image[non_red_area, 0] = np.clip(0.5 * image[non_red_area, 0], 0, 255) # Reduced red channel | |
combined_image[non_red_area, 1] = np.clip(0.5 * image[non_red_area, 1] + 0.5 * 255, 0, 255) # Green channel | |
combined_image[non_red_area, 2] = np.clip(0.5 * image[non_red_area, 2], 0, 255) | |
logger.info("Mask Drawn") | |
# Add labels on top of masks | |
for i,mask in enumerate(void_masks): | |
x_min, y_min, x_max, y_max = get_bounding_box(mask) | |
(text_w, text_h), _ = cv2.getTextSize("Void", font, font_scale, font_thickness) | |
label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0]) | |
put_text_with_background(combined_image, f"Void {i+1}", label_position) | |
for i,mask in enumerate(component_masks): | |
i=i+1 | |
x_min, y_min, x_max, y_max = get_bounding_box(mask) | |
(text_w, text_h), _ = cv2.getTextSize("Component", font, font_scale, font_thickness) | |
label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0]) | |
put_text_with_background(combined_image, f"Component {i}", label_position) | |
# Prepare an empty list to store the output in the required format | |
mask_coordinates = [] | |
for mask in void_masks: | |
# Get contours from the mask | |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
# Image dimensions | |
height, width = mask.shape | |
# For each contour, extract the normalized coordinates | |
for contour in contours: | |
contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points | |
normalized_points = contour_points / [width, height] # Normalize to (0, 1) | |
class_id = 1 # 1 for voids | |
row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class | |
mask_coordinates.append(row) | |
for mask in component_masks: | |
# Get contours from the mask | |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
# Filter to keep only the largest contour | |
contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
largest_contour = [contours[0]] if contours else [] | |
# Image dimensions | |
height, width = mask.shape | |
# For each contour, extract the normalized coordinates | |
for contour in largest_contour: | |
contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points | |
normalized_points = contour_points / [width, height] # Normalize to (0, 1) | |
class_id = 0 # for components | |
row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class | |
mask_coordinates.append(row) | |
mask_coordinates_filename = f'{filename}.txt' # Create a unique filename | |
mask_coordinates_path = os.path.join(app.config['YOLO_TRAIN_LABEL_FOLDER'], mask_coordinates_filename) | |
with open(mask_coordinates_path, "w") as file: | |
for row in mask_coordinates: | |
# Join elements of the row into a string with spaces in between and write to the file | |
file.write(" ".join(map(str, row)) + "\n") | |
# Save train image | |
train_image_filepath = os.path.join(app.config['YOLO_TRAIN_IMAGE_FOLDER'], filename) | |
shutil.copy(image_path, train_image_filepath) | |
train_image_url = url_for('static', filename=f'yolo/dataset_yolo/train/images/{filename}') | |
# Save result | |
result_filename = f'segmented_{filename}' | |
result_path = os.path.join(app.config['SAM_RESULT_FOLDER'], result_filename) | |
plt.imsave(result_path, combined_image) | |
logger.info("Mask generation completed successfully") | |
return jsonify({ | |
'status': 'success', | |
'train_image_url':train_image_url, | |
'result_path': url_for('static', filename=f'sam/sam_results/{result_filename}') | |
}) | |
except Exception as e: | |
logger.error(f"Mask generation error: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def classify(): | |
""" | |
Classify an image and return the classification result, area data, and the annotated image. | |
Request body should contain a JSON object with a single key 'filename' specifying the image file to be classified. | |
Returns a JSON object with the following keys: | |
- status: 'success' if the classification is successful, 'error' if there is an error. | |
- result_path: URL of the annotated image. | |
- area_data: a list of dictionaries containing the area and overlap statistics for each component. | |
- area_data_path: URL of the JSON file containing the area data. | |
If there is an error, returns a JSON object with a single key 'error' containing the error message. | |
""" | |
try: | |
data = request.json | |
filename = data.get('filename', '') | |
if not filename: | |
return jsonify({'error': 'No filename provided'}), 400 | |
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
if not os.path.exists(image_path): | |
return jsonify({'error': 'Image file not found'}), 404 | |
# Read image | |
image = cv2.imread(image_path) | |
if image is None: | |
return jsonify({'error': 'Failed to load image'}), 500 | |
results = model(image) | |
result = results[0] | |
component_masks = [] | |
void_masks = [] | |
# Extract masks and labels from results | |
for mask, label in zip(result.masks.data, result.boxes.cls): | |
mask_array = mask.cpu().numpy().astype(bool) # Convert to a binary mask (boolean array) | |
if label == 1: # Assuming label '1' represents void | |
void_masks.append(mask_array) | |
elif label == 0: # Assuming label '0' represents component | |
component_masks.append(mask_array) | |
# Calculate area and overlap statistics | |
area_data = [] | |
for i, component_mask in enumerate(component_masks): | |
component_area = np.sum(component_mask).item() # Total component area in pixels | |
void_area_within_component = 0 | |
max_void_area_percentage = 0 | |
# Calculate overlap of each void mask with the component mask | |
for void_mask in void_masks: | |
overlap_area = np.sum(void_mask & component_mask).item() # Overlapping area | |
void_area_within_component += overlap_area | |
void_area_percentage = (overlap_area / component_area) * 100 if component_area > 0 else 0 | |
max_void_area_percentage = max(max_void_area_percentage, void_area_percentage) | |
# Append data for this component | |
area_data.append({ | |
"Image": filename, | |
'Component': f'Component {i+1}', | |
'Area': component_area, | |
'Void Area (pixels)': void_area_within_component, | |
'Void Area %': void_area_within_component / component_area * 100 if component_area > 0 else 0, | |
'Max Void Area %': max_void_area_percentage | |
}) | |
area_data_filename = f'area_data_{filename.split("/")[-1]}.json' # Create a unique filename | |
area_data_path = os.path.join(app.config['AREA_DATA_FOLDER'], area_data_filename) | |
with open(area_data_path, 'w') as json_file: | |
json.dump(area_data, json_file, indent=4) | |
annotated_image = result.plot() | |
output_filename = f'output_{filename}' | |
output_image_path = os.path.join(app.config['YOLO_RESULT_FOLDER'], output_filename) | |
plt.imsave(output_image_path, annotated_image) | |
logger.info("Classification completed successfully") | |
return jsonify({ | |
'status': 'success', | |
'result_path': url_for('static', filename=f'yolo/yolo_results/{output_filename}'), | |
'area_data': area_data, | |
'area_data_path': url_for('static', filename=f'yolo/area_data/{area_data_filename}') | |
}) | |
except Exception as e: | |
logger.error(f"Classification error: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
retraining_status = { | |
'status': 'idle', | |
'progress': None, | |
'message': None | |
} | |
def start_retraining(): | |
""" | |
Start the model retraining process. | |
If the request is a POST, start the model retraining process in a separate thread. | |
If the request is a GET, render the retraining page. | |
Returns: | |
A JSON response with the status of the retraining process, or a rendered HTML page. | |
""" | |
if request.method == 'POST': | |
# Reset status | |
global retraining_status | |
retraining_status['status'] = 'in_progress' | |
retraining_status['progress'] = 'Initializing' | |
# Start retraining in a separate thread | |
threading.Thread(target=retrain_model_fn).start() | |
return jsonify({'status': 'started'}) | |
else: | |
# GET request - render the retraining page | |
return render_template('retrain.html') | |
# Event handler for client connection | |
def handle_connect(): | |
print('Client connected') | |
if __name__ == '__main__': | |
app.run(port=5001, debug=True) |