CLIPSeg / app.py
sigyllly's picture
Update app.py
125d133 verified
raw
history blame
1.77 kB
from flask import Flask, request, jsonify
import base64
from PIL import Image
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
app = Flask(__name__)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
@app.route('/api/mask_image', methods=['POST'])
def mask_image_api():
data = request.get_json()
base64_image = data.get('base64_image', '')
prompt = data.get('prompt', '')
threshold = data.get('threshold', 0.4)
alpha_value = data.get('alpha_value', 0.5)
draw_rectangles = data.get('draw_rectangles', False)
# Decode base64 image
image_data = base64.b64decode(base64_image)
# Process the image
image = Image.open(BytesIO(image_data))
inputs = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
bmask = mask > threshold
mask[mask < threshold] = 0
# Convert the output mask to base64
buffered_mask = BytesIO()
mask.save(buffered_mask, format="PNG")
base64_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
return jsonify({'base64_mask': base64_mask})
if __name__ == '__main__':
app.run(debug=True)