File size: 5,154 Bytes
dfdcd97
a3ee867
c95f3e0
fd55cab
 
 
 
3701938
eefe5b4
c95f3e0
fd55cab
3d6a9c7
fd55cab
3cd1243
eefe5b4
 
 
 
 
e0d4d2f
fd55cab
3701938
 
eefe5b4
 
3701938
3d6a9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72f4c5c
fd55cab
3701938
 
 
eefe5b4
3d6a9c7
 
 
 
 
 
eefe5b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0d4d2f
e9cd6fd
eefe5b4
3701938
 
 
 
 
 
 
 
fd55cab
 
 
3d6a9c7
eefe5b4
 
 
 
 
fd55cab
 
 
3d6a9c7
 
 
 
 
 
 
 
fd55cab
 
 
3d6a9c7
 
fd55cab
 
3d6a9c7
 
 
 
 
 
 
 
3701938
 
 
 
 
eefe5b4
 
 
 
3701938
e0d4d2f
3d6a9c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import torch
from PIL import Image
import cv2
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from ultralytics import FastSAM
import supervision as sv
import os

# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Initialize FastSAM model
FASTSAM_WEIGHTS = "FastSAM-s.pt"
if not os.path.exists(FASTSAM_WEIGHTS):
    os.system(f"wget https://huggingface.co/spaces/An-619/FastSAM/resolve/main/weights/{FASTSAM_WEIGHTS}")
fast_sam = FastSAM(FASTSAM_WEIGHTS)

def process_image_clip(image, text_input):
    if image is None:
        return "Please upload an image first."
    if not text_input:
        return "Please enter some text to check in the image."
    
    try:
        # Convert numpy array to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        # Create a list of candidate labels
        candidate_labels = [text_input, f"not {text_input}"]
        
        # Process image and text
        inputs = processor(
            images=image,
            text=candidate_labels,
            return_tensors="pt",
            padding=True
        )
        
        # Get model predictions
        outputs = model(**{k: v for k, v in inputs.items()})
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
        
        # Get confidence for the positive label
        confidence = float(probs[0][0])
        return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
    except Exception as e:
        return f"Error processing image: {str(e)}"

def process_image_fastsam(image):
    if image is None:
        return None
        
    try:
        # Convert PIL image to numpy array if needed
        if isinstance(image, Image.Image):
            image_np = np.array(image)
        else:
            image_np = image
            
        # Run FastSAM inference
        results = fast_sam(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
        
        # Get detections
        detections = sv.Detections.from_ultralytics(results[0])
        
        # Create annotator
        box_annotator = sv.BoxAnnotator()
        mask_annotator = sv.MaskAnnotator()
        
        # Annotate image
        annotated_image = mask_annotator.annotate(scene=image_np.copy(), detections=detections)
        annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
        
        return Image.fromarray(annotated_image)
    except Exception as e:
        return f"Error processing image: {str(e)}"

# Create Gradio interface
with gr.Blocks(css="footer {visibility: hidden}") as demo:
    gr.Markdown("""
    # CLIP and FastSAM Demo
    This demo combines two powerful AI models:
    - **CLIP**: For zero-shot image classification
    - **FastSAM**: For automatic image segmentation
    
    Try uploading an image and use either of the tabs below!
    """)
    
    with gr.Tab("CLIP Zero-Shot Classification"):
        with gr.Row():
            image_input = gr.Image(label="Input Image")
            text_input = gr.Textbox(
                label="What do you want to check in the image?", 
                placeholder="e.g., 'a dog', 'sunset', 'people playing'",
                info="Enter any concept you want to check in the image"
            )
        output_text = gr.Textbox(label="Result")
        classify_btn = gr.Button("Classify")
        classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
        
        gr.Examples(
            examples=[
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png", "kitchen"],
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg", "calculator"],
            ],
            inputs=[image_input, text_input],
        )
    
    with gr.Tab("FastSAM Segmentation"):
        with gr.Row():
            image_input_sam = gr.Image(label="Input Image")
            image_output = gr.Image(label="Segmentation Result")
        segment_btn = gr.Button("Segment")
        segment_btn.click(fn=process_image_fastsam, inputs=[image_input_sam], outputs=image_output)
        
        gr.Examples(
            examples=[
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png"],
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg"],
            ],
            inputs=[image_input_sam],
        )
    
    gr.Markdown("""
    ### How to use:
    1. **CLIP Classification**: Upload an image and enter text to check if that concept exists in the image
    2. **FastSAM Segmentation**: Upload an image to get automatic segmentation with bounding boxes and masks
    
    ### Note:
    - The models run on CPU, so processing might take a few seconds
    - For best results, use clear images with good lighting
    """)

demo.launch(share=True)