File size: 5,276 Bytes
8fece0a 1315958 8fece0a 0b9869b 8fece0a ba67d41 8fece0a ba67d41 8fece0a 4b550a1 9a19ad2 8fece0a ba67d41 8fece0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from pathlib import Path
import cv2
import sys
import gradio as gr
import os
import numpy as np
from gradio_utils import *
from transformers import CLIPTokenizer
def image_mod(image):
return image.rotate(45)
sys.path.insert(1, os.path.join(sys.path[0], '..'))
NUM_POINTS = 3
NUM_FRAMES = 16
LARGE_BOX_SIZE = 176
data = {}
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def get_token_number(prompt, word):
all_tokens = tokenizer(prompt).input_ids
word_tokens = tokenizer(word).input_ids
print(all_tokens, word_tokens, word)
return all_tokens.index(word_tokens[1]) # Word_tokens start with cls
def overlay_mask(img, mask):
mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
# Create a 3-channel version of the mask
mask_3ch = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR)
# Set the opacity level
opacity = 0.25 # Adjust as needed
alpha_channel = np.ones_like(mask_resized) * 255 # Start with a fully opaque alpha channel
# Set black pixels to be completely transparent
alpha_channel[mask_resized < 5] = 0
# Set the opacity level for non-black pixels
opacity = 0.75 # Adjust this value as needed (0.0 to 1.0)
alpha_channel[mask_resized != 0] = int(255 * opacity)
# Create a 4-channel image (BGR + Alpha)
b, g, r = cv2.split(img)
rgba = [b, g, r, alpha_channel]
result = cv2.merge(rgba, 4)
# Overlay the mask on the image
overlay = cv2.addWeighted(mask_3ch, opacity, img, 1 - opacity, 0)
return overlay
def fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num=0):
frame_num = frame_num - 1
if layer_num is None:
layer_num = 0
else:
layer_num = 100 if layer_num == 3 else layer_num
video_file_name = f"./data/videos/{prompt.replace(' ', '_')}/video/frame_{frame_num:04d}.png"
img = cv2.imread(video_file_name)
if word is None:
overlaid_image = img
else:
mask_file_name = f'./data/final_masks/attention_probs_{prompt}/frame_{frame_num}_layer_{layer_num}_diffusionstep_{diffusion_step}_token_{get_token_number(prompt, word)}.png'
mask = cv2.imread(mask_file_name, cv2.IMREAD_GRAYSCALE)
overlaid_image = overlay_mask(img, mask)
print(mask_file_name)
return img, overlaid_image
def fetch_proper_img_and_change_prompt(prompt, word, frame_num, diffusion_step, layer_num=0):
radio = change_text_prompt(prompt)
video_1, video_2 = fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num)
return [video_1, video_2, radio]
css = """
.word-btn {
width: fit-content;
padding: 3px;
}
.word-btns-container {
flex-direction: row;
}
"""
registry = {
'spider': 'mask_1',
'descending': 'mask_2',
}
data_path = Path(
'data'
)
available_prompts = ['a dog and a cat sitting','A fish swimming in the water', 'A spider descending from its web', 'An astronaut riding a horse']
with gr.Blocks(css=css) as demo:
with gr.Row():
video_1 = gr.Image(label="Image", )
# video_1 = gr.Image(label="Image", width=256, height=256)
video_2 = gr.Image(label="Image with Attention Mask", )
# video_2 = gr.Image(label="Image with Attention Mask", width=256, height=256)
def change_text_prompt(text):
return gr.Radio(text.strip().split(' '), value=None, label='Choose a word to visualize its attention mask.')
text = 'a dog and a cat sitting'
gr.Markdown("""
## Visualizing Attention Masks
* Select a prompt from the drop down
* Click on "Get words" to get the words in the prompt
* Select a radio button from the words to visualize the attention mask
* Play around with the index of diffusion steps, layers to visualize different masks
""")
with gr.Group("Video Selection"):
txt_1 = gr.Dropdown(choices=available_prompts, label="Video Prompt", value=available_prompts[0])
submit_btn = gr.Button('Get words')
with gr.Group('Word Selection'):
radio = gr.Radio(text.split(' '), value=None, label='Choose a word to visualize its attention mask.')
range_slider = gr.Slider(1, 16, 1, step=2, label='Frame of the generated video to visualize the attention mask.')
diffusion_slider = gr.Slider(0, 35, 0, step=5, label='Index of diffusion steps.')
layer_num_slider = gr.Slider(0, 6, 0, step=1, label='Layer number for attention mask.')
radio.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
range_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
diffusion_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
layer_num_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
submit_btn.click(change_text_prompt, inputs=[txt_1], outputs=[radio])
if __name__ == "__main__":
demo.launch(server_name='0.0.0.0')
|