# ------------------------------------------------------------------------------ # Copyright (c) 2023, Alaa lab, UC Berkeley. All rights reserved. # # Written by Yulu Gan. # ------------------------------------------------------------------------------ from __future__ import annotations import math import cv2 import random from fnmatch import fnmatch import numpy as np import gradio as gr import torch from PIL import Image, ImageOps from diffusers import StableDiffusionInstructPix2PixPipeline title = "InstructCV" description = """

Project Page | Paper | Code

Gradio demo for InstructCV: Instruction-Tuned Text-to-Image Diffusion Models As Vision Generalists. \n You may upload any images you like and try to let the model do vision tasks following your intent. \n Some examples: You could use "Segment the dog" for segmentation, "Detect the dog" for object detection, "Estimate the depth map of this image" for depth estimation, etc. """ # noqa example_instructions = [ "Please help me detect Buzz.", "Please help me detect Woody's face.", "Create a monocular depth map.", ] model_id = "alaa-lab/InstructCV" def main(): # pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cpu") pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, safety_checker=None).to("cuda") example_image = Image.open("imgs/example2.jpg").convert("RGB") def load_example(seed: int, randomize_seed:bool): example_instruction = random.choice(example_instructions) return [example_image, example_instruction] + generate( example_image, example_instruction, seed, 0, ) def generate( input_image: Image.Image, instruction: str, seed: int, randomize_seed:bool, ): seed = random.randint(0, 100000) if randomize_seed else seed width, height = input_image.size factor = 512 / max(width, height) factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) width = int((width * factor) // 64) * 64 height = int((height * factor) // 64) * 64 input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) if instruction == "": return [input_image] generator = torch.manual_seed(seed) edited_image = pipe( instruction, image=input_image, guidance_scale=7.5, image_guidance_scale=1.5, num_inference_steps=50, generator=generator, ).images[0] instruction_ = instruction.lower() if fnmatch(instruction_, "*segment*") or fnmatch(instruction_, "*split*") or fnmatch(instruction_, "*divide*"): input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) #numpy.ndarray edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY) ret, thresh = cv2.threshold(edited_image, 127, 255, cv2.THRESH_BINARY) img2 = input_image.copy() seed_seg = np.random.randint(0,10000) np.random.seed(seed_seg) colors = np.random.randint(0,255,(3)) colors2 = np.random.randint(0,255,(3)) contours,_ = cv2.findContours(thresh,cv2.RETR_LIST,cv2.CHAIN_APPROX_NONE) edited_image = cv2.drawContours(input_image,contours,-1,(int(colors[0]),int(colors[1]),int(colors[2])),3) for j in range(len(contours)): edited_image_2 = cv2.fillPoly(img2, [contours[j]], (int(colors2[0]),int(colors2[1]),int(colors2[2]))) img_merge = cv2.addWeighted(edited_image, 0.5,edited_image_2, 0.5, 0) edited_image = Image.fromarray(cv2.cvtColor(img_merge, cv2.COLOR_BGR2RGB)) if fnmatch(instruction_, "*depth*"): edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY) n_min = np.min(edited_image) n_max = np.max(edited_image) edited_image = (edited_image-n_min)/(n_max-n_min+1e-8) edited_image = (255*edited_image).astype(np.uint8) edited_image = cv2.applyColorMap(edited_image, cv2.COLORMAP_JET) edited_image = Image.fromarray(cv2.cvtColor(edited_image, cv2.COLOR_BGR2RGB)) text_cfg_scale = 7.5 image_cfg_scale = 1.5 return [seed, text_cfg_scale, image_cfg_scale, edited_image] with gr.Blocks() as demo: # gr.HTML("""

# InstructCV: Towards Universal Text-to-Image Vision Generalists #

""") gr.Markdown("

" + title + "

") gr.Markdown(description) with gr.Row(): with gr.Column(scale=1.5, min_width=100): generate_button = gr.Button("Generate result") with gr.Column(scale=1.5, min_width=100): load_button = gr.Button("Load example") with gr.Column(scale=3): instruction = gr.Textbox(lines=1, label="Instruction", interactive=True) with gr.Row(): input_image = gr.Image(label="Input Image", type="pil", interactive=True) edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False) input_image.style(height=512, width=512) edited_image.style(height=512, width=512) with gr.Row(): randomize_seed = gr.Radio( ["Fix Seed", "Randomize Seed"], value="Randomize Seed", type="index", show_label=False, interactive=True, ) seed = gr.Number(value=90, precision=0, label="Seed", interactive=True) text_cfg_scale = gr.Number(value=7.5, label=f"Text weight", interactive=False) image_cfg_scale = gr.Number(value=1.5, label=f"Image weight", interactive=False) # gr.Markdown(Intro_text) load_button.click( fn=load_example, inputs=[seed, randomize_seed], outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image], ) generate_button.click( fn=generate, inputs=[ input_image, instruction, seed, randomize_seed, ], outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image], ) demo.queue(concurrency_count=1) demo.launch(share=False) if __name__ == "__main__": main()