File size: 3,289 Bytes
420fa3e
78d66d3
420fa3e
c73b59b
 
78d66d3
c73b59b
5017de6
c73b59b
 
420fa3e
5017de6
 
420fa3e
78d66d3
c73b59b
 
 
 
 
c30e671
420fa3e
c73b59b
 
 
 
 
 
5017de6
c73b59b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78d66d3
c73b59b
 
 
 
 
78d66d3
c73b59b
 
78d66d3
c73b59b
 
 
 
 
 
 
 
 
78d66d3
c73b59b
 
 
78d66d3
5017de6
 
c73b59b
5017de6
 
 
420fa3e
 
c73b59b
 
5017de6
c30e671
 
5017de6
 
 
 
c30e671
 
420fa3e
c73b59b
5017de6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np
import gradio as gr
from PIL import Image

# check if cuda is available
device = "cuda" if torch.cuda.is_available() else "cpu"

# we initialize model and processor
checkpoint = "google/owlv2-base-patch16-ensemble"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=device)
sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")

def apply_mask(image, mask, color):
    """Apply a mask to an image with a specific color."""
    for c in range(3):  # iterate over rgb channels
        image[:, :, c] = np.where(mask, color[c], image[:, :, c])
    return image

def query(image, texts, threshold):
    texts = texts.split(",")
    predictions = detector(
        image,
        candidate_labels=texts,
        threshold=threshold
    )
    
    image = np.array(image).copy()
    
    colors = [
        (255, 0, 0),  # Red
        (0, 255, 0),  # Green
        (0, 0, 255),  # Blue
        (255, 255, 0),  # Yellow
        (255, 165, 0),  # Orange
        (255, 0, 255)  # Magenta
    ]
    
    for i, pred in enumerate(predictions):
        score = pred["score"]
        if score > 0.5:
            box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
                   round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]

            inputs = sam_processor(
                image,
                input_boxes=[[[box]]],
                return_tensors="pt"
            ).to(device)

            with torch.no_grad():
                outputs = sam_model(**inputs)

            mask = sam_processor.image_processor.post_process_masks(
                outputs.pred_masks.cpu(),
                inputs["original_sizes"].cpu(),
                inputs["reshaped_input_sizes"].cpu()
            )[0][0][0].numpy()
            
            # we apply the mask with the corresponding color
            color = colors[i % len(colors)]  # we cycle through colors
            image = apply_mask(image, mask > 0.5, color)

    result_image = Image.fromarray(image)
    
    return result_image

description = (
    "Welcome to RobustSAM by Snap Research."
    "This Space uses RobustSAM, a robust version of the Segment Anything Model (SAM) with improved performance on low-quality images while maintaining zero-shot segmentation capabilities. "
    "Thanks to its integration with OWLv2, RobustSAM becomes text-promptable, allowing for flexible and accurate segmentation, even with degraded image quality. Try the example or input an image with comma-separated candidate labels to see the enhanced segmentation results."
)

demo = gr.Interface(
    query,
    inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label="Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
    outputs=gr.Image(type="pil", label="Segmented Image"),
    title="RobustSAM",
    description=description,
    examples=[
        ["./blur.jpg", "insect", 0.1],
        ["./lowlight.jpg", "bus, window", 0.1],
        ["./rain.jpg", "tree, leafs", 0.1],
        ["./haze.jpg", "", 0.1],
    ],
    cache_examples=True
)

demo.launch()