Spaces:
Sleeping
Sleeping
# ------------------------------------------------------------------------------ | |
# Copyright (c) 2023, Alaa lab, UC Berkeley. All rights reserved. | |
# | |
# Written by Yulu Gan. | |
# ------------------------------------------------------------------------------ | |
from __future__ import annotations | |
import math | |
import cv2 | |
import random | |
from fnmatch import fnmatch | |
import numpy as np | |
import gradio as gr | |
import torch | |
from PIL import Image, ImageOps | |
from diffusers import StableDiffusionInstructPix2PixPipeline | |
title = "InstructCV: Instruction-Tuned Text-to-Image Diffusion Models as Vision Generalists" | |
description = """ | |
<p style='text-align: center'> Yulu Gan, Sungwoo Park, Alex Schubert, Anthony Philippakis, Ahmed Alaa <br> | |
<a href='https://arxiv.org/abs/2310.00390'>arXiv</a> | <a href='https://github.com/AlaaLab/InstructCV' target='_blank'>Code</a></p> | |
We develop a <b>unified language interface</b> for computer vision tasks that abstracts away task-specific design choices and enables task execution by following natural language instructions. \n\n | |
<b>Tips for using this demo</b>: Please upload a new image and provide an instruction outlining the specific vision task you wish InstructCV to perform (e.g., “Segment the dog”, “Detect the dog”, “Estimate the depth map of this image”, etc.). \n | |
<div style="display: flex; justify-content: center; align-items: center;"> | |
<img src="https://i.postimg.cc/hjtwgCDr/Fig1-Instruct-CV-1.png" alt="Application of InstructCV to new test images & user-written instructions" width="600"> | |
</div> | |
""" # noqa | |
example_instructions = [ | |
"Please help me detect Buzz.", | |
"Please help me detect Woody's face.", | |
"Create a monocular depth map.", | |
] | |
model_id = "alaa-lab/InstructCV" | |
def main(): | |
# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cpu") | |
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda") | |
example_image = Image.open("imgs/example2.jpg").convert("RGB") | |
def load_example( | |
seed: int, | |
randomize_seed: bool, | |
text_cfg_scale: float, | |
image_cfg_scale: float, | |
): | |
example_instruction = random.choice(example_instructions) | |
return [example_image, example_instruction] + generate( | |
example_image, | |
example_instruction, | |
seed, | |
0, | |
text_cfg_scale, | |
image_cfg_scale, | |
) | |
def generate( | |
input_image: Image.Image, | |
instruction: str, | |
seed: int, | |
randomize_seed:bool, | |
text_cfg_scale: float, | |
image_cfg_scale: float, | |
): | |
seed = random.randint(0, 100000) if randomize_seed else seed | |
text_cfg_scale = text_cfg_scale | |
image_cfg_scale = image_cfg_scale | |
width, height = input_image.size | |
factor = 512 / max(width, height) | |
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) | |
width = int((width * factor) // 64) * 64 | |
height = int((height * factor) // 64) * 64 | |
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) | |
if instruction == "": | |
return [input_image] | |
generator = torch.manual_seed(seed) | |
edited_image = pipe( | |
instruction, image=input_image, | |
guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale, | |
num_inference_steps=25, generator=generator, | |
).images[0] | |
instruction_ = instruction.lower() | |
if fnmatch(instruction_, "*segment*") or fnmatch(instruction_, "*split*") or fnmatch(instruction_, "*divide*"): | |
input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) #numpy.ndarray | |
edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY) | |
ret, thresh = cv2.threshold(edited_image, 127, 255, cv2.THRESH_BINARY) | |
img2 = input_image.copy() | |
seed_seg = np.random.randint(0,10000) | |
np.random.seed(seed_seg) | |
colors = np.random.randint(0,255,(3)) | |
colors2 = np.random.randint(0,255,(3)) | |
contours,_ = cv2.findContours(thresh,cv2.RETR_LIST,cv2.CHAIN_APPROX_NONE) | |
edited_image = cv2.drawContours(input_image,contours,-1,(int(colors[0]),int(colors[1]),int(colors[2])),3) | |
for j in range(len(contours)): | |
edited_image_2 = cv2.fillPoly(img2, [contours[j]], (int(colors2[0]),int(colors2[1]),int(colors2[2]))) | |
img_merge = cv2.addWeighted(edited_image, 0.5,edited_image_2, 0.5, 0) | |
edited_image = Image.fromarray(cv2.cvtColor(img_merge, cv2.COLOR_BGR2RGB)) | |
if fnmatch(instruction_, "*depth*"): | |
edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY) | |
n_min = np.min(edited_image) | |
n_max = np.max(edited_image) | |
edited_image = (edited_image-n_min)/(n_max-n_min+1e-8) | |
edited_image = (255*edited_image).astype(np.uint8) | |
edited_image = cv2.applyColorMap(edited_image, cv2.COLORMAP_JET) | |
edited_image = Image.fromarray(cv2.cvtColor(edited_image, cv2.COLOR_BGR2RGB)) | |
# text_cfg_scale = 7.5 | |
# image_cfg_scale = 1.5 | |
return [seed, text_cfg_scale, image_cfg_scale, edited_image] | |
with gr.Blocks() as demo: | |
# gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;"> | |
# InstructCV: Towards Universal Text-to-Image Vision Generalists | |
# </h1>""") | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1.5, min_width=100): | |
generate_button = gr.Button("Generate result") | |
with gr.Column(scale=1.5, min_width=100): | |
load_button = gr.Button("Load example") | |
with gr.Column(scale=3): | |
instruction = gr.Textbox(lines=1, label="Instruction", interactive=True) | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", type="pil", interactive=True) | |
edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False) | |
input_image.style(height=512, width=512) | |
edited_image.style(height=512, width=512) | |
with gr.Row(): | |
randomize_seed = gr.Radio( | |
["Fix Seed", "Randomize Seed"], | |
value="Randomize Seed", | |
type="index", | |
show_label=False, | |
interactive=True, | |
) | |
seed = gr.Number(value=90, precision=0, label="Seed", interactive=True) | |
text_cfg_scale = gr.Number(value=7.5, label=f"Text weight", interactive=True) | |
image_cfg_scale = gr.Number(value=1.5, label=f"Image weight", interactive=True) | |
# gr.Markdown(Intro_text) | |
load_button.click( | |
fn=load_example, | |
inputs=[ | |
seed, | |
randomize_seed, | |
text_cfg_scale, | |
image_cfg_scale, | |
], | |
outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image], | |
) | |
generate_button.click( | |
fn=generate, | |
inputs=[ | |
input_image, | |
instruction, | |
seed, | |
randomize_seed, | |
text_cfg_scale, | |
image_cfg_scale, | |
], | |
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image], | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch(share=False) | |
if __name__ == "__main__": | |
main() | |