sagar007 commited on
Commit
fd55cab
·
verified ·
1 Parent(s): d07eab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -56
app.py CHANGED
@@ -1,68 +1,69 @@
1
  import gradio as gr
2
  import torch
3
- import numpy as np
4
  from PIL import Image
5
- import matplotlib.pyplot as plt
6
- import io
7
- from sam2.sam2_image_predictor import SAM2ImagePredictor
8
-
9
- # Set up device
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load SAM 2 model
13
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
 
14
 
15
- def fig2img(fig):
16
- buf = io.BytesIO()
17
- fig.savefig(buf)
18
- buf.seek(0)
19
- img = Image.open(buf)
20
- return img
21
 
22
- def plot_masks(image, masks):
23
- fig, ax = plt.subplots(figsize=(10, 10))
24
- ax.imshow(image)
 
 
 
 
 
 
 
 
 
 
25
 
26
- for mask in masks:
27
- masked = np.ma.masked_where(mask == 0, mask)
28
- ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet'))
29
- ax.axis('off')
30
- plt.close()
31
- return fig2img(fig)
32
 
33
- def segment_everything(input_image):
34
- try:
35
- if input_image is None:
36
- return None, "Please upload an image before submitting."
37
-
38
- input_image = Image.fromarray(input_image).convert("RGB")
39
-
40
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
41
- predictor.set_image(input_image)
42
- # Use 'everything' prompt
43
- masks, _, _ = predictor.predict([])
44
-
45
- # Plot the results
46
- result_image = plot_masks(input_image, masks)
47
-
48
- return result_image, f"Segmented everything in the image. Found {len(masks)} objects."
49
 
50
- except Exception as e:
51
- return None, f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
52
 
53
  # Create Gradio interface
54
- iface = gr.Interface(
55
- fn=segment_everything,
56
- inputs=[
57
- gr.Image(type="numpy", label="Upload an image")
58
- ],
59
- outputs=[
60
- gr.Image(type="pil", label="Segmented Image"),
61
- gr.Textbox(label="Status")
62
- ],
63
- title="SAM 2 Everything Segmentation",
64
- description="Upload an image to segment all objects using SAM 2."
65
- )
 
 
 
 
 
66
 
67
- # Launch the interface
68
- iface.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ from ultralytics import FastSAM
8
+ from ultralytics.models.fastsam import FastSAMPrompt
 
9
 
10
+ # Load CLIP model
11
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
 
14
+ # Load FastSAM model
15
+ fast_sam = FastSAM('FastSAM-x.pt')
 
 
 
 
16
 
17
+ def process_image_clip(image, text_input):
18
+ # Process image for CLIP
19
+ inputs = processor(
20
+ images=image,
21
+ text=[text_input],
22
+ return_tensors="pt",
23
+ padding=True
24
+ )
25
+
26
+ # Get model predictions
27
+ outputs = model(**inputs)
28
+ logits_per_image = outputs.logits_per_image
29
+ probs = logits_per_image.softmax(dim=1)
30
 
31
+ confidence = float(probs[0][0])
32
+ return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
 
 
 
 
33
 
34
+ def process_image_fastsam(image):
35
+ # Convert PIL image to numpy array
36
+ image_np = np.array(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Run FastSAM inference
39
+ everything_results = fast_sam(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
40
+ prompt_process = FastSAMPrompt(image_np, everything_results, device='cpu')
41
+
42
+ # Get everything mask
43
+ ann = prompt_process.everything()
44
+
45
+ # Convert annotation to image
46
+ result_image = prompt_process.plot_to_result()
47
+ return Image.fromarray(result_image)
48
 
49
  # Create Gradio interface
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("# CLIP and FastSAM Demo")
52
+
53
+ with gr.Tab("CLIP Zero-Shot Classification"):
54
+ with gr.Row():
55
+ image_input = gr.Image(type="pil", label="Input Image")
56
+ text_input = gr.Textbox(label="What do you want to check in the image?", placeholder="Type here...")
57
+ output_text = gr.Textbox(label="Result")
58
+ classify_btn = gr.Button("Classify")
59
+ classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
60
+
61
+ with gr.Tab("FastSAM Segmentation"):
62
+ with gr.Row():
63
+ image_input_sam = gr.Image(type="pil", label="Input Image")
64
+ image_output = gr.Image(type="pil", label="Segmentation Result")
65
+ segment_btn = gr.Button("Segment")
66
+ segment_btn.click(fn=process_image_fastsam, inputs=[image_input_sam], outputs=image_output)
67
 
68
+ if __name__ == "__main__":
69
+ demo.launch()