File size: 2,756 Bytes
1e58367
bc5dfe0
1e58367
 
 
 
95255f1
 
 
 
 
 
 
 
 
1e58367
 
 
ba97523
 
1e58367
5a3f926
 
 
 
 
1e58367
 
fbfcc0e
1e58367
12d1976
5a3f926
9808945
 
150c578
1e58367
 
bc5dfe0
1e58367
 
 
 
 
bc5dfe0
 
 
 
 
 
 
 
 
1e58367
 
 
fbfcc0e
1e58367
 
 
 
ba97523
a3268d8
fbfcc0e
1e58367
 
 
5a3f926
1e58367
 
fbfcc0e
4777db1
1e58367
88465f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import cv2
import gradio as gr
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection


# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")


def query_image(img, text_queries, score_threshold):
    text_queries = text_queries
    text_queries = text_queries.split(",")
    
    target_sizes = torch.Tensor([img.shape[:2]])
    img_input = cv2.resize(img, (768, 768), interpolation = cv2.INTER_AREA)
                     
    inputs = processor(text=text_queries, images=img_input, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)


    
    outputs.logits = outputs.logits.cpu()
    outputs.pred_boxes = outputs.pred_boxes.cpu() 
    results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]

    font = cv2.FONT_HERSHEY_SIMPLEX

    for box, score, label in zip(boxes, scores, labels):
        box = [int(i) for i in box.tolist()]

        if score >= score_threshold:
            img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
            if box[3] + 25 > 768:
                y = box[3] - 10
            else:
                y = box[3] + 25
                
            img = cv2.putText(
                img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
            )
    return img


description = """
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlvit">OWL-ViT</a>, 
introduced in <a href="https://arxiv.org/abs/2205.06230">Simple Open-Vocabulary Object Detection
with Vision Transformers</a>. 
\n\nYou can use OWL-ViT to query images with text descriptions of any object. 
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
can also use the score threshold slider to set a threshold to filter out low probability predictions.
\n\n<a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb">Colab demo</a>
"""
demo = gr.Interface(
    query_image, 
    inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)], 
    outputs="image",
    title="Zero-Shot Object Detection with OWL-ViT",
    description=description,
    examples=[["assets/astronaut.png", "human face, rocket, flag, nasa badge", 0.11], ["assets/coffee.png", "coffee mug, spoon, plate", 0.1]],
)
demo.launch()