sagar007 commited on
Commit
7a7f5c3
·
verified ·
1 Parent(s): 71905f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -3,35 +3,53 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  from transformers import AutoProcessor, CLIPSegForImageSegmentation
 
6
 
7
  # Load the CLIPSeg model and processor
8
  processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
  def segment_image(input_image, text_prompt):
12
- # Preprocess the image
13
- inputs = processor(text=[text_prompt], images=[input_image], padding="max_length", return_tensors="pt")
 
 
14
 
15
- # Perform segmentation
16
- with torch.no_grad():
17
- outputs = model(**inputs)
 
18
 
19
- # Get the predicted segmentation
20
- preds = outputs.logits.squeeze().sigmoid()
21
 
22
- # Convert the prediction to a numpy array and scale to 0-255
23
- segmentation = (preds.numpy() * 255).astype(np.uint8)
 
24
 
25
- # Create a colored heatmap
26
- heatmap = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
27
- heatmap[:, :, 0] = segmentation # Red channel
28
- heatmap[:, :, 2] = 255 - segmentation # Blue channel
29
 
30
- # Blend the heatmap with the original image
31
- original_image = np.array(input_image)
32
- blended = (0.7 * original_image + 0.3 * heatmap).astype(np.uint8)
33
 
34
- return Image.fromarray(blended)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Create Gradio interface
37
  iface = gr.Interface(
@@ -40,7 +58,10 @@ iface = gr.Interface(
40
  gr.Image(type="pil", label="Input Image"),
41
  gr.Textbox(label="Text Prompt", placeholder="Enter a description of what to segment...")
42
  ],
43
- outputs=gr.Image(type="pil", label="Segmentation Result"),
 
 
 
44
  title="CLIPSeg Image Segmentation",
45
  description="Upload an image and provide a text prompt to segment objects."
46
  )
 
3
  import numpy as np
4
  from PIL import Image
5
  from transformers import AutoProcessor, CLIPSegForImageSegmentation
6
+ import traceback
7
 
8
  # Load the CLIPSeg model and processor
9
  processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
10
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
11
 
12
  def segment_image(input_image, text_prompt):
13
+ try:
14
+ # Ensure input_image is a PIL Image
15
+ if not isinstance(input_image, Image.Image):
16
+ input_image = Image.fromarray(input_image)
17
 
18
+ # Resize image if it's too large
19
+ max_size = 1024
20
+ if max(input_image.size) > max_size:
21
+ input_image.thumbnail((max_size, max_size))
22
 
23
+ # Preprocess the image
24
+ inputs = processor(text=[text_prompt], images=[input_image], padding="max_length", return_tensors="pt")
25
 
26
+ # Perform segmentation
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
 
30
+ # Get the predicted segmentation
31
+ preds = outputs.logits.squeeze().sigmoid()
 
 
32
 
33
+ # Convert the prediction to a numpy array and scale to 0-255
34
+ segmentation = (preds.numpy() * 255).astype(np.uint8)
 
35
 
36
+ # Resize segmentation to match input image size
37
+ segmentation = Image.fromarray(segmentation).resize(input_image.size)
38
+ segmentation = np.array(segmentation)
39
+
40
+ # Create a colored heatmap
41
+ heatmap = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
42
+ heatmap[:, :, 0] = segmentation # Red channel
43
+ heatmap[:, :, 2] = 255 - segmentation # Blue channel
44
+
45
+ # Blend the heatmap with the original image
46
+ original_image = np.array(input_image)
47
+ blended = (0.7 * original_image + 0.3 * heatmap).astype(np.uint8)
48
+
49
+ return Image.fromarray(blended)
50
+ except Exception as e:
51
+ error_msg = f"An error occurred: {str(e)}\n\nStacktrace:\n{traceback.format_exc()}"
52
+ return Image.new('RGB', (400, 200), color = (255, 0, 0)) # Red image to indicate error
53
 
54
  # Create Gradio interface
55
  iface = gr.Interface(
 
58
  gr.Image(type="pil", label="Input Image"),
59
  gr.Textbox(label="Text Prompt", placeholder="Enter a description of what to segment...")
60
  ],
61
+ outputs=[
62
+ gr.Image(type="pil", label="Segmentation Result"),
63
+ gr.Textbox(label="Error Message", visible=False)
64
+ ],
65
  title="CLIPSeg Image Segmentation",
66
  description="Upload an image and provide a text prompt to segment objects."
67
  )