CLIPSeg2 / app.py
sigyllly's picture
Update app.py
347b6c6 verified
raw
history blame
2.98 kB
from flask import Flask, request, jsonify
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import numpy as np
import io
import base64
app = Flask(__name__)
# Load CLIPSeg processor and model
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# Function to process image and generate mask
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
# Function to get masks from positive or negative prompts
def get_masks(prompts, img, threshold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threshold
masks.append(mask)
return masks
@app.route('/')
def hello_world():
return 'Hello, World!'
# Function to extract image using positive and negative prompts
def extract_image(pos_prompts, neg_prompts, img, threshold):
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask
@app.route('/api', methods=['POST'])
def process_request():
data = request.json
# Convert base64 image to PIL Image
base64_image = data.get('image')
image_data = base64.b64decode(base64_image.split(',')[1])
img = Image.open(io.BytesIO(image_data))
# Get other parameters
pos_prompts = data.get('positive_prompts', '')
neg_prompts = data.get('negative_prompts', '')
threshold = float(data.get('threshold', 0.4))
# Perform image segmentation
output_image, final_mask = extract_image(pos_prompts, neg_prompts, img, threshold)
# Convert result to base64 for response
buffered = io.BytesIO()
output_image.save(buffered, format="PNG")
result_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({'result_image_base64': result_image_base64})
if __name__ == '__main__':
print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
app.run(host='0.0.0.0', port=7860, debug=True)