File size: 2,981 Bytes
35b65a4
5c75869
 
 
 
35b65a4
 
 
ab226f6
5c75869
24d11e8
5c75869
 
 
f77dac9
 
 
 
4fb1cd4
f77dac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d11e8
 
 
 
 
 
 
 
 
 
f77dac9
ab226f6
 
 
35b65a4
f77dac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35b65a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c6914c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from flask import Flask, request, jsonify
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import numpy as np
import io
import base64

app = Flask(__name__)

# Load CLIPSeg processor and model
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

# Function to process image and generate mask
def process_image(image, prompt):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)

    return mask

# Function to get masks from positive or negative prompts
def get_masks(prompts, img, threshold):
    prompts = prompts.split(",")
    masks = []
    for prompt in prompts:
        mask = process_image(img, prompt)
        mask = mask > threshold
        masks.append(mask)

    return masks

@app.route('/')
def hello_world():
    return 'Hello, World!'

# Function to extract image using positive and negative prompts
def extract_image(pos_prompts, neg_prompts, img, threshold):
    positive_masks = get_masks(pos_prompts, img, 0.5)
    negative_masks = get_masks(neg_prompts, img, 0.5)

    pos_mask = np.any(np.stack(positive_masks), axis=0)
    neg_mask = np.any(np.stack(negative_masks), axis=0)
    final_mask = pos_mask & ~neg_mask

    final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
    output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
    output_image.paste(img, mask=final_mask)

    return output_image, final_mask

@app.route('/api', methods=['POST'])
def process_request():
    data = request.json

    # Convert base64 image to PIL Image
    base64_image = data.get('image')
    image_data = base64.b64decode(base64_image.split(',')[1])
    img = Image.open(io.BytesIO(image_data))

    # Get other parameters
    pos_prompts = data.get('positive_prompts', '')
    neg_prompts = data.get('negative_prompts', '')
    threshold = float(data.get('threshold', 0.4))

    # Perform image segmentation
    output_image, final_mask = extract_image(pos_prompts, neg_prompts, img, threshold)

    # Convert result to base64 for response
    buffered = io.BytesIO()
    output_image.save(buffered, format="PNG")
    result_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return jsonify({'result_image_base64': result_image_base64})

if __name__ == '__main__':
    print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
    app.run(host='0.0.0.0', port=80, debug=True)