from flask import Flask, request, jsonify from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from PIL import Image import torch import numpy as np import io import base64 import threading import time app = Flask(__name__) # Load CLIPSeg processor and model processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") # Global variable for caching results cache = {} # Function to process image and generate mask def process_image(image, prompt): inputs = processor( text=prompt, images=image, padding="max_length", return_tensors="pt" ) with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits pred = torch.sigmoid(preds) mat = pred.cpu().numpy() mask = Image.fromarray(np.uint8(mat * 255), "L") mask = mask.convert("RGB") mask = mask.resize(image.size) mask = np.array(mask)[:, :, 0] mask_min = mask.min() mask_max = mask.max() mask = (mask - mask_min) / (mask_max - mask_min) return mask # Function to get masks from positive or negative prompts def get_masks(prompts, img, threshold): prompts = prompts.split(",") masks = [] for prompt in prompts: mask = process_image(img, prompt) mask = mask > threshold masks.append(mask) return masks # Function to extract image using positive and negative prompts def extract_image(pos_prompts, neg_prompts, img, threshold): cache_key = (pos_prompts, neg_prompts, threshold) if cache_key in cache: return cache[cache_key] positive_masks = get_masks(pos_prompts, img, 0.5) negative_masks = get_masks(neg_prompts, img, 0.5) pos_mask = np.any(np.stack(positive_masks), axis=0) neg_mask = np.any(np.stack(negative_masks), axis=0) final_mask = pos_mask & ~neg_mask final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L") # Convert final mask to base64 buffered = io.BytesIO() final_mask.save(buffered, format="PNG") final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") # Cache the result cache[cache_key] = {'final_mask_base64': final_mask_base64} return cache[cache_key] @app.route('/') def hello_world(): return 'Hello, World!' @app.route('/api', methods=['POST']) def process_request(): data = request.json # Convert base64 image to PIL Image base64_image = data.get('image') image_data = base64.b64decode(base64_image.split(',')[1]) img = Image.open(io.BytesIO(image_data)) # Get other parameters pos_prompts = data.get('positive_prompts', '') neg_prompts = data.get('negative_prompts', '') threshold = float(data.get('threshold', 0.4)) # Perform image segmentation result = extract_image(pos_prompts, neg_prompts, img, threshold) return jsonify(result) # Keep the server alive using a periodic task def keep_alive(): while True: time.sleep(300) # 5 minutes requests.get('http://127.0.0.1:7860/') # Send a request to keep the server alive if __name__ == '__main__': print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/") # Start the keep-alive thread keep_alive_thread = threading.Thread(target=keep_alive) keep_alive_thread.start() app.run(host='0.0.0.0', port=7860, debug=True)