visualizing_diffusion_attention / src /app_cached_videos.py
anshuln's picture
Update src/app_cached_videos.py
06062dc verified
raw
history blame
5.35 kB
from pathlib import Path
import cv2
import sys
import gradio as gr
import os
import numpy as np
from gradio_utils import *
from transformers import CLIPTokenizer
def image_mod(image):
return image.rotate(45)
sys.path.insert(1, os.path.join(sys.path[0], '..'))
NUM_POINTS = 3
NUM_FRAMES = 16
LARGE_BOX_SIZE = 176
data = {}
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def get_token_number(prompt, word):
all_tokens = tokenizer(prompt).input_ids
word_tokens = tokenizer(word).input_ids
print(all_tokens, word_tokens, word)
return all_tokens.index(word_tokens[1]) # Word_tokens start with cls
def overlay_mask(img, mask):
mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
# Create a 3-channel version of the mask
mask_3ch = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR)
# Set the opacity level
opacity = 0.25 # Adjust as needed
alpha_channel = np.ones_like(mask_resized) * 255 # Start with a fully opaque alpha channel
# Set black pixels to be completely transparent
alpha_channel[mask_resized < 5] = 0
# Set the opacity level for non-black pixels
opacity = 0.3 # Adjust this value as needed (0.0 to 1.0)
alpha_channel[mask_resized != 0] = int(255 * opacity)
# Create a 4-channel image (BGR + Alpha)
b, g, r = cv2.split(img)
rgba = [b, g, r, alpha_channel]
result = cv2.merge(rgba, 4)
# Overlay the mask on the image
overlay = cv2.addWeighted(mask_3ch, opacity, img, 1 - opacity, 0)
return overlay
def fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num=0):
frame_num = frame_num - 1
if layer_num is None:
layer_num = 0
else:
layer_num = 100 if layer_num == 3 else layer_num
video_file_name = f"./data/videos/{prompt.replace(' ', '_')}/video/frame_{frame_num:04d}.png"
img = cv2.imread(video_file_name)
if word is None:
overlaid_image = img
else:
mask_file_name = f'./data/final_masks/attention_probs_{prompt}/frame_{frame_num}_layer_{layer_num}_diffusionstep_{diffusion_step}_token_{get_token_number(prompt, word)}.png'
mask = cv2.imread(mask_file_name, cv2.IMREAD_GRAYSCALE)
overlaid_image = overlay_mask(img, mask)
print(mask_file_name)
return img, overlaid_image
def fetch_proper_img_and_change_prompt(prompt, word, frame_num, diffusion_step, layer_num=0):
radio = change_text_prompt(prompt)
video_1, video_2 = fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num)
return [video_1, video_2, radio]
css = """
.word-btn {
width: fit-content;
padding: 3px;
}
.word-btns-container {
flex-direction: row;
}
"""
registry = {
'spider': 'mask_1',
'descending': 'mask_2',
}
data_path = Path(
'data'
)
available_prompts = ['a dog and a cat sitting','A fish swimming in the water', 'A spider descending from its web', 'An astronaut riding a horse']
with gr.Blocks(css=css) as demo:
with gr.Row():
video_1 = gr.Image(label="Image", )
# video_1 = gr.Image(label="Image", width=256, height=256)
video_2 = gr.Image(label="Image with Attention Mask", )
# video_2 = gr.Image(label="Image with Attention Mask", width=256, height=256)
def change_text_prompt(text):
return gr.Radio(text.strip().split(' '), value=None, label='Choose a word to visualize its attention mask.')
text = 'a dog and a cat sitting'
gr.Markdown("""
## Visualizing Attention Masks
* Select a prompt from the drop down
* Click on "Get words" to get the words in the prompt
* Select a radio button from the words to visualize the attention mask
* Play around with the index of diffusion steps, layers to visualize different masks
* Brighter mask corresponds to larger values of attention.
""")
with gr.Group("Video Selection"):
txt_1 = gr.Dropdown(choices=available_prompts, label="Video Prompt", value=available_prompts[0])
submit_btn = gr.Button('Get words')
with gr.Group('Word Selection'):
radio = gr.Radio(text.split(' '), value=None, label='Choose a word to visualize its attention mask.')
range_slider = gr.Slider(1, 16, 1, step=2, label='Frame of the generated video to visualize the attention mask.')
diffusion_slider = gr.Slider(0, 35, 0, step=5, label='Index of diffusion steps.')
layer_num_slider = gr.Slider(0, 6, 0, step=1, label='Layer number for attention mask.')
radio.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
range_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
diffusion_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
layer_num_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
submit_btn.click(change_text_prompt, inputs=[txt_1], outputs=[radio])
if __name__ == "__main__":
demo.launch(server_name='0.0.0.0')