Spaces:
Running
Running
import os | |
import gradio as gr | |
from gradio_client import Client, handle_file | |
from pathlib import Path | |
from gradio.utils import get_cache_folder | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import ast | |
# from zerogpu import init_zerogpu | |
# init_zerogpu() | |
class Examples(gr.helpers.Examples): | |
def __init__(self, *args, cached_folder=None, **kwargs): | |
super().__init__(*args, **kwargs, _initiated_directly=False) | |
if cached_folder is not None: | |
self.cached_folder = cached_folder | |
# self.cached_file = Path(self.cached_folder) / "log.csv" | |
self.create() | |
def postprocess(output, prompt): | |
result = [] | |
image = Image.open(output) | |
w, h = image.size | |
n = len(prompt) | |
slice_width = w // n | |
for i in range(n): | |
left = i * slice_width | |
right = (i + 1) * slice_width if i < n - 1 else w | |
cropped_img = image.crop((left, 0, right, h)) | |
caption = prompt[i] | |
result.append((cropped_img, caption)) | |
return result | |
# user click the image to get points, and show the points on the image | |
def get_point(img, sel_pix, evt: gr.SelectData): | |
# print(img, sel_pix) | |
if len(sel_pix) < 5: | |
sel_pix.append((evt.index, 1)) # default foreground_point | |
img = cv2.imread(img) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# draw points | |
for point, label in sel_pix: | |
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
# if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
print(sel_pix) | |
return img, sel_pix | |
def set_point(img, checkbox_group, sel_pix, semantic_input): | |
ori_img = img | |
# print(img, checkbox_group, sel_pix, semantic_input) | |
sel_pix = ast.literal_eval(sel_pix) | |
img = cv2.imread(img) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
if len(sel_pix) <= 5 and len(sel_pix) > 0: | |
for point, label in sel_pix: | |
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
return ori_img, img, sel_pix | |
# undo the selected point | |
def undo_points(orig_img, sel_pix): | |
if isinstance(orig_img, int): # if orig_img is int, the image if select from examples | |
temp = cv2.imread(image_examples[orig_img][0]) | |
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) | |
else: | |
temp = cv2.imread(orig_img) | |
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) | |
# draw points | |
if len(sel_pix) != 0: | |
sel_pix.pop() | |
for point, label in sel_pix: | |
cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB | |
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) | |
return temp, sel_pix | |
HF_TOKEN = os.environ.get('HF_KEY') | |
client = Client("Canyu/Diception", | |
max_workers=3, | |
hf_token=HF_TOKEN) | |
colors = [(255, 0, 0), (0, 255, 0)] | |
markers = [1, 5] | |
def process_image_check(path_input, prompt, sel_points, semantic): | |
if path_input is None: | |
raise gr.Error( | |
"Missing image in the left pane: please upload an image first." | |
) | |
if len(prompt) == 0: | |
raise gr.Error( | |
"At least 1 prediction type is needed." | |
) | |
def inf(image_path, prompt, sel_points, semantic): | |
if isinstance(sel_points, str): | |
sel_points = ast.literal_eval(selected_points) | |
print('=========== PROCESS IMAGE CHECK ===========') | |
print(f"Image Path: {image_path}") | |
print(f"Prompt: {prompt}") | |
print(f"Selected Points (before processing): {sel_points}") | |
print(f"Semantic Input: {semantic}") | |
print('===========================================') | |
if 'point segmentation' in prompt and len(sel_points) == 0: | |
raise gr.Error( | |
"At least 1 point is needed." | |
) | |
return | |
if 'point segmentation' not in prompt and len(sel_points) != 0: | |
raise gr.Error( | |
"You must select 'point segmentation' when performing point segmentation." | |
) | |
return | |
if 'semantic segmentation' in prompt and semantic == '': | |
raise gr.Error( | |
"Target category is needed." | |
) | |
return | |
if 'semantic segmentation' not in prompt and semantic != '': | |
raise gr.Error( | |
"You must select 'semantic segmentation' when performing semantic segmentation." | |
) | |
return | |
# return None | |
# inputs = process_image_4(image_path, prompt, sel_points, semantic) | |
prompt_str = str(sel_points) | |
result = client.predict( | |
input_image=handle_file(image_path), | |
checkbox_group=prompt, | |
selected_points=prompt_str, | |
semantic_input=semantic, | |
api_name="/inf" | |
) | |
result = postprocess(result, prompt) | |
return result | |
def clear_cache(): | |
return None, None | |
def dummy(): | |
pass | |
def run_demo_server(): | |
options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation'] | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="Diception", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
.hideme { | |
display: none; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as demo: | |
selected_points = gr.State([]) # store points | |
original_image = gr.State(value=None) # store original image without points, default None | |
gr.HTML( | |
""" | |
<h1>DICEPTION: A Generalist Diffusion Model for Vision Perception</h1> | |
<h3>One single model solves multiple perception tasks, producing impressive results!</h3> | |
<h3>Due to the GPU quota limit, if an error occurs, please wait for 5 minutes before retrying.</h3> | |
<p align="center"> | |
<a title="arXiv" href="https://arxiv.org/abs/2502.17157" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
</a> | |
<a title="Github" href="https://github.com/aim-uofa/Diception" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/aim-uofa/Diception?label=GitHub%20%E2%98%85&logo=github&color=C8C" | |
alt="badge-github-stars"> | |
</a> | |
</p> | |
""" | |
) | |
selected_points_tmp = gr.Textbox(label="Points", elem_classes="hideme") | |
with gr.Row(): | |
checkbox_group = gr.CheckboxGroup(choices=options, label="Task") | |
with gr.Row(): | |
semantic_input = gr.Textbox(label="Category Name", placeholder="e.g. person/cat/dog/elephant...... (for semantic segmentation only, in COCO)") | |
with gr.Row(): | |
gr.Markdown('For non-human image inputs, the pose results may have issues. Same when perform semantic segmentation with categories that are not in COCO.') | |
with gr.Row(): | |
gr.Markdown('The results of semantic segmentation may be unstable because:') | |
with gr.Row(): | |
gr.Markdown('- We only trained on COCO, whose quality and quantity are insufficient to meet the requirements.') | |
with gr.Row(): | |
gr.Markdown('- Semantic segmentation is more complex than other tasks, as it requires accurately learning the relationship between semantics and objects.') | |
with gr.Row(): | |
gr.Markdown('However, we are still able to produce some high-quality semantic segmentation results, strongly demonstrating the potential of our approach.') | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
gr.Markdown('You can click on the image to select points prompt. At most 5 point.') | |
matting_image_submit_btn = gr.Button( | |
value="Run", variant="primary" | |
) | |
with gr.Row(): | |
undo_button = gr.Button('Undo point') | |
matting_image_reset_btn = gr.Button(value="Reset") | |
with gr.Column(): | |
matting_image_output = gr.Gallery(label="Results") | |
# img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output]) | |
matting_image_submit_btn.click( | |
fn=process_image_check, | |
inputs=[input_image, checkbox_group, selected_points, semantic_input], | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
fn=inf, | |
inputs=[original_image, checkbox_group, selected_points, semantic_input], | |
outputs=[matting_image_output], | |
concurrency_limit=1, | |
) | |
matting_image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
[] | |
), | |
inputs=[], | |
outputs=[ | |
input_image, | |
matting_image_output, | |
selected_points | |
], | |
queue=False, | |
) | |
# once user upload an image, the original image is stored in `original_image` | |
def store_img(img): | |
return img, [] # when new image is uploaded, `selected_points` should be empty | |
input_image.upload( | |
store_img, | |
[input_image], | |
[original_image, selected_points] | |
) | |
input_image.select( | |
get_point, | |
[original_image, selected_points], | |
[input_image, selected_points], | |
) | |
undo_button.click( | |
undo_points, | |
[original_image, selected_points], | |
[input_image, selected_points] | |
) | |
examples = gr.Examples( | |
fn=set_point, | |
run_on_click=True, | |
examples=[ | |
["assets/woman.jpg", ['point segmentation', 'depth', 'normal', 'entity segmentation', 'human pose', 'semantic segmentation'], '[([2744, 975], 1), ([3440, 1954], 1), ([2123, 2405], 1), ([838, 1678], 1), ([4688, 1922], 1)]', 'person'], | |
["assets/woman2.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'human pose'], '[([687, 1416], 1), ([1021, 707], 1), ([1138, 1138], 1), ([1182, 1583], 1), ([1188, 2172], 1)]', 'person'], | |
["assets/board.jpg", ['point segmentation', 'depth', 'entity segmentation', 'normal'], '[([1003, 2163], 1)]', ''], | |
["assets/lion.jpg", ['point segmentation', 'depth', 'semantic segmentation'], '[([1287, 671], 1)]', 'lion'], | |
["assets/apple.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([3367, 1950], 1)]','apple'], | |
["assets/room.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([1308, 2215], 1)]', 'chair'], | |
["assets/car.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([1276, 1369], 1)]', 'car'], | |
["assets/person.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation', 'human pose'], '[([3253, 1459], 1)]', 'tie'], | |
["assets/woman3.jpg", ['point segmentation', 'depth', 'entity segmentation'], '[([420, 692], 1)]', ''], | |
["assets/cat.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation'], '[([756, 661], 1)]', 'cat'], | |
["assets/room2.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'normal'], '[([3946, 224], 1)]', 'laptop'], | |
["assets/cartoon_cat.png", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'normal'], '[([1478, 3048], 1)]', 'cat'], | |
["assets/sheep.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation'], '[([1789, 1791], 1), ([1869, 1333], 1)]', 'sheep'], | |
["assets/cartoon_girl.jpeg", ['point segmentation', 'depth', 'entity segmentation', 'normal', 'human pose', 'semantic segmentation'], '[([1208, 2089], 1), ([635, 2731], 1), ([1070, 2888], 1), ([1493, 2350], 1)]', 'person'], | |
], | |
inputs=[input_image, checkbox_group, selected_points_tmp, semantic_input], | |
outputs=[original_image, input_image, selected_points], | |
cache_examples=False, | |
) | |
demo.queue( | |
api_open=False, | |
).launch() | |
if __name__ == '__main__': | |
run_demo_server() |