File size: 3,020 Bytes
35b65a4
5c75869
 
 
 
35b65a4
 
d432bdc
 
35b65a4
ab226f6
5c75869
24d11e8
5c75869
 
 
df3e218
 
 
d432bdc
f77dac9
 
 
 
4fb1cd4
f77dac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d11e8
 
 
 
 
 
 
 
 
 
f77dac9
df3e218
35b65a4
 
 
 
 
 
 
 
 
 
 
 
 
 
df3e218
 
 
 
 
 
 
 
 
 
 
 
 
 
35b65a4
df3e218
d432bdc
 
 
 
 
 
35b65a4
 
 
d432bdc
 
 
 
 
347b6c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from flask import Flask, request, jsonify
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import numpy as np
import io
import base64
import threading
import time

app = Flask(__name__)

# Load CLIPSeg processor and model
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

@app.route('/')
def hello_world():
    return 'Hello, World!'

# Function to process image and generate mask
def process_image(image, prompt):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)

    return mask

# Function to get masks from positive or negative prompts
def get_masks(prompts, img, threshold):
    prompts = prompts.split(",")
    masks = []
    for prompt in prompts:
        mask = process_image(img, prompt)
        mask = mask > threshold
        masks.append(mask)

    return masks

# Route for processing requests
@app.route('/api', methods=['POST'])
def process_request():
    data = request.json

    # Convert base64 image to PIL Image
    base64_image = data.get('image')
    image_data = base64.b64decode(base64_image.split(',')[1])
    img = Image.open(io.BytesIO(image_data))

    # Get other parameters
    pos_prompts = data.get('positive_prompts', '')
    neg_prompts = data.get('negative_prompts', '')
    threshold = float(data.get('threshold', 0.4))

    # Perform image segmentation without caching
    positive_masks = get_masks(pos_prompts, img, 0.5)
    negative_masks = get_masks(neg_prompts, img, 0.5)

    pos_mask = np.any(np.stack(positive_masks), axis=0)
    neg_mask = np.any(np.stack(negative_masks), axis=0)
    final_mask = pos_mask & ~neg_mask

    final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")

    # Convert final mask to base64
    buffered = io.BytesIO()
    final_mask.save(buffered, format="PNG")
    final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return jsonify({'final_mask_base64': final_mask_base64})

# Keep the server alive using a periodic task
def keep_alive():
    while True:
        time.sleep(300)  # 5 minutes
        requests.get('http://127.0.0.1:7860/')  # Send a request to keep the server alive

if __name__ == '__main__':
    print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")

    # Start the keep-alive thread
    keep_alive_thread = threading.Thread(target=keep_alive)
    keep_alive_thread.start()

    app.run(host='0.0.0.0', port=7860, debug=True)