File size: 5,350 Bytes
8fece0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1315958
8fece0a
 
 
 
 
 
 
 
 
 
 
 
87c8975
8fece0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba67d41
8fece0a
 
 
 
 
ba67d41
8fece0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b550a1
9a19ad2
 
 
 
 
 
06062dc
9a19ad2
 
8fece0a
 
 
 
 
ba67d41
8fece0a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from pathlib import Path

import cv2
import sys
import gradio as gr
import os
import numpy as np
from gradio_utils import *

from transformers import CLIPTokenizer

def image_mod(image):
    return image.rotate(45)


sys.path.insert(1, os.path.join(sys.path[0], '..'))


NUM_POINTS = 3
NUM_FRAMES = 16
LARGE_BOX_SIZE = 176



data = {}
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def get_token_number(prompt, word):
    all_tokens = tokenizer(prompt).input_ids
    word_tokens = tokenizer(word).input_ids
    print(all_tokens, word_tokens, word)
    return all_tokens.index(word_tokens[1])  # Word_tokens start with cls

def overlay_mask(img, mask):
    mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Create a 3-channel version of the mask
    mask_3ch = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR)

    # Set the opacity level
    opacity = 0.25  # Adjust as needed
    alpha_channel = np.ones_like(mask_resized) * 255  # Start with a fully opaque alpha channel

    # Set black pixels to be completely transparent
    alpha_channel[mask_resized < 5] = 0

    # Set the opacity level for non-black pixels
    opacity = 0.3  # Adjust this value as needed (0.0 to 1.0)
    alpha_channel[mask_resized != 0] = int(255 * opacity)

    # Create a 4-channel image (BGR + Alpha)
    b, g, r = cv2.split(img)
    rgba = [b, g, r, alpha_channel]
    result = cv2.merge(rgba, 4)
    # Overlay the mask on the image
    overlay = cv2.addWeighted(mask_3ch, opacity, img, 1 - opacity, 0)    
    return overlay

def fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num=0):
    
    
    frame_num = frame_num - 1
    if layer_num is None:
        layer_num = 0
    else:
        layer_num = 100 if layer_num == 3 else layer_num
   
    video_file_name = f"./data/videos/{prompt.replace(' ', '_')}/video/frame_{frame_num:04d}.png"
    img = cv2.imread(video_file_name)

    if word is None:
        overlaid_image = img
    else:
        mask_file_name = f'./data/final_masks/attention_probs_{prompt}/frame_{frame_num}_layer_{layer_num}_diffusionstep_{diffusion_step}_token_{get_token_number(prompt, word)}.png'
        mask = cv2.imread(mask_file_name, cv2.IMREAD_GRAYSCALE)
        overlaid_image = overlay_mask(img, mask)
        print(mask_file_name)

    return img, overlaid_image


def fetch_proper_img_and_change_prompt(prompt, word, frame_num, diffusion_step, layer_num=0):
    radio = change_text_prompt(prompt)
    video_1, video_2 = fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num)
    return [video_1, video_2, radio]

css = """
.word-btn {
    width: fit-content;
    padding: 3px;
}
.word-btns-container {
    flex-direction: row;
}
"""

registry = {
    'spider': 'mask_1',
    'descending': 'mask_2',
}

data_path = Path(
    'data'
)


available_prompts = ['a dog and a cat sitting','A fish swimming in the water', 'A spider descending from its web', 'An astronaut riding a horse']

with gr.Blocks(css=css) as demo:
    with gr.Row():
        video_1 = gr.Image(label="Image", )
        # video_1 = gr.Image(label="Image", width=256, height=256)
        video_2 = gr.Image(label="Image with Attention Mask", )
        # video_2 = gr.Image(label="Image with Attention Mask", width=256, height=256)
    

    def change_text_prompt(text):
        return gr.Radio(text.strip().split(' '), value=None, label='Choose a word to visualize its attention mask.')

    text = 'a dog and a cat sitting'
    gr.Markdown("""
                ## Visualizing Attention Masks
                * Select a prompt from the drop down
                * Click on "Get words" to get the words in the prompt
                * Select a radio button from the words to visualize the attention mask
                * Play around with the index of diffusion steps, layers to visualize different masks
                * Brighter mask corresponds to larger values of attention.
                """)

    with gr.Group("Video Selection"):
        txt_1 = gr.Dropdown(choices=available_prompts, label="Video Prompt", value=available_prompts[0])
        submit_btn = gr.Button('Get words')
    with gr.Group('Word Selection'):
        radio = gr.Radio(text.split(' '), value=None, label='Choose a word to visualize its attention mask.')
        range_slider = gr.Slider(1, 16, 1, step=2, label='Frame of the generated video to visualize the attention mask.')
        diffusion_slider = gr.Slider(0, 35, 0, step=5, label='Index of diffusion steps.')
        layer_num_slider = gr.Slider(0, 6, 0, step=1, label='Layer number for attention mask.')
    radio.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
    range_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
    diffusion_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
    layer_num_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2])
    submit_btn.click(change_text_prompt, inputs=[txt_1], outputs=[radio])

if __name__ == "__main__":
    demo.launch(server_name='0.0.0.0')