|
from pathlib import Path |
|
|
|
import cv2 |
|
import sys |
|
import gradio as gr |
|
import os |
|
import numpy as np |
|
from gradio_utils import * |
|
|
|
from transformers import CLIPTokenizer |
|
|
|
def image_mod(image): |
|
return image.rotate(45) |
|
|
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], '..')) |
|
|
|
|
|
NUM_POINTS = 3 |
|
NUM_FRAMES = 16 |
|
LARGE_BOX_SIZE = 176 |
|
|
|
|
|
|
|
data = {} |
|
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32') |
|
def get_token_number(prompt, word): |
|
all_tokens = tokenizer(prompt).input_ids |
|
word_tokens = tokenizer(word).input_ids |
|
print(all_tokens, word_tokens, word) |
|
return all_tokens.index(word_tokens[1]) |
|
|
|
def overlay_mask(img, mask): |
|
mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
mask_3ch = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
opacity = 0.25 |
|
alpha_channel = np.ones_like(mask_resized) * 255 |
|
|
|
|
|
alpha_channel[mask_resized < 5] = 0 |
|
|
|
|
|
opacity = 0.75 |
|
alpha_channel[mask_resized != 0] = int(255 * opacity) |
|
|
|
|
|
b, g, r = cv2.split(img) |
|
rgba = [b, g, r, alpha_channel] |
|
result = cv2.merge(rgba, 4) |
|
|
|
overlay = cv2.addWeighted(mask_3ch, opacity, img, 1 - opacity, 0) |
|
return overlay |
|
|
|
def fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num=0): |
|
|
|
|
|
frame_num = frame_num - 1 |
|
if layer_num is None: |
|
layer_num = 0 |
|
else: |
|
layer_num = 100 if layer_num == 3 else layer_num |
|
|
|
video_file_name = f"./data/videos/{prompt.replace(' ', '_')}/video/frame_{frame_num:04d}.png" |
|
img = cv2.imread(video_file_name) |
|
|
|
if word is None: |
|
overlaid_image = img |
|
else: |
|
mask_file_name = f'./data/final_masks/attention_probs_{prompt}/frame_{frame_num}_layer_{layer_num}_diffusionstep_{diffusion_step}_token_{get_token_number(prompt, word)}.png' |
|
mask = cv2.imread(mask_file_name, cv2.IMREAD_GRAYSCALE) |
|
overlaid_image = overlay_mask(img, mask) |
|
print(mask_file_name) |
|
|
|
return img, overlaid_image |
|
|
|
|
|
def fetch_proper_img_and_change_prompt(prompt, word, frame_num, diffusion_step, layer_num=0): |
|
radio = change_text_prompt(prompt) |
|
video_1, video_2 = fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num) |
|
return [video_1, video_2, radio] |
|
|
|
css = """ |
|
.word-btn { |
|
width: fit-content; |
|
padding: 3px; |
|
} |
|
.word-btns-container { |
|
flex-direction: row; |
|
} |
|
""" |
|
|
|
registry = { |
|
'spider': 'mask_1', |
|
'descending': 'mask_2', |
|
} |
|
|
|
data_path = Path( |
|
'data' |
|
) |
|
|
|
|
|
available_prompts = ['a dog and a cat sitting','A fish swimming in the water', 'A spider descending from its web', 'An astronaut riding a horse'] |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Row(): |
|
video_1 = gr.Image(label="Image", ) |
|
|
|
video_2 = gr.Image(label="Image with Attention Mask", ) |
|
|
|
|
|
|
|
def change_text_prompt(text): |
|
return gr.Radio(text.strip().split(' '), value=None, label='Choose a word to visualize its attention mask.') |
|
|
|
text = 'a dog and a cat sitting' |
|
gr.Markdown(""" |
|
## Visualizing Attention Masks |
|
* Select a prompt from the drop down |
|
* Click on "Get words" to get the words in the prompt |
|
* Select a radio button from the words to visualize the attention mask |
|
* Play around with the index of diffusion steps, layers to visualize different masks |
|
""") |
|
|
|
with gr.Group("Video Selection"): |
|
txt_1 = gr.Dropdown(choices=available_prompts, label="Video Prompt", value=available_prompts[0]) |
|
submit_btn = gr.Button('Get words') |
|
with gr.Group('Word Selection'): |
|
radio = gr.Radio(text.split(' '), value=None, label='Choose a word to visualize its attention mask.') |
|
range_slider = gr.Slider(1, 16, 1, step=2, label='Frame of the generated video to visualize the attention mask.') |
|
diffusion_slider = gr.Slider(0, 35, 0, step=5, label='Index of diffusion steps.') |
|
layer_num_slider = gr.Slider(0, 6, 0, step=1, label='Layer number for attention mask.') |
|
radio.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) |
|
range_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) |
|
diffusion_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) |
|
layer_num_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) |
|
submit_btn.click(change_text_prompt, inputs=[txt_1], outputs=[radio]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name='0.0.0.0') |
|
|