File size: 4,110 Bytes
aae2aac
 
 
 
 
 
 
 
788697a
aae2aac
 
 
 
 
c043972
aae2aac
ee4ce79
aae2aac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94299ff
aae2aac
53e7451
c043972
797b9ba
aae2aac
797b9ba
 
aae2aac
7c1b1e5
 
 
 
797b9ba
 
 
 
 
 
 
 
 
 
 
 
 
aae2aac
797b9ba
 
 
 
 
 
 
 
4f4dd22
797b9ba
 
 
 
 
 
 
 
7d41601
797b9ba
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from turtle import title
import os 
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch 
import cv2 
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
from skimage.measure import label, regionprops

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")


random_images = []
images_dir = 'images/'
for idx, images in enumerate(os.listdir(images_dir)):
    image = os.path.join(images_dir, images)
    if os.path.isfile(image) and idx < 10:
        random_images.append(image)


def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
    bbox = np.asarray(bbox)/model_shape
    y1,y2 = bbox[::2] *orig_image_shape[0]
    x1,x2 = bbox[1::2]*orig_image_shape[1]
    return [int(y1),int(x1),int(y2),int(x2)]

def detect_using_clip(image,prompts=[],threshould=0.4):
    model_detections = dict()
    inputs = processor(
        text=prompts,
        images=[image] * len(prompts),
        padding="max_length",
        return_tensors="pt",
    )
    with torch.no_grad():  # Use 'torch.no_grad()' to disable gradient computation
        outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)
    detection = outputs.logits[0]  # Assuming class index 0
    for i,prompt in enumerate(prompts):
        predicted_image =  torch.sigmoid(preds[i][0]).detach().cpu().numpy()
        predicted_image = np.where(predicted_image>threshould,255,0)
        # extract countours from the image
        lbl_0 = label(predicted_image)
        props = regionprops(lbl_0)
        model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]

    return model_detections

def display_images(image,detections,prompt='traffic light'):
    H,W = image.shape[:2]
    image_copy = image.copy()
    if prompt not in detections.keys():
        print("prompt not in query ..")
        return image_copy
    for bbox in detections[prompt]:
        cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
    return image_copy


def shot(image, labels_text):
    print(labels_text)
    prompts = labels_text.split(',')
    global classes 
    classes = prompts

    detections  = detect_using_clip(image,prompts=prompts)
    
    return detections

def add_text(text):
    labels = text.split(',')
    return labels 

inputt = gr.Image(type="numpy", label="Input Image for Classification")

# with gr.Blocks(title="Zero Shot Object ddetection using Text Prompts") as demo :
#     gr.Markdown(
#     """ 
#     <center>
#     <h1>
#     The CLIP Model  
#     </h1>
#     A neural network called CLIP which efficiently learns visual concepts from natural language supervision. CLIP can be applied to any visual classification benchmark by simply providing the names of the visual categories to be recognized, similar to the “zero-shot” capabilities of GPT-2 and GPT-3.   
#     </center>
#     """
#     )

#     with gr.Row():
#         with gr.Column():
#             inputt = gr.Image(type="numpy", label="Input Image for Classification")
#             labels = gr.Textbox(label="Enter Label/ labels",placeholder="ex. car,person",scale=4)
#             button = gr.Button(value="Locate objects")
#         with gr.Column():
#             outputs = gr.Image(type="numpy", label="Detected Objects with Selected Category")
#             # dropdown = gr.Dropdown(labels,label="Select the category",info='Label selection panel')
        
#         # labels.submit(add_text, inputs=labels)
#         button.click(fn=shot,inputs=[inputt,labels],api_name='Get labels')


# demo.launch()
iface = gr.Interface(fn=shot,
                    inputs = ["image","text"],
                    outputs="label",
                    examples=["images/room.jpg","bed,table,plant"],
                    allow_flagging=False, 
                    analytics_enabled=False,
                )
iface.launch()