sagar007 commited on
Commit
c95f3e0
·
verified ·
1 Parent(s): d1f9260

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -16
app.py CHANGED
@@ -2,28 +2,43 @@ import gradio as gr
2
  import torch
3
  import cv2
4
  import numpy as np
5
- from fastsam import FastSAM, FastSAMPrompt
 
6
 
7
- # Load the FastSAM model
8
- model = FastSAM('FastSAM-x.pt')
 
 
 
 
9
 
10
  def segment_image(input_image, points):
11
- # Prepare the image
12
- input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
 
 
 
13
 
14
- # Run the model
15
- everything_results = model(input_image, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
 
16
 
17
- # Prepare prompts
18
- prompt_process = FastSAMPrompt(input_image, everything_results, device='cpu')
 
 
 
 
 
19
 
20
- # Generate mask based on points
21
- ann = prompt_process.point_prompt(points=points, pointlabel=[1] * len(points))
22
 
23
  # Overlay the mask on the original image
24
- result_image = input_image.copy()
25
- mask = ann[0].astype(bool)
26
- result_image[mask] = result_image[mask] * 0.5 + np.array([255, 0, 0]) * 0.5
 
27
 
28
  return result_image
29
 
@@ -35,8 +50,8 @@ iface = gr.Interface(
35
  gr.Image(type="numpy", tool="sketch", brush_radius=5, label="Click on objects to segment")
36
  ],
37
  outputs=gr.Image(type="numpy"),
38
- title="FastSAM Image Segmentation",
39
- description="Click on objects in the image to segment them using FastSAM."
40
  )
41
 
42
  # Launch the interface
 
2
  import torch
3
  import cv2
4
  import numpy as np
5
+ from transformers import SamModel, SamProcessor
6
+ from PIL import Image
7
 
8
+ # Set up device
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load model and processor
12
+ model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
  def segment_image(input_image, points):
16
+ # Convert input_image to PIL Image
17
+ input_image = Image.fromarray(input_image)
18
+
19
+ # Prepare inputs
20
+ inputs = processor(input_image, input_points=[points], return_tensors="pt").to(device)
21
 
22
+ # Generate masks
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
 
26
+ # Post-process masks
27
+ masks = processor.image_processor.post_process_masks(
28
+ outputs.pred_masks.cpu(),
29
+ inputs["original_sizes"].cpu(),
30
+ inputs["reshaped_input_sizes"].cpu()
31
+ )
32
+ scores = outputs.iou_scores
33
 
34
+ # Convert mask to numpy array
35
+ mask = masks[0][0].numpy()
36
 
37
  # Overlay the mask on the original image
38
+ result_image = np.array(input_image)
39
+ mask_rgb = np.zeros_like(result_image)
40
+ mask_rgb[mask > 0.5] = [255, 0, 0] # Red color for the mask
41
+ result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
42
 
43
  return result_image
44
 
 
50
  gr.Image(type="numpy", tool="sketch", brush_radius=5, label="Click on objects to segment")
51
  ],
52
  outputs=gr.Image(type="numpy"),
53
+ title="Segment Anything Model (SAM) Image Segmentation",
54
+ description="Click on objects in the image to segment them using SAM."
55
  )
56
 
57
  # Launch the interface