File size: 2,402 Bytes
dfdcd97
a3ee867
c95f3e0
fd55cab
 
 
 
 
c95f3e0
fd55cab
 
 
3cd1243
fd55cab
 
e0d4d2f
fd55cab
 
 
 
 
 
 
 
 
 
 
 
 
4f39124
fd55cab
 
72f4c5c
fd55cab
 
 
3ba1061
fd55cab
 
 
 
 
 
 
 
 
 
e0d4d2f
e9cd6fd
fd55cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0d4d2f
fd55cab
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
from PIL import Image
import cv2
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt

# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load FastSAM model
fast_sam = FastSAM('FastSAM-x.pt')

def process_image_clip(image, text_input):
    # Process image for CLIP
    inputs = processor(
        images=image,
        text=[text_input],
        return_tensors="pt",
        padding=True
    )
    
    # Get model predictions
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)
    
    confidence = float(probs[0][0])
    return f"Confidence that the image contains '{text_input}': {confidence:.2%}"

def process_image_fastsam(image):
    # Convert PIL image to numpy array
    image_np = np.array(image)
    
    # Run FastSAM inference
    everything_results = fast_sam(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
    prompt_process = FastSAMPrompt(image_np, everything_results, device='cpu')
    
    # Get everything mask
    ann = prompt_process.everything()
    
    # Convert annotation to image
    result_image = prompt_process.plot_to_result()
    return Image.fromarray(result_image)

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# CLIP and FastSAM Demo")
    
    with gr.Tab("CLIP Zero-Shot Classification"):
        with gr.Row():
            image_input = gr.Image(type="pil", label="Input Image")
            text_input = gr.Textbox(label="What do you want to check in the image?", placeholder="Type here...")
        output_text = gr.Textbox(label="Result")
        classify_btn = gr.Button("Classify")
        classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
    
    with gr.Tab("FastSAM Segmentation"):
        with gr.Row():
            image_input_sam = gr.Image(type="pil", label="Input Image")
            image_output = gr.Image(type="pil", label="Segmentation Result")
        segment_btn = gr.Button("Segment")
        segment_btn.click(fn=process_image_fastsam, inputs=[image_input_sam], outputs=image_output)

if __name__ == "__main__":
    demo.launch()