sagar007 commited on
Commit
73989e5
·
verified ·
1 Parent(s): 3ba1061

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -46
app.py CHANGED
@@ -13,53 +13,63 @@ model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
  def segment_image(input_image, segment_anything):
16
- if input_image is None:
17
- return None, "Please upload an image before submitting."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Convert input_image to PIL Image
20
- input_image = Image.fromarray(input_image)
21
-
22
- # Store original size
23
- original_size = input_image.size
24
-
25
- if segment_anything:
26
- # Segment everything in the image
27
- inputs = processor(input_image, return_tensors="pt").to(device)
28
- else:
29
- # Use the center of the image as a point prompt
30
- width, height = input_image.size
31
- center_point = [[width // 2, height // 2]]
32
- inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
33
-
34
- # Generate masks
35
- with torch.no_grad():
36
- outputs = model(**inputs)
37
-
38
- # Post-process masks
39
- masks = processor.image_processor.post_process_masks(
40
- outputs.pred_masks.cpu(),
41
- inputs["original_sizes"].cpu(),
42
- inputs["reshaped_input_sizes"].cpu()
43
- )
44
-
45
- # Convert mask to numpy array and resize to match original image
46
- if segment_anything:
47
- # Combine all masks
48
- combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
49
- else:
50
- # Use the first mask
51
- combined_mask = masks[0][0].numpy() > 0.5
52
-
53
- # Resize mask to match original image size
54
- combined_mask = cv2.resize(combined_mask.astype(np.uint8), original_size[::-1]) > 0
55
-
56
- # Overlay the mask on the original image
57
- result_image = np.array(input_image)
58
- mask_rgb = np.zeros_like(result_image)
59
- mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
60
- result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
61
-
62
- return result_image, "Segmentation completed successfully."
63
 
64
  # Create Gradio interface
65
  iface = gr.Interface(
 
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
  def segment_image(input_image, segment_anything):
16
+ try:
17
+ if input_image is None:
18
+ return None, "Please upload an image before submitting."
19
+
20
+ # Convert input_image to PIL Image
21
+ input_image = Image.fromarray(input_image).convert("RGB")
22
+
23
+ # Store original size
24
+ original_size = input_image.size
25
+ if not original_size or 0 in original_size:
26
+ return None, "Invalid image size. Please upload a different image."
27
+
28
+ if segment_anything:
29
+ # Segment everything in the image
30
+ inputs = processor(input_image, return_tensors="pt").to(device)
31
+ else:
32
+ # Use the center of the image as a point prompt
33
+ width, height = original_size
34
+ center_point = [[width // 2, height // 2]]
35
+ inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
36
+
37
+ # Generate masks
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+
41
+ # Post-process masks
42
+ masks = processor.image_processor.post_process_masks(
43
+ outputs.pred_masks.cpu(),
44
+ inputs["original_sizes"].cpu(),
45
+ inputs["reshaped_input_sizes"].cpu()
46
+ )
47
+
48
+ # Convert mask to numpy array and resize to match original image
49
+ if segment_anything:
50
+ # Combine all masks
51
+ combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
52
+ else:
53
+ # Use the first mask
54
+ combined_mask = masks[0][0].numpy() > 0.5
55
+
56
+ # Ensure mask is 2D
57
+ if combined_mask.ndim > 2:
58
+ combined_mask = combined_mask.squeeze()
59
+
60
+ # Resize mask to match original image size
61
+ combined_mask = cv2.resize(combined_mask.astype(np.uint8), (original_size[0], original_size[1])) > 0
62
+
63
+ # Overlay the mask on the original image
64
+ result_image = np.array(input_image)
65
+ mask_rgb = np.zeros_like(result_image)
66
+ mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
67
+ result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
68
+
69
+ return result_image, "Segmentation completed successfully."
70
 
71
+ except Exception as e:
72
+ return None, f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Create Gradio interface
75
  iface = gr.Interface(