Saarthak2002 commited on
Commit
ace5a98
·
verified ·
1 Parent(s): 7990efc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -5,7 +5,8 @@ import requests
5
  import numpy as np
6
  import gradio as gr
7
  from io import BytesIO
8
- from torchvision.models.segmentation import deeplabv3_resnet50
 
9
 
10
  # Step 1: Load the Image from URL
11
  def load_image(url):
@@ -18,8 +19,8 @@ def crop_image(image, bounding_box):
18
  x_min, y_min, x_max, y_max = bounding_box.values()
19
  return image.crop((x_min, y_min, x_max, y_max))
20
 
21
- # Step 3: Preprocessing for U-Net (DeepLabV3 in this case)
22
- def preprocess_image(image, size=(256, 256)):
23
  preprocess = transforms.Compose([
24
  transforms.Resize(size),
25
  transforms.ToTensor(),
@@ -27,29 +28,38 @@ def preprocess_image(image, size=(256, 256)):
27
  ])
28
  return preprocess(image).unsqueeze(0) # Add batch dimension
29
 
30
- # Step 4: Load Pre-trained Segmentation Model (DeepLabV3)
31
  def load_model():
32
- model = deeplabv3_resnet50(pretrained=True)
33
  model.eval() # Set the model to evaluation mode
 
 
34
  return model
35
 
36
  # Step 5: Perform Segmentation
37
  def segment_image(model, input_tensor):
 
 
38
  with torch.no_grad():
39
  output = model(input_tensor)['out'] # Model output
40
  mask = output.argmax(dim=1).squeeze().cpu().numpy() # Get segmentation mask
41
  return mask
42
 
43
- # Step 6: Postprocess and Extract Object
44
- def apply_mask(image, mask, threshold=1):
45
- mask_resized = Image.fromarray((mask * 255).astype(np.uint8)).resize(image.size, Image.NEAREST)
46
- mask_resized = np.array(mask_resized) > threshold
47
- image_np = np.array(image)
 
 
 
48
 
49
- # Create RGBA image with transparency
 
50
  rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
51
  rgba_image[..., :3] = image_np # Copy RGB channels
52
- rgba_image[..., 3] = mask_resized.astype(np.uint8) * 255 # Alpha channel based on mask
 
53
  return Image.fromarray(rgba_image)
54
 
55
  # Gradio Interface to handle input and output
@@ -85,3 +95,4 @@ iface = gr.Interface(
85
 
86
  # Launch the interface
87
  iface.launch()
 
 
5
  import numpy as np
6
  import gradio as gr
7
  from io import BytesIO
8
+ from torchvision.models.segmentation import deeplabv3_resnet101
9
+ import cv2
10
 
11
  # Step 1: Load the Image from URL
12
  def load_image(url):
 
19
  x_min, y_min, x_max, y_max = bounding_box.values()
20
  return image.crop((x_min, y_min, x_max, y_max))
21
 
22
+ # Step 3: Preprocessing for Segmentation Model
23
+ def preprocess_image(image, size=(512, 512)):
24
  preprocess = transforms.Compose([
25
  transforms.Resize(size),
26
  transforms.ToTensor(),
 
28
  ])
29
  return preprocess(image).unsqueeze(0) # Add batch dimension
30
 
31
+ # Step 4: Load a More Robust Pre-trained Model
32
  def load_model():
33
+ model = deeplabv3_resnet101(pretrained=True) # Switch to ResNet101 for better feature extraction
34
  model.eval() # Set the model to evaluation mode
35
+ if torch.cuda.is_available():
36
+ model = model.to("cuda")
37
  return model
38
 
39
  # Step 5: Perform Segmentation
40
  def segment_image(model, input_tensor):
41
+ if torch.cuda.is_available():
42
+ input_tensor = input_tensor.to("cuda")
43
  with torch.no_grad():
44
  output = model(input_tensor)['out'] # Model output
45
  mask = output.argmax(dim=1).squeeze().cpu().numpy() # Get segmentation mask
46
  return mask
47
 
48
+ # Step 6: Refine Mask and Extract Object
49
+ def apply_mask(image, mask):
50
+ mask = cv2.resize(mask.astype(np.uint8), image.size, interpolation=cv2.INTER_NEAREST)
51
+
52
+ # Apply morphological operations for cleaner mask
53
+ kernel = np.ones((5, 5), np.uint8)
54
+ mask = cv2.dilate(mask, kernel, iterations=1)
55
+ mask = cv2.erode(mask, kernel, iterations=1)
56
 
57
+ # Create RGBA image
58
+ image_np = np.array(image)
59
  rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
60
  rgba_image[..., :3] = image_np # Copy RGB channels
61
+ rgba_image[..., 3] = mask * 255 # Alpha channel based on refined mask
62
+
63
  return Image.fromarray(rgba_image)
64
 
65
  # Gradio Interface to handle input and output
 
95
 
96
  # Launch the interface
97
  iface.launch()
98
+