sigyllly commited on
Commit
f77dac9
·
verified ·
1 Parent(s): 2415688

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -87
app.py CHANGED
@@ -8,91 +8,74 @@ import threading
8
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
- # Gradio UI
12
- with gr.Blocks() as demo:
13
- gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
14
-
15
- # Add your article and description here
16
- gr.Markdown("Your article goes here")
17
- gr.Markdown("Your description goes here")
18
-
19
- with gr.Row():
20
- with gr.Column():
21
- input_image = gr.Image(type="pil")
22
- positive_prompts = gr.Textbox(
23
- label="Please describe what you want to identify (comma separated)"
24
- )
25
- negative_prompts = gr.Textbox(
26
- label="Please describe what you want to ignore (comma separated)"
27
- )
28
- input_slider_T = gr.Slider(
29
- minimum=0, maximum=1, value=0.4, label="Threshold"
30
- )
31
- btn_process = gr.Button(label="Process")
32
-
33
- with gr.Column():
34
- output_image = gr.Image(label="Result")
35
- output_mask = gr.Image(label="Mask")
36
-
37
- def process_image(image, prompt):
38
- inputs = processor(
39
- text=prompt, images=image, padding="max_length", return_tensors="pt"
40
- )
41
- with torch.no_grad():
42
- outputs = model(**inputs)
43
- preds = outputs.logits
44
-
45
- pred = torch.sigmoid(preds)
46
- mat = pred.cpu().numpy()
47
- mask = Image.fromarray(np.uint8(mat * 255), "L")
48
- mask = mask.convert("RGB")
49
- mask = mask.resize(image.size)
50
- mask = np.array(mask)[:, :, 0]
51
-
52
- mask_min = mask.min()
53
- mask_max = mask.max()
54
- mask = (mask - mask_min) / (mask_max - mask_min)
55
-
56
- return mask
57
-
58
- def get_masks(prompts, img, threshold):
59
- prompts = prompts.split(",")
60
- masks = []
61
- for prompt in prompts:
62
- mask = process_image(img, prompt)
63
- mask = mask > threshold
64
- masks.append(mask)
65
-
66
- return masks
67
-
68
- def extract_image(pos_prompts, neg_prompts, img, threshold):
69
- positive_masks = get_masks(pos_prompts, img, 0.5)
70
- negative_masks = get_masks(neg_prompts, img, 0.5)
71
-
72
- pos_mask = np.any(np.stack(positive_masks), axis=0)
73
- neg_mask = np.any(np.stack(negative_masks), axis=0)
74
- final_mask = pos_mask & ~neg_mask
75
-
76
- final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
77
- output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
78
- output_image.paste(img, mask=final_mask)
79
-
80
- return output_image, final_mask
81
-
82
- btn_process.click(
83
- extract_image,
84
- inputs=[
85
- positive_prompts,
86
- negative_prompts,
87
- input_image,
88
- input_slider_T,
89
- ],
90
- outputs=[output_image, output_mask],
91
  )
92
-
93
- # Gradio API Endpoint
94
- gr.Interface(
95
- [positive_prompts, negative_prompts, input_image, input_slider_T],
96
- [output_image, output_mask],
97
- live=True,
98
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
+ # Function to process image and generate mask
12
+ def process_image(image, prompt):
13
+ inputs = processor(
14
+ text=prompt, images=image, padding="max_length", return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  )
16
+ with torch.no_grad():
17
+ outputs = model(**inputs)
18
+ preds = outputs.logits
19
+
20
+ pred = torch.sigmoid(preds)
21
+ mat = pred.cpu().numpy()
22
+ mask = Image.fromarray(np.uint8(mat * 255), "L")
23
+ mask = mask.convert("RGB")
24
+ mask = mask.resize(image.size)
25
+ mask = np.array(mask)[:, :, 0]
26
+
27
+ mask_min = mask.min()
28
+ mask_max = mask.max()
29
+ mask = (mask - mask_min) / (mask_max - mask_min)
30
+
31
+ return mask
32
+
33
+ # Function to get masks from positive or negative prompts
34
+ def get_masks(prompts, img, threshold):
35
+ prompts = prompts.split(",")
36
+ masks = []
37
+ for prompt in prompts:
38
+ mask = process_image(img, prompt)
39
+ mask = mask > threshold
40
+ masks.append(mask)
41
+
42
+ return masks
43
+
44
+ # Function to extract image using positive and negative prompts
45
+ def extract_image(pos_prompts, neg_prompts, img, threshold):
46
+ positive_masks = get_masks(pos_prompts, img, 0.5)
47
+ negative_masks = get_masks(neg_prompts, img, 0.5)
48
+
49
+ pos_mask = np.any(np.stack(positive_masks), axis=0)
50
+ neg_mask = np.any(np.stack(negative_masks), axis=0)
51
+ final_mask = pos_mask & ~neg_mask
52
+
53
+ final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
54
+ output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
55
+ output_image.paste(img, mask=final_mask)
56
+
57
+ return output_image, final_mask
58
+
59
+ # Define Gradio interface
60
+ iface = gr.Interface(
61
+ fn=extract_image,
62
+ inputs=[
63
+ gr.Textbox(
64
+ label="Please describe what you want to identify (comma separated)",
65
+ key="pos_prompts",
66
+ ),
67
+ gr.Textbox(
68
+ label="Please describe what you want to ignore (comma separated)",
69
+ key="neg_prompts",
70
+ ),
71
+ gr.Image(type="pil", label="Input Image", key="img"),
72
+ gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold", key="threshold"),
73
+ ],
74
+ outputs=[
75
+ gr.Image(label="Result", key="output_image"),
76
+ gr.Image(label="Mask", key="output_mask"),
77
+ ],
78
+ )
79
+
80
+ # Launch Gradio API
81
+ iface.launch()