File size: 3,020 Bytes
35b65a4 5c75869 35b65a4 d432bdc 35b65a4 ab226f6 5c75869 24d11e8 5c75869 df3e218 d432bdc f77dac9 4fb1cd4 f77dac9 24d11e8 f77dac9 df3e218 35b65a4 df3e218 35b65a4 df3e218 d432bdc 35b65a4 d432bdc 347b6c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from flask import Flask, request, jsonify
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import numpy as np
import io
import base64
import threading
import time
app = Flask(__name__)
# Load CLIPSeg processor and model
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
@app.route('/')
def hello_world():
return 'Hello, World!'
# Function to process image and generate mask
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
# Function to get masks from positive or negative prompts
def get_masks(prompts, img, threshold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threshold
masks.append(mask)
return masks
# Route for processing requests
@app.route('/api', methods=['POST'])
def process_request():
data = request.json
# Convert base64 image to PIL Image
base64_image = data.get('image')
image_data = base64.b64decode(base64_image.split(',')[1])
img = Image.open(io.BytesIO(image_data))
# Get other parameters
pos_prompts = data.get('positive_prompts', '')
neg_prompts = data.get('negative_prompts', '')
threshold = float(data.get('threshold', 0.4))
# Perform image segmentation without caching
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
# Convert final mask to base64
buffered = io.BytesIO()
final_mask.save(buffered, format="PNG")
final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({'final_mask_base64': final_mask_base64})
# Keep the server alive using a periodic task
def keep_alive():
while True:
time.sleep(300) # 5 minutes
requests.get('http://127.0.0.1:7860/') # Send a request to keep the server alive
if __name__ == '__main__':
print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
# Start the keep-alive thread
keep_alive_thread = threading.Thread(target=keep_alive)
keep_alive_thread.start()
app.run(host='0.0.0.0', port=7860, debug=True)
|