sagar007 commited on
Commit
3cd1243
·
verified ·
1 Parent(s): 564688d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -38
app.py CHANGED
@@ -2,73 +2,73 @@ import gradio as gr
2
  import torch
3
  import cv2
4
  import numpy as np
5
- from transformers import SamModel, SamProcessor
6
  from PIL import Image
7
 
8
  # Set up device
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # Load model and processor
12
- model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
- processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
 
 
 
 
14
 
15
  def process_mask(mask, target_size):
16
- # Ensure mask is 2D
17
  if mask.ndim > 2:
18
  mask = mask.squeeze()
19
-
20
- # If mask is still not 2D, take the first 2D slice
21
  if mask.ndim > 2:
22
  mask = mask[0]
23
-
24
- # Convert to binary
25
  mask = (mask > 0.5).astype(np.uint8) * 255
26
-
27
- # Resize mask to match original image size using PIL
28
  mask_image = Image.fromarray(mask)
29
  mask_image = mask_image.resize(target_size, Image.NEAREST)
30
-
31
  return np.array(mask_image) > 0
32
 
33
- def segment_image(input_image, segment_anything):
34
  try:
35
  if input_image is None:
36
  return None, "Please upload an image before submitting."
37
 
38
- # Convert input_image to PIL Image and ensure it's RGB
39
  input_image = Image.fromarray(input_image).convert("RGB")
40
-
41
- # Store original size
42
  original_size = input_image.size
43
  if not original_size or 0 in original_size:
44
  return None, "Invalid image size. Please upload a different image."
45
 
46
- # Process the image
47
- if segment_anything:
48
- inputs = processor(input_image, return_tensors="pt").to(device)
49
- else:
50
- width, height = original_size
51
- center_point = [[width // 2, height // 2]]
52
- inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
53
 
54
  # Generate masks
55
  with torch.no_grad():
56
- outputs = model(**inputs)
57
 
58
  # Post-process masks
59
- masks = processor.image_processor.post_process_masks(
60
- outputs.pred_masks.cpu(),
61
- inputs["original_sizes"].cpu(),
62
- inputs["reshaped_input_sizes"].cpu()
63
  )
64
 
65
- # Process the mask
66
- if segment_anything:
67
- combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
68
- else:
69
- combined_mask = masks[0][0].numpy()
 
 
 
 
 
 
 
70
 
71
- combined_mask = process_mask(combined_mask, original_size)
72
 
73
  # Overlay the mask on the original image
74
  result_image = np.array(input_image)
@@ -76,7 +76,7 @@ def segment_image(input_image, segment_anything):
76
  mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
77
  result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
78
 
79
- return result_image, "Segmentation completed successfully."
80
 
81
  except Exception as e:
82
  return None, f"An error occurred: {str(e)}"
@@ -86,14 +86,14 @@ iface = gr.Interface(
86
  fn=segment_image,
87
  inputs=[
88
  gr.Image(type="numpy", label="Upload an image"),
89
- gr.Checkbox(label="Segment Everything")
90
  ],
91
  outputs=[
92
  gr.Image(type="numpy", label="Segmented Image"),
93
  gr.Textbox(label="Status")
94
  ],
95
- title="Segment Anything Model (SAM) Image Segmentation",
96
- description="Upload an image and choose whether to segment everything or use a center point."
97
  )
98
 
99
  # Launch the interface
 
2
  import torch
3
  import cv2
4
  import numpy as np
5
+ from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
6
  from PIL import Image
7
 
8
  # Set up device
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Load SAM model and processor
12
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
+
15
+ # Load BLIP model and processor for image-to-text
16
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
18
 
19
  def process_mask(mask, target_size):
 
20
  if mask.ndim > 2:
21
  mask = mask.squeeze()
 
 
22
  if mask.ndim > 2:
23
  mask = mask[0]
 
 
24
  mask = (mask > 0.5).astype(np.uint8) * 255
 
 
25
  mask_image = Image.fromarray(mask)
26
  mask_image = mask_image.resize(target_size, Image.NEAREST)
 
27
  return np.array(mask_image) > 0
28
 
29
+ def segment_image(input_image, object_name):
30
  try:
31
  if input_image is None:
32
  return None, "Please upload an image before submitting."
33
 
 
34
  input_image = Image.fromarray(input_image).convert("RGB")
 
 
35
  original_size = input_image.size
36
  if not original_size or 0 in original_size:
37
  return None, "Invalid image size. Please upload a different image."
38
 
39
+ # Generate image caption
40
+ blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
41
+ caption = blip_model.generate(**blip_inputs)
42
+ caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
43
+
44
+ # Process the image with SAM
45
+ sam_inputs = sam_processor(input_image, return_tensors="pt").to(device)
46
 
47
  # Generate masks
48
  with torch.no_grad():
49
+ sam_outputs = sam_model(**sam_inputs)
50
 
51
  # Post-process masks
52
+ masks = sam_processor.image_processor.post_process_masks(
53
+ sam_outputs.pred_masks.cpu(),
54
+ sam_inputs["original_sizes"].cpu(),
55
+ sam_inputs["reshaped_input_sizes"].cpu()
56
  )
57
 
58
+ # Find the mask that best matches the specified object
59
+ best_mask = None
60
+ best_score = -1
61
+ for mask in masks[0]:
62
+ mask_binary = mask.numpy() > 0.5
63
+ mask_area = mask_binary.sum()
64
+ if object_name.lower() in caption_text.lower() and mask_area > best_score:
65
+ best_mask = mask_binary
66
+ best_score = mask_area
67
+
68
+ if best_mask is None:
69
+ return input_image, f"Could not find '{object_name}' in the image."
70
 
71
+ combined_mask = process_mask(best_mask, original_size)
72
 
73
  # Overlay the mask on the original image
74
  result_image = np.array(input_image)
 
76
  mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
77
  result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
78
 
79
+ return result_image, f"Segmented '{object_name}' in the image."
80
 
81
  except Exception as e:
82
  return None, f"An error occurred: {str(e)}"
 
86
  fn=segment_image,
87
  inputs=[
88
  gr.Image(type="numpy", label="Upload an image"),
89
+ gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
90
  ],
91
  outputs=[
92
  gr.Image(type="numpy", label="Segmented Image"),
93
  gr.Textbox(label="Status")
94
  ],
95
+ title="Segment Anything Model (SAM) with Object Specification",
96
+ description="Upload an image and specify an object to segment."
97
  )
98
 
99
  # Launch the interface