Spaces:
Running
Running
File size: 4,249 Bytes
aae2aac 788697a aae2aac 2a7fc91 aae2aac 7e2bb82 aae2aac c043972 797b9ba 7e2bb82 aae2aac 7e2bb82 797b9ba aae2aac 7c1b1e5 7e2bb82 797b9ba aae2aac 797b9ba 4f4dd22 797b9ba 7e2bb82 797b9ba 7e2bb82 8f09828 797b9ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from turtle import title
import os
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch
import cv2
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
from skimage.measure import label, regionprops
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
global classes
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
bbox = np.asarray(bbox)/model_shape
y1,y2 = bbox[::2] *orig_image_shape[0]
x1,x2 = bbox[1::2]*orig_image_shape[1]
return [int(y1),int(x1),int(y2),int(x2)]
def detect_using_clip(image,prompts=[],threshould=0.4):
model_detections = dict()
inputs = processor(
text=prompts,
images=[image] * len(prompts),
padding="max_length",
return_tensors="pt",
)
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
detection = outputs.logits[0] # Assuming class index 0
for i,prompt in enumerate(prompts):
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
predicted_image = np.where(predicted_image>threshould,255,0)
# extract countours from the image
lbl_0 = label(predicted_image)
props = regionprops(lbl_0)
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
return model_detections
def display_images(image,detections,prompt='traffic light'):
H,W = image.shape[:2]
image_copy = image.copy()
if prompt not in detections.keys():
print("prompt not in query ..")
return image_copy
for bbox in detections[prompt]:
cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
return image_copy
def shot(image, labels_text):
print("Labels Text ",labels_text)
prompts = labels_text.split(',')
classes = prompts
print("prompts :",prompts,classes)
print("Image shape ",image.shape )
detections = detect_using_clip(image,prompts=prompts)
print("detections :",detections)
return detections
def add_text(text):
labels = text.split(',')
return labels
# inputt = gr.Image(type="numpy", label="Input Image for Classification")
# with gr.Blocks(title="Zero Shot Object ddetection using Text Prompts") as demo :
# gr.Markdown(
# """
# <center>
# <h1>
# The CLIP Model
# </h1>
# A neural network called CLIP which efficiently learns visual concepts from natural language supervision. CLIP can be applied to any visual classification benchmark by simply providing the names of the visual categories to be recognized, similar to the “zero-shot” capabilities of GPT-2 and GPT-3.
# </center>
# """
# )
# with gr.Row():
# with gr.Column():
# inputt = gr.Image(type="numpy", label="Input Image for Classification")
# labels = gr.Textbox(label="Enter Label/ labels",placeholder="ex. car,person",scale=4)
# button = gr.Button(value="Locate objects")
# with gr.Column():
# outputs = gr.Image(type="numpy", label="Detected Objects with Selected Category")
# # dropdown = gr.Dropdown(labels,label="Select the category",info='Label selection panel')
# # labels.submit(add_text, inputs=labels)
# button.click(fn=shot,inputs=[inputt,labels],api_name='Get labels')
# demo.launch()
iface = gr.Interface(fn=shot,
inputs = ["image","text",gr.Dropdown(classes, label="Category Label",info='Select Categories')],
outputs="label",
description="Add a picture and a list of labels separated by commas",
title="Zero-shot Image Classification with Prompt ",
examples=[["images/room.jpg","bed,table,plant"]],
# allow_flagging=False,
# analytics_enabled=False,
)
iface.launch()
|