Saarthak2002 commited on
Commit
b07764a
·
verified ·
1 Parent(s): c4d88ec

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +87 -0
main.py CHANGED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import requests
5
+ import numpy as np
6
+ import gradio as gr
7
+ from io import BytesIO
8
+ from torchvision.models.segmentation import deeplabv3_resnet50
9
+
10
+ # Step 1: Load the Image from URL
11
+ def load_image(url):
12
+ response = requests.get(url)
13
+ image = Image.open(BytesIO(response.content)).convert("RGB")
14
+ return image
15
+
16
+ # Step 2: Crop Image Based on Bounding Box
17
+ def crop_image(image, bounding_box):
18
+ x_min, y_min, x_max, y_max = bounding_box.values()
19
+ return image.crop((x_min, y_min, x_max, y_max))
20
+
21
+ # Step 3: Preprocessing for U-Net (DeepLabV3 in this case)
22
+ def preprocess_image(image, size=(256, 256)):
23
+ preprocess = transforms.Compose([
24
+ transforms.Resize(size),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
+ ])
28
+ return preprocess(image).unsqueeze(0) # Add batch dimension
29
+
30
+ # Step 4: Load Pre-trained Segmentation Model (DeepLabV3)
31
+ def load_model():
32
+ model = deeplabv3_resnet50(pretrained=True)
33
+ model.eval() # Set the model to evaluation mode
34
+ return model
35
+
36
+ # Step 5: Perform Segmentation
37
+ def segment_image(model, input_tensor):
38
+ with torch.no_grad():
39
+ output = model(input_tensor)['out'] # Model output
40
+ mask = output.argmax(dim=1).squeeze().cpu().numpy() # Get segmentation mask
41
+ return mask
42
+
43
+ # Step 6: Postprocess and Extract Object
44
+ def apply_mask(image, mask, threshold=1):
45
+ mask_resized = Image.fromarray((mask * 255).astype(np.uint8)).resize(image.size, Image.NEAREST)
46
+ mask_resized = np.array(mask_resized) > threshold
47
+ image_np = np.array(image)
48
+
49
+ # Create RGBA image with transparency
50
+ rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
51
+ rgba_image[..., :3] = image_np # Copy RGB channels
52
+ rgba_image[..., 3] = mask_resized.astype(np.uint8) * 255 # Alpha channel based on mask
53
+ return Image.fromarray(rgba_image)
54
+
55
+ # Gradio Interface to handle input and output
56
+ def segment_object(image_url, x_min, y_min, x_max, y_max):
57
+ bounding_box = {"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}
58
+
59
+ # Load and process the image
60
+ image = load_image(image_url)
61
+ cropped_image = crop_image(image, bounding_box)
62
+ input_tensor = preprocess_image(cropped_image)
63
+
64
+ # Load model and perform segmentation
65
+ model = load_model()
66
+ mask = segment_image(model, input_tensor)
67
+
68
+ # Apply mask to extract object
69
+ result_image = apply_mask(cropped_image, mask)
70
+ return result_image
71
+
72
+ # Set up the Gradio Interface
73
+ iface = gr.Interface(
74
+ fn=segment_object,
75
+ inputs=[
76
+ gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
77
+ gr.Number(label="x_min", value=100),
78
+ gr.Number(label="y_min", value=100),
79
+ gr.Number(label="x_max", value=600),
80
+ gr.Number(label="y_max", value=400),
81
+ ],
82
+ outputs=gr.Image(label="Segmented Image"),
83
+ live=True
84
+ )
85
+
86
+ # Launch the interface
87
+ iface.launch()