File size: 3,419 Bytes
35b65a4
5c75869
 
 
 
35b65a4
 
d432bdc
 
35b65a4
ab226f6
5c75869
24d11e8
5c75869
 
 
d432bdc
 
 
f77dac9
 
 
 
4fb1cd4
f77dac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d11e8
 
 
 
 
 
 
 
 
 
f77dac9
 
 
d432bdc
 
 
 
 
f77dac9
 
 
 
 
 
 
 
 
fffe605
 
 
 
 
d432bdc
 
 
 
fffe605
 
 
 
f77dac9
35b65a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d432bdc
35b65a4
d432bdc
 
 
 
 
 
 
35b65a4
 
 
d432bdc
 
 
 
 
347b6c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from flask import Flask, request, jsonify
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import numpy as np
import io
import base64
import threading
import time

app = Flask(__name__)

# Load CLIPSeg processor and model
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

# Global variable for caching results
cache = {}

# Function to process image and generate mask
def process_image(image, prompt):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)

    return mask

# Function to get masks from positive or negative prompts
def get_masks(prompts, img, threshold):
    prompts = prompts.split(",")
    masks = []
    for prompt in prompts:
        mask = process_image(img, prompt)
        mask = mask > threshold
        masks.append(mask)

    return masks

# Function to extract image using positive and negative prompts
def extract_image(pos_prompts, neg_prompts, img, threshold):
    cache_key = (pos_prompts, neg_prompts, threshold)
    
    if cache_key in cache:
        return cache[cache_key]

    positive_masks = get_masks(pos_prompts, img, 0.5)
    negative_masks = get_masks(neg_prompts, img, 0.5)

    pos_mask = np.any(np.stack(positive_masks), axis=0)
    neg_mask = np.any(np.stack(negative_masks), axis=0)
    final_mask = pos_mask & ~neg_mask

    final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")

    # Convert final mask to base64
    buffered = io.BytesIO()
    final_mask.save(buffered, format="PNG")
    final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

    # Cache the result
    cache[cache_key] = {'final_mask_base64': final_mask_base64}

    return cache[cache_key]

@app.route('/')
def hello_world():
    return 'Hello, World!'

@app.route('/api', methods=['POST'])
def process_request():
    data = request.json

    # Convert base64 image to PIL Image
    base64_image = data.get('image')
    image_data = base64.b64decode(base64_image.split(',')[1])
    img = Image.open(io.BytesIO(image_data))

    # Get other parameters
    pos_prompts = data.get('positive_prompts', '')
    neg_prompts = data.get('negative_prompts', '')
    threshold = float(data.get('threshold', 0.4))

    # Perform image segmentation
    result = extract_image(pos_prompts, neg_prompts, img, threshold)

    return jsonify(result)

# Keep the server alive using a periodic task
def keep_alive():
    while True:
        time.sleep(300)  # 5 minutes
        requests.get('http://127.0.0.1:7860/')  # Send a request to keep the server alive

if __name__ == '__main__':
    print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")

    # Start the keep-alive thread
    keep_alive_thread = threading.Thread(target=keep_alive)
    keep_alive_thread.start()

    app.run(host='0.0.0.0', port=7860, debug=True)