from flask import Flask, request, jsonify, render_template from PIL import Image import base64 from io import BytesIO from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import torch import numpy as np import matplotlib.pyplot as plt import cv2 app = Flask(__name__) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def process_image(image, prompt, threshold, alpha_value, draw_rectangles): inputs = processor( text=prompt, images=image, padding="max_length", return_tensors="pt" ) # predict with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits pred = torch.sigmoid(preds) mat = pred.cpu().numpy() # Ensure we are working with a single-channel 2D mask mat = np.squeeze(mat, axis=0) # Remove batch dimension if it exists mask = Image.fromarray(np.uint8(mat * 255), "L") mask = mask.convert("RGB") mask = mask.resize(image.size) mask = np.array(mask)[:, :, 0] # normalize the mask mask_min = mask.min() mask_max = mask.max() mask = (mask - mask_min) / (mask_max - mask_min) # threshold the mask bmask = mask > threshold mask[mask < threshold] = 0 fig, ax = plt.subplots() ax.imshow(image) ax.imshow(mask, alpha=alpha_value, cmap="jet") if draw_rectangles: contours, hierarchy = cv2.findContours( bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE ) for contour in contours: x, y, w, h = cv2.boundingRect(contour) rect = plt.Rectangle( (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2 ) ax.add_patch(rect) ax.axis("off") plt.tight_layout() bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L") output_image = Image.new("RGBA", image.size, (0, 0, 0, 0)) output_image.paste(image, mask=bmask) # Convert mask to base64 buffered_mask = BytesIO() bmask.save(buffered_mask, format="PNG") result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8') # Convert output image to base64 buffered_output = BytesIO() output_image.save(buffered_output, format="PNG") result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8') return fig, result_mask, result_output # Existing process_image function, copy it here # ... @app.route('/') def index(): return render_template('index.html') @app.route('/api/mask_image', methods=['POST']) def mask_image_api(): data = request.get_json() base64_image = data.get('base64_image', '') prompt = data.get('prompt', '') threshold = data.get('threshold', 0.4) alpha_value = data.get('alpha_value', 0.5) draw_rectangles = data.get('draw_rectangles', False) # Decode base64 image image_data = base64.b64decode(base64_image.split(',')[1]) image = Image.open(BytesIO(image_data)) # Process the image _, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles) return jsonify({'result_mask': result_mask, 'result_output': result_output}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=True)