Spaces:
Sleeping
Sleeping
File size: 4,186 Bytes
60e5cfd 84fd7cf 60e5cfd 05e5c9d b026005 60e5cfd 05e5c9d 60e5cfd 05e5c9d 60e5cfd 05e5c9d 60e5cfd 6c13449 60e5cfd 05e5c9d dce6539 05e5c9d be6042a dce6539 a4acab9 60e5cfd 05e5c9d a4acab9 6c13449 a4acab9 6c13449 05e5c9d 6c13449 05e5c9d 6c13449 a4acab9 60e5cfd 05e5c9d 60e5cfd 05e5c9d 60e5cfd 05e5c9d 60e5cfd 05e5c9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from flask import Flask, request, send_file, Response, jsonify
from flask_cors import CORS
import numpy as np
import io
import torch
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from PIL import Image
import zipfile
app = Flask(__name__)
CORS(app)
cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
print(cudaOrNah)
# Global model setup
# running out of memory adjusted
# checkpoint = "sam_vit_h_4b8939.pth"
# model_type = "vit_h"
checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=cudaOrNah)
mask_generator = SamAutomaticMaskGenerator(
model=sam,
min_mask_region_area=0.0015 # Adjust this value as needed
)
print('Setup SAM model')
@app.route('/')
def hello():
return {"hei": "Shredded to peices"}
@app.route('/health', methods=['GET'])
def health_check():
# Simple health check endpoint
return jsonify({"status": "ok"}), 200
@app.route('/get-masks', methods=['POST'])
def get_masks():
try:
print('received image from frontend')
# Get the image file from the request
if 'image' not in request.files:
return jsonify({"error": "No image file provided"}), 400
image_file = request.files['image']
if image_file.filename == '':
return jsonify({"error": "No image file provided"}), 400
# Read image file using OpenCV-style approach (similar to cv2.imread)
# Convert the image file to a NumPy array using OpenCV
file_bytes = np.fromstring(image_file.read(), np.uint8)
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
# Convert BGR to RGB using OpenCV (similar to cv2.cvtColor)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image is None:
raise ValueError("Image not found or unable to read.")
if cudaOrNah == "cuda":
torch.cuda.empty_cache()
masks = mask_generator.generate(image)
if cudaOrNah == "cuda":
torch.cuda.empty_cache()
masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
def is_background(segmentation):
val = (segmentation[10, 10] or segmentation[-10, 10] or
segmentation[10, -10] or segmentation[-10, -10])
return val
masks = [mask for mask in masks if not is_background(mask['segmentation'])]
for i in range(0, len(masks) - 1)[::-1]:
large_mask = masks[i]['segmentation']
for j in range(i+1, len(masks)):
not_small_mask = np.logical_not(masks[j]['segmentation'])
masks[i]['segmentation'] = np.logical_and(large_mask, not_small_mask)
masks[i]['area'] = masks[i]['segmentation'].sum()
large_mask = masks[i]['segmentation']
def sum_under_threshold(segmentation, threshold):
return segmentation.sum() / segmentation.size < 0.0015
masks = [mask for mask in masks if not sum_under_threshold(mask['segmentation'], 100)]
masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
# Create a zip file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for idx, mask in enumerate(masks):
alpha = mask['segmentation'].astype('uint8') * 255
mask_image = Image.fromarray(alpha)
mask_io = io.BytesIO()
mask_image.save(mask_io, format="PNG")
mask_io.seek(0)
zip_file.writestr(f'mask_{idx+1}.png', mask_io.read())
zip_buffer.seek(0)
return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip')
except Exception as e:
# Log the error message if needed
print(f"Error processing the image: {e}")
# Return a JSON response with the error message and a 400 Bad Request status
return jsonify({"error": "Error processing the image", "details": str(e)}), 400
if __name__ == '__main__':
app.run(debug=True) |