Spaces:
Sleeping
Sleeping
from flask import Flask, request, send_file, Response, jsonify | |
from flask_cors import CORS | |
import numpy as np | |
import io | |
import torch | |
import cv2 | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
from PIL import Image | |
import zipfile | |
app = Flask(__name__) | |
CORS(app) | |
cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu" | |
print(cudaOrNah) | |
# Global model setup | |
# Adjusted due to memory constraints | |
# checkpoint = "sam_vit_h_4b8939.pth" | |
# model_type = "vit_h" | |
checkpoint = "sam_vit_l_0b3195.pth" | |
model_type = "vit_l" | |
sam = sam_model_registry[model_type](checkpoint=checkpoint) | |
sam.to(device=cudaOrNah) | |
mask_generator = SamAutomaticMaskGenerator( | |
model=sam, | |
min_mask_region_area=0.0015 # Adjust this value as needed | |
) | |
print('Setup SAM model') | |
def hello(): | |
return {"hei": "Shredded to pieces"} | |
def health_check(): | |
# Simple health check endpoint | |
return jsonify({"status": "ok"}), 200 | |
def get_masks(): | |
try: | |
print('received image from frontend') | |
# Get the image file from the request | |
if 'image' not in request.files: | |
return jsonify({"error": "No image file provided"}), 400 | |
image_file = request.files['image'] | |
if image_file.filename == '': | |
return jsonify({"error": "No image file provided"}), 400 | |
raw_image = Image.open(image_file).convert("RGB") | |
# Convert the PIL Image to a NumPy array | |
image_array = np.array(raw_image) | |
# Since OpenCV expects BGR, convert RGB to BGR | |
image = image_array[:, :, ::-1] | |
if image is None: | |
raise ValueError("Image not found or unable to read.") | |
if cudaOrNah == "cuda": | |
torch.cuda.empty_cache() | |
masks = mask_generator.generate(image) | |
if cudaOrNah == "cuda": | |
torch.cuda.empty_cache() | |
# Sort masks by area in descending order | |
masks = sorted(masks, key=lambda x: x['area'], reverse=True) | |
# Initialize a cumulative mask to keep track of covered areas | |
cumulative_mask = np.zeros_like(masks[0]['segmentation'], dtype=bool) | |
# Process masks to remove overlaps | |
for mask in masks: | |
# Subtract areas already covered | |
mask['segmentation'] = np.logical_and( | |
mask['segmentation'], np.logical_not(cumulative_mask) | |
) | |
# Update the cumulative mask | |
cumulative_mask = np.logical_or(cumulative_mask, mask['segmentation']) | |
# Update the area | |
mask['area'] = mask['segmentation'].sum() | |
# Remove masks with zero area | |
masks = [mask for mask in masks if mask['area'] > 0] | |
# (Optional) Remove background masks if needed | |
def is_background(segmentation): | |
val = (segmentation[10, 10] or segmentation[-10, 10] or | |
segmentation[10, -10] or segmentation[-10, -10]) | |
return val | |
masks = [mask for mask in masks if not is_background(mask['segmentation'])] | |
# Create a zip file in memory | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
for idx, mask in enumerate(masks): | |
alpha = mask['segmentation'].astype('uint8') * 255 | |
mask_image = Image.fromarray(alpha) | |
mask_io = io.BytesIO() | |
mask_image.save(mask_io, format="PNG") | |
mask_io.seek(0) | |
zip_file.writestr(f'mask_{idx+1}.png', mask_io.read()) | |
zip_buffer.seek(0) | |
return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip') | |
except Exception as e: | |
# Log the error message if needed | |
print(f"Error processing the image: {e}") | |
# Return a JSON response with the error message and a 400 Bad Request status | |
return jsonify({"error": "Error processing the image", "details": str(e)}), 400 | |
if __name__ == '__main__': | |
app.run(debug=True) | |