Saarthak2002 commited on
Commit
975418f
·
verified ·
1 Parent(s): ce52e71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -14,13 +14,22 @@ def load_image(url):
14
  image = Image.open(BytesIO(response.content)).convert("RGB")
15
  return image
16
 
17
- # Step 2: Crop Image Based on Bounding Box
 
 
 
 
 
 
 
 
 
18
  def crop_image(image, bounding_box):
19
  x_min, y_min, x_max, y_max = bounding_box.values()
20
  return image.crop((x_min, y_min, x_max, y_max))
21
 
22
- # Step 3: Preprocessing for Segmentation Model
23
- def preprocess_image(image, size=(512, 512)):
24
  preprocess = transforms.Compose([
25
  transforms.Resize(size),
26
  transforms.ToTensor(),
@@ -28,7 +37,7 @@ def preprocess_image(image, size=(512, 512)):
28
  ])
29
  return preprocess(image).unsqueeze(0) # Add batch dimension
30
 
31
- # Step 4: Load a More Robust Pre-trained Model
32
  def load_model():
33
  model = deeplabv3_resnet101(pretrained=True) # Switch to ResNet101 for better feature extraction
34
  model.eval() # Set the model to evaluation mode
@@ -36,22 +45,27 @@ def load_model():
36
  model = model.to("cuda")
37
  return model
38
 
39
- # Step 5: Perform Segmentation
40
  def segment_image(model, input_tensor):
41
  if torch.cuda.is_available():
42
  input_tensor = input_tensor.to("cuda")
43
  with torch.no_grad():
44
  output = model(input_tensor)['out'] # Model output
45
- mask = output.argmax(dim=1).squeeze().cpu().numpy() # Get segmentation mask
 
46
  return mask
47
 
48
- # Step 6: Refine Mask and Extract Object
49
- def apply_mask(image, mask):
50
- mask = cv2.resize(mask.astype(np.uint8), image.size, interpolation=cv2.INTER_NEAREST)
51
-
52
- # Apply morphological operations for cleaner mask
 
 
 
 
53
  kernel = np.ones((5, 5), np.uint8)
54
- mask = cv2.dilate(mask, kernel, iterations=1)
55
  mask = cv2.erode(mask, kernel, iterations=1)
56
 
57
  # Create RGBA image
@@ -64,7 +78,7 @@ def apply_mask(image, mask):
64
 
65
  # Gradio Interface to handle input and output
66
  def segment_object(image_url, x_min, y_min, x_max, y_max):
67
- bounding_box = {"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}
68
 
69
  # Load and process the image
70
  image = load_image(image_url)
@@ -95,4 +109,3 @@ iface = gr.Interface(
95
 
96
  # Launch the interface
97
  iface.launch()
98
-
 
14
  image = Image.open(BytesIO(response.content)).convert("RGB")
15
  return image
16
 
17
+ # Step 2: Adjust Bounding Box to Add Margin
18
+ def adjust_bounding_box(bounding_box, margin=20):
19
+ return {
20
+ "x_min": max(0, bounding_box["x_min"] - margin),
21
+ "y_min": max(0, bounding_box["y_min"] - margin),
22
+ "x_max": bounding_box["x_max"] + margin,
23
+ "y_max": bounding_box["y_max"] + margin,
24
+ }
25
+
26
+ # Step 3: Crop Image Based on Bounding Box
27
  def crop_image(image, bounding_box):
28
  x_min, y_min, x_max, y_max = bounding_box.values()
29
  return image.crop((x_min, y_min, x_max, y_max))
30
 
31
+ # Step 4: Preprocessing for Segmentation Model
32
+ def preprocess_image(image, size=(1024, 1024)):
33
  preprocess = transforms.Compose([
34
  transforms.Resize(size),
35
  transforms.ToTensor(),
 
37
  ])
38
  return preprocess(image).unsqueeze(0) # Add batch dimension
39
 
40
+ # Step 5: Load a More Robust Pre-trained Model
41
  def load_model():
42
  model = deeplabv3_resnet101(pretrained=True) # Switch to ResNet101 for better feature extraction
43
  model.eval() # Set the model to evaluation mode
 
45
  model = model.to("cuda")
46
  return model
47
 
48
+ # Step 6: Perform Segmentation with Soft Masking
49
  def segment_image(model, input_tensor):
50
  if torch.cuda.is_available():
51
  input_tensor = input_tensor.to("cuda")
52
  with torch.no_grad():
53
  output = model(input_tensor)['out'] # Model output
54
+ probabilities = torch.softmax(output, dim=1) # Get class probabilities
55
+ mask = probabilities[0, 1].cpu().numpy() # Assuming 1 corresponds to the object class
56
  return mask
57
 
58
+ # Step 7: Refine Mask and Extract Object
59
+ def apply_mask(image, mask, threshold=0.5):
60
+ # Threshold the mask
61
+ mask = (mask > threshold).astype(np.uint8)
62
+
63
+ # Resize mask to the original image size
64
+ mask = cv2.resize(mask, image.size, interpolation=cv2.INTER_NEAREST)
65
+
66
+ # Apply morphological operations for a cleaner mask
67
  kernel = np.ones((5, 5), np.uint8)
68
+ mask = cv2.dilate(mask, kernel, iterations=2)
69
  mask = cv2.erode(mask, kernel, iterations=1)
70
 
71
  # Create RGBA image
 
78
 
79
  # Gradio Interface to handle input and output
80
  def segment_object(image_url, x_min, y_min, x_max, y_max):
81
+ bounding_box = adjust_bounding_box({"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max})
82
 
83
  # Load and process the image
84
  image = load_image(image_url)
 
109
 
110
  # Launch the interface
111
  iface.launch()