sagar007 commited on
Commit
a3ee867
·
verified ·
1 Parent(s): 7a7f5c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -58
app.py CHANGED
@@ -1,70 +1,67 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
- from PIL import Image
 
5
  from transformers import AutoProcessor, CLIPSegForImageSegmentation
6
- import traceback
7
 
8
  # Load the CLIPSeg model and processor
9
  processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
10
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
11
 
12
- def segment_image(input_image, text_prompt):
13
- try:
14
- # Ensure input_image is a PIL Image
15
- if not isinstance(input_image, Image.Image):
16
- input_image = Image.fromarray(input_image)
17
-
18
- # Resize image if it's too large
19
- max_size = 1024
20
- if max(input_image.size) > max_size:
21
- input_image.thumbnail((max_size, max_size))
22
-
23
- # Preprocess the image
24
- inputs = processor(text=[text_prompt], images=[input_image], padding="max_length", return_tensors="pt")
25
-
26
- # Perform segmentation
27
- with torch.no_grad():
28
- outputs = model(**inputs)
29
-
30
- # Get the predicted segmentation
31
- preds = outputs.logits.squeeze().sigmoid()
32
-
33
- # Convert the prediction to a numpy array and scale to 0-255
34
- segmentation = (preds.numpy() * 255).astype(np.uint8)
35
-
36
- # Resize segmentation to match input image size
37
- segmentation = Image.fromarray(segmentation).resize(input_image.size)
38
- segmentation = np.array(segmentation)
39
-
40
- # Create a colored heatmap
41
- heatmap = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
42
- heatmap[:, :, 0] = segmentation # Red channel
43
- heatmap[:, :, 2] = 255 - segmentation # Blue channel
44
 
45
- # Blend the heatmap with the original image
46
- original_image = np.array(input_image)
47
- blended = (0.7 * original_image + 0.3 * heatmap).astype(np.uint8)
 
 
 
 
 
 
 
 
 
48
 
49
- return Image.fromarray(blended)
50
- except Exception as e:
51
- error_msg = f"An error occurred: {str(e)}\n\nStacktrace:\n{traceback.format_exc()}"
52
- return Image.new('RGB', (400, 200), color = (255, 0, 0)) # Red image to indicate error
 
53
 
54
- # Create Gradio interface
55
- iface = gr.Interface(
56
- fn=segment_image,
57
- inputs=[
58
- gr.Image(type="pil", label="Input Image"),
59
- gr.Textbox(label="Text Prompt", placeholder="Enter a description of what to segment...")
60
- ],
61
- outputs=[
62
- gr.Image(type="pil", label="Segmentation Result"),
63
- gr.Textbox(label="Error Message", visible=False)
64
- ],
65
- title="CLIPSeg Image Segmentation",
66
- description="Upload an image and provide a text prompt to segment objects."
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Launch the interface
70
- iface.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
  from transformers import AutoProcessor, CLIPSegForImageSegmentation
 
6
 
7
  # Load the CLIPSeg model and processor
8
  processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
+ def segment_everything(image):
12
+ inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt")
13
+ with torch.no_grad():
14
+ outputs = model(**inputs)
15
+ preds = outputs.logits.squeeze().sigmoid()
16
+ segmentation = (preds.numpy() * 255).astype(np.uint8)
17
+ return Image.fromarray(segmentation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def segment_box(image, box):
20
+ x1, y1, x2, y2 = map(int, box)
21
+ mask = Image.new('L', image.size, 0)
22
+ draw = ImageDraw.Draw(mask)
23
+ draw.rectangle([x1, y1, x2, y2], fill=255)
24
+
25
+ inputs = processor(text=["object in box"], images=[image], mask_pixels=mask, padding="max_length", return_tensors="pt")
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ preds = outputs.logits.squeeze().sigmoid()
29
+ segmentation = (preds.numpy() * 255).astype(np.uint8)
30
+ return Image.fromarray(segmentation)
31
 
32
+ def update_image(image, segmentation, tool):
33
+ if segmentation is None:
34
+ return image
35
+ blended = Image.blend(image.convert('RGBA'), segmentation.convert('RGBA'), 0.5)
36
+ return blended
37
 
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown("# Segment Anything-like Demo")
40
+ with gr.Row():
41
+ with gr.Column(scale=1):
42
+ input_image = gr.Image(label="Input Image", tool="select")
43
+ with gr.Row():
44
+ everything_btn = gr.Button("Everything")
45
+ box_btn = gr.Button("Box")
46
+ with gr.Column(scale=1):
47
+ output_image = gr.Image(label="Segmentation Result")
48
+
49
+ everything_btn.click(
50
+ fn=segment_everything,
51
+ inputs=[input_image],
52
+ outputs=[output_image]
53
+ )
54
+
55
+ box_btn.click(
56
+ fn=segment_box,
57
+ inputs=[input_image, input_image.sel],
58
+ outputs=[output_image]
59
+ )
60
+
61
+ output_image.change(
62
+ fn=update_image,
63
+ inputs=[input_image, output_image, gr.State("last_tool")],
64
+ outputs=[output_image]
65
+ )
66
 
67
+ demo.launch()