Spaces:
Sleeping
Sleeping
File size: 2,986 Bytes
aae2aac 788697a aae2aac afd6f20 aae2aac b2f9f4f aae2aac 630b911 7e2bb82 aae2aac c043972 797b9ba 7e2bb82 1d61991 630b911 b2f9f4f 7c1b1e5 797b9ba b2f9f4f 739883b 7e2bb82 436eff8 8f09828 797b9ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from turtle import title
import os
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch
import cv2
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
from skimage.measure import label, regionprops
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
bbox = np.asarray(bbox)/model_shape
y1,y2 = bbox[::2] *orig_image_shape[0]
x1,x2 = bbox[1::2]*orig_image_shape[1]
return [int(y1),int(x1),int(y2),int(x2)]
def detect_using_clip(image,prompts=[],threshould=0.4):
model_detections = dict()
inputs = processor(
text=prompts,
images=[image] * len(prompts),
padding="max_length",
return_tensors="pt",
)
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
detection = outputs.logits[0] # Assuming class index 0
for i,prompt in enumerate(prompts):
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
predicted_image = np.where(predicted_image>threshould,255,0)
# extract countours from the image
lbl_0 = label(predicted_image)
props = regionprops(lbl_0)
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
return model_detections
def visualize_images(image,detections,prompt):
H,W = image.shape[:2]
image_copy = image.copy()
if prompt not in detections.keys():
print("prompt not in query ..")
return image_copy
for bbox in detections[prompt]:
cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
return image_copy
def shot(image, labels_text,selected_categoty):
print("Labels Text ",labels_text)
prompts = labels_text.split(',')
classes = prompts
print("prompts :",prompts,classes)
print("Image shape ",image.shape )
model_detections = detect_using_clip(image,prompts=prompts)
print("detections :",model_detections)
print("Ctegory ",selected_categoty)
return visualize_images(image=image,detections=model_detections,prompt=selected_categoty)
iface = gr.Interface(fn=shot,
inputs = ["image","text","text"],
outputs="image",
description="Add a picture and a list of labels separated by commas",
title="Zero-shot Image Classification with Prompt ",
examples=[["images/room.jpg","bed,table,plant",'plant']],
# allow_flagging=False,
# analytics_enabled=False,
)
iface.launch()
|