sigyllly commited on
Commit
6e7b1c7
·
verified ·
1 Parent(s): 6bde246

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -93
app.py CHANGED
@@ -8,102 +8,76 @@ import threading
8
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
- # Gradio UI
12
- with gr.Blocks() as demo:
13
- gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
14
-
15
- # Add your article and description here
16
- gr.Markdown("Your article goes here")
17
- gr.Markdown("Your description goes here")
18
-
19
- with gr.Row():
20
- with gr.Column():
21
- input_image = gr.Image(type="pil")
22
- positive_prompts = gr.Textbox(
23
- label="Please describe what you want to identify (comma separated)"
24
- )
25
- negative_prompts = gr.Textbox(
26
- label="Please describe what you want to ignore (comma separated)"
27
- )
28
-
29
- input_slider_T = gr.Slider(
30
- minimum=0, maximum=1, value=0.4, label="Threshold"
31
- )
32
- btn_process = gr.Button(label="Process")
33
-
34
- with gr.Column():
35
- output_image = gr.Image(label="Result")
36
- output_mask = gr.Image(label="Mask")
37
-
38
- def process_image(image, prompt):
39
- inputs = processor(
40
- text=prompt, images=image, padding="max_length", return_tensors="pt"
41
- )
42
-
43
- with torch.no_grad():
44
- outputs = model(**inputs)
45
- preds = outputs.logits
46
-
47
- pred = torch.sigmoid(preds)
48
- mat = pred.cpu().numpy()
49
- mask = Image.fromarray(np.uint8(mat * 255), "L")
50
- mask = mask.convert("RGB")
51
- mask = mask.resize(image.size)
52
- mask = np.array(mask)[:, :, 0]
53
-
54
- mask_min = mask.min()
55
- mask_max = mask.max()
56
- mask = (mask - mask_min) / (mask_max - mask_min)
57
- return mask
58
-
59
- def get_masks(prompts, img, threshold):
60
- prompts = prompts.split(",")
61
- masks = []
62
- for prompt in prompts:
63
- mask = process_image(img, prompt)
64
- mask = mask > threshold
65
- masks.append(mask)
66
- return masks
67
-
68
- def extract_image(pos_prompts, neg_prompts, img, threshold):
69
- positive_masks = get_masks(pos_prompts, img, 0.5)
70
- negative_masks = get_masks(neg_prompts, img, 0.5)
71
-
72
- pos_mask = np.any(np.stack(positive_masks), axis=0)
73
- neg_mask = np.any(np.stack(negative_masks), axis=0)
74
- final_mask = pos_mask & ~neg_mask
75
-
76
- final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
77
- output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
78
- output_image.paste(img, mask=final_mask)
79
- return output_image, final_mask
80
-
81
- btn_process.click(
82
- extract_image,
83
- inputs=[
84
- positive_prompts,
85
- negative_prompts,
86
- input_image,
87
- input_slider_T,
88
- ],
89
- outputs=[output_image, output_mask],
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- # API endpoint definition
93
  iface = gr.Interface(
94
- extract_image,
95
- [
96
- gr.Textbox(label="Positive prompts"),
97
- gr.Textbox(label="Negative prompts"),
98
- gr.Image(type="pil"),
99
- gr.Slider(minimum=0, maximum=1, value=0.4, label="Threshold"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ],
101
- [gr.Image(label="Result"), gr.Image(label="Mask")],
102
- "textbox,textbox,image,slider",
103
- "image,image",
104
- title="CLIPSeg API",
105
  )
106
 
107
- # Launch both UI and API
108
- demo.launch()
109
- iface.launch(share=True)
 
8
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
+ def process_image(image, prompt):
12
+ inputs = processor(
13
+ text=prompt, images=image, padding="max_length", return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ preds = outputs.logits
18
+ pred = torch.sigmoid(preds)
19
+ mat = pred.cpu().numpy()
20
+ mask = Image.fromarray(np.uint8(mat * 255), "L")
21
+ mask = mask.convert("RGB")
22
+ mask = mask.resize(image.size)
23
+ mask = np.array(mask)[:, :, 0]
24
+ mask_min = mask.min()
25
+ mask_max = mask.max()
26
+ mask = (mask - mask_min) / (mask_max - mask_min)
27
+ return mask
28
+
29
+ def get_masks(prompts, img, threshold):
30
+ prompts = prompts.split(",")
31
+ masks = []
32
+ for prompt in prompts:
33
+ mask = process_image(img, prompt)
34
+ mask = mask > threshold
35
+ masks.append(mask)
36
+ return masks
37
+
38
+ def extract_image(pos_prompts, neg_prompts, img, threshold):
39
+ positive_masks = get_masks(pos_prompts, img, 0.5)
40
+ negative_masks = get_masks(neg_prompts, img, 0.5)
41
+
42
+ pos_mask = np.any(np.stack(positive_masks), axis=0)
43
+ neg_mask = np.any(np.stack(negative_masks), axis=0)
44
+ final_mask = pos_mask & ~neg_mask
45
+
46
+ final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
47
+ output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
48
+ output_image.paste(img, mask=final_mask)
49
+ return output_image, final_mask
50
 
 
51
  iface = gr.Interface(
52
+ fn=extract_image,
53
+ inputs=[
54
+ gr.Image(type="pil", label="Input Image"),
55
+ gr.Textbox(label="Positive Prompts (comma separated)"),
56
+ gr.Textbox(label="Negative Prompts (comma separated)"),
57
+ gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"),
58
+ ],
59
+ outputs=[
60
+ gr.Image(type="pil", label="Output Image"),
61
+ gr.Image(type="pil", label="Output Mask"),
62
+ ],
63
+ )
64
+
65
+ # Define an API endpoint
66
+ api_interface = gr.Interface(
67
+ fn=extract_image,
68
+ inputs=[
69
+ gr.Image(type="pil", label="Input Image"),
70
+ gr.Textbox(label="Positive Prompts (comma separated)"),
71
+ gr.Textbox(label="Negative Prompts (comma separated)"),
72
+ gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"),
73
+ ],
74
+ outputs=[
75
+ gr.Image(type="pil", label="Output Image"),
76
+ gr.Image(type="pil", label="Output Mask"),
77
  ],
78
+ live=True # Setting live to True enables the API endpoint
 
 
 
79
  )
80
 
81
+ # Run the Gradio UI and API
82
+ iface.launch()
83
+ api_interface.launch(share=True) # share=True allows external access to the API