sagar007 commited on
Commit
e0d4d2f
·
verified ·
1 Parent(s): 49e0cdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -96
app.py CHANGED
@@ -1,99 +1,39 @@
1
  import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
  import torch
5
- from transformers import AutoProcessor, CLIPSegForImageSegmentation
6
-
7
- # Load the CLIPSeg model and processor
8
- processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
- model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
-
11
- # Ensure that the model uses GPU if available
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
- model.to(device)
14
-
15
- def segment_everything(image):
16
- if isinstance(image, list):
17
- image = image[0]
18
-
19
- if isinstance(image, np.ndarray):
20
- image = Image.fromarray(image)
21
-
22
- inputs = processor(text=["object"], images=image, padding="max_length", return_tensors="pt").to(device)
23
- with torch.no_grad():
24
- outputs = model(**inputs)
25
- preds = outputs.logits.squeeze().sigmoid().cpu()
26
- segmentation = (preds.numpy() * 255).astype(np.uint8)
27
- return Image.fromarray(segmentation)
28
-
29
- def segment_box(image, x1, y1, x2, y2):
30
- if isinstance(image, list):
31
- image = image[0]
32
-
33
- if isinstance(image, Image.Image):
34
- image = np.array(image)
35
-
36
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
37
- cropped_image = image[y1:y2, x1:x2]
38
- inputs = processor(text=["object"], images=Image.fromarray(cropped_image), padding="max_length", return_tensors="pt").to(device)
39
- with torch.no_grad():
40
- outputs = model(**inputs)
41
- preds = outputs.logits.squeeze().sigmoid().cpu()
42
- segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
43
- segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
44
- return Image.fromarray(segmentation)
45
-
46
- def update_image(image, segmentation):
47
- if segmentation is None:
48
- return image
49
-
50
- if isinstance(image, list):
51
- image = image[0]
52
-
53
- if isinstance(image, np.ndarray):
54
- image_pil = Image.fromarray(image)
55
- else:
56
- image_pil = image
57
-
58
- seg_pil = Image.fromarray(segmentation).convert('RGBA')
59
-
60
- if image_pil.size!= seg_pil.size:
61
- seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
62
-
63
- blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
64
-
65
- return np.array(blended)
66
-
67
- with gr.Blocks() as demo:
68
- gr.Markdown("# Segment Anything-like Demo")
69
- with gr.Row():
70
- with gr.Column(scale=1):
71
- input_image = gr.Image(label="Input Image")
72
- with gr.Row():
73
- x1_input = gr.Number(label="X1")
74
- y1_input = gr.Number(label="Y1")
75
- x2_input = gr.Number(label="X2")
76
- y2_input = gr.Number(label="Y2")
77
- with gr.Row():
78
- everything_btn = gr.Button("Everything")
79
- box_btn = gr.Button("Box")
80
- with gr.Column(scale=1):
81
- output_image = gr.Image(label="Segmentation Result")
82
-
83
- everything_btn.click(
84
- fn=segment_everything,
85
- inputs=[input_image],
86
- outputs=[output_image]
87
- )
88
- box_btn.click(
89
- fn=segment_box,
90
- inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
91
- outputs=[output_image]
92
- )
93
- output_image.change(
94
- fn=update_image,
95
- inputs=[input_image, output_image],
96
- outputs=[output_image]
97
- )
98
-
99
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+
6
+ # Load pre-trained U-Net model
7
+ model = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'unet', pretrained=True)
8
+
9
+ # Define a function to segment an image
10
+ def segment_image(image):
11
+ # Preprocess image
12
+ image = Image.fromarray(image)
13
+ image = transforms.Compose([
14
+ transforms.Resize((256, 256)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
+ ])(image)
18
+
19
+ # Run segmentation model
20
+ output = model(image.unsqueeze(0))
21
+ output = torch.argmax(output, dim=1)
22
+
23
+ # Postprocess output
24
+ output = output.squeeze(0).cpu().numpy()
25
+ output = Image.fromarray(output.astype('uint8'))
26
+
27
+ return output
28
+
29
+ # Create Gradio app
30
+ demo = gr.Interface(
31
+ fn=segment_image,
32
+ inputs=gr.Image(type="pil"),
33
+ outputs=gr.Image(type="pil"),
34
+ title="Segment Anything",
35
+ description="Segment any image using a pre-trained U-Net model"
36
+ )
37
+
38
+ # Launch Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  demo.launch()