File size: 4,186 Bytes
60e5cfd
 
 
 
 
 
 
 
 
84fd7cf
60e5cfd
 
 
 
 
 
05e5c9d
 
 
 
b026005
 
60e5cfd
 
 
 
 
 
 
 
 
 
05e5c9d
60e5cfd
 
 
 
 
 
 
 
 
05e5c9d
60e5cfd
 
 
05e5c9d
60e5cfd
 
 
 
6c13449
 
 
 
 
 
 
60e5cfd
 
 
05e5c9d
dce6539
 
05e5c9d
be6042a
dce6539
 
 
 
a4acab9
60e5cfd
05e5c9d
 
 
 
a4acab9
6c13449
a4acab9
6c13449
 
 
 
 
 
 
05e5c9d
6c13449
 
05e5c9d
6c13449
a4acab9
60e5cfd
05e5c9d
60e5cfd
 
 
 
 
 
 
 
05e5c9d
60e5cfd
 
05e5c9d
60e5cfd
 
 
 
 
 
 
 
05e5c9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from flask import Flask, request, send_file, Response, jsonify
from flask_cors import CORS
import numpy as np
import io
import torch
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from PIL import Image
import zipfile

app = Flask(__name__)
CORS(app)

cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
print(cudaOrNah)

# Global model setup 
# running out of memory adjusted
# checkpoint = "sam_vit_h_4b8939.pth"
# model_type = "vit_h"
checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=cudaOrNah)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    min_mask_region_area=0.0015  # Adjust this value as needed
)
print('Setup SAM model')

@app.route('/')
def hello():
    return {"hei": "Shredded to peices"}

@app.route('/health', methods=['GET'])
def health_check():
    # Simple health check endpoint
    return jsonify({"status": "ok"}), 200

@app.route('/get-masks', methods=['POST'])
def get_masks():
    try:
        print('received image from frontend')
        # Get the image file from the request
        if 'image' not in request.files:
            return jsonify({"error": "No image file provided"}), 400
        
        image_file = request.files['image']
        if image_file.filename == '':
            return jsonify({"error": "No image file provided"}), 400

        # Read image file using OpenCV-style approach (similar to cv2.imread)
        # Convert the image file to a NumPy array using OpenCV
        file_bytes = np.fromstring(image_file.read(), np.uint8)
        image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)

        # Convert BGR to RGB using OpenCV (similar to cv2.cvtColor)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if image is None:
            raise ValueError("Image not found or unable to read.")
        
        if cudaOrNah == "cuda":
            torch.cuda.empty_cache()
        
        masks = mask_generator.generate(image)

        if cudaOrNah == "cuda":
            torch.cuda.empty_cache()

        masks = sorted(masks, key=(lambda x: x['area']), reverse=True)

        def is_background(segmentation):
            val = (segmentation[10, 10] or segmentation[-10, 10] or
                   segmentation[10, -10] or segmentation[-10, -10])
            return val

        masks = [mask for mask in masks if not is_background(mask['segmentation'])]

        for i in range(0, len(masks) - 1)[::-1]:
            large_mask = masks[i]['segmentation']
            for j in range(i+1, len(masks)):
                not_small_mask = np.logical_not(masks[j]['segmentation'])
                masks[i]['segmentation'] = np.logical_and(large_mask, not_small_mask)
                masks[i]['area'] = masks[i]['segmentation'].sum()
                large_mask = masks[i]['segmentation']

        def sum_under_threshold(segmentation, threshold):
            return segmentation.sum() / segmentation.size < 0.0015

        masks = [mask for mask in masks if not sum_under_threshold(mask['segmentation'], 100)]
        masks = sorted(masks, key=(lambda x: x['area']), reverse=True)

        # Create a zip file in memory
        zip_buffer = io.BytesIO()
        with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
            for idx, mask in enumerate(masks):
                alpha = mask['segmentation'].astype('uint8') * 255
                mask_image = Image.fromarray(alpha)
                mask_io = io.BytesIO()
                mask_image.save(mask_io, format="PNG")
                mask_io.seek(0)
                zip_file.writestr(f'mask_{idx+1}.png', mask_io.read())

        zip_buffer.seek(0)
        
        return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip')
    except Exception as e:
        # Log the error message if needed
        print(f"Error processing the image: {e}")
        # Return a JSON response with the error message and a 400 Bad Request status
        return jsonify({"error": "Error processing the image", "details": str(e)}), 400

if __name__ == '__main__':
    app.run(debug=True)