Spaces:
Sleeping
Sleeping
from turtle import title | |
import os | |
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
from PIL import Image | |
import torch | |
import cv2 | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig | |
from skimage.measure import label, regionprops | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
classes = list() | |
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352): | |
bbox = np.asarray(bbox)/model_shape | |
y1,y2 = bbox[::2] *orig_image_shape[0] | |
x1,x2 = bbox[1::2]*orig_image_shape[1] | |
return [int(y1),int(x1),int(y2),int(x2)] | |
def detect_using_clip(image,prompts=[],threshould=0.4): | |
model_detections = dict() | |
inputs = processor( | |
text=prompts, | |
images=[image] * len(prompts), | |
padding="max_length", | |
return_tensors="pt", | |
) | |
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation | |
outputs = model(**inputs) | |
preds = outputs.logits.unsqueeze(1) | |
detection = outputs.logits[0] # Assuming class index 0 | |
for i,prompt in enumerate(prompts): | |
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() | |
predicted_image = np.where(predicted_image>threshould,255,0) | |
# extract countours from the image | |
lbl_0 = label(predicted_image) | |
props = regionprops(lbl_0) | |
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props] | |
return model_detections | |
def display_images(image,detections,prompt='traffic light'): | |
H,W = image.shape[:2] | |
image_copy = image.copy() | |
if prompt not in detections.keys(): | |
print("prompt not in query ..") | |
return image_copy | |
for bbox in detections[prompt]: | |
cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2) | |
return image_copy | |
def shot(image, labels_text): | |
print("Labels Text ",labels_text) | |
prompts = labels_text.split(',') | |
classes = prompts | |
print("prompts :",prompts,classes) | |
print("Image shape ",image.shape ) | |
detections = detect_using_clip(image,prompts=prompts) | |
print("detections :",detections) | |
return 0 | |
iface = gr.Interface(fn=shot, | |
inputs = ["image","text",gr.Dropdown(classes, label="Category Label",info='Select Categories')], | |
outputs="image", | |
description="Add a picture and a list of labels separated by commas", | |
title="Zero-shot Image Classification with Prompt ", | |
examples=[["images/room.jpg","bed,table,plant"]], | |
# allow_flagging=False, | |
# analytics_enabled=False, | |
) | |
iface.launch() | |