|
import torch |
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
from ipycanvas import Canvas |
|
import cv2 |
|
|
|
from visualize_attention_src.utils import get_image |
|
|
|
exp_dir = "saved_attention_map_results" |
|
|
|
style_name = "line_art" |
|
src_name = "cat" |
|
tgt_name = "dog" |
|
|
|
steps = ["20"] |
|
seed = "4" |
|
saved_dtype = "tensor" |
|
|
|
|
|
attn_map_raws = [] |
|
for step in steps: |
|
attn_map_name_wo_ext = f"attn_map_raw_{style_name}_src_{src_name}_tgt_{tgt_name}_activate_layer_(0, 0)(108, 140)_attn_map_step_{step}_seed_{seed}" |
|
|
|
if saved_dtype == 'uint8': |
|
attn_map_name = attn_map_name_wo_ext + '_uint8.npy' |
|
attn_map_path = os.path.join(exp_dir, attn_map_name) |
|
attn_map_raws.append(np.load(attn_map_path, allow_pickle=True)) |
|
|
|
else: |
|
attn_map_name = attn_map_name_wo_ext + '.pt' |
|
attn_map_path = os.path.join(exp_dir, attn_map_name) |
|
attn_map_raws.append(torch.load(attn_map_path)) |
|
print(attn_map_path) |
|
|
|
attn_map_path = os.path.join(exp_dir, attn_map_name) |
|
|
|
print(f"{step} is on memory") |
|
|
|
keys = [key for key in attn_map_raws[0].keys()] |
|
|
|
|
|
print(len(keys)) |
|
key = keys[0] |
|
|
|
|
|
tgt_idx = 3 |
|
|
|
attn_map_paired_rgb_grid_name = f"{style_name}_src_{src_name}_tgt_{tgt_name}_scale_1.0_activate_layer_(0, 0)(108, 140)_seed_{seed}.png" |
|
|
|
attn_map_paired_rgb_grid_path = os.path.join(exp_dir, attn_map_paired_rgb_grid_name) |
|
print(attn_map_paired_rgb_grid_path) |
|
attn_map_paired_rgb_grid = Image.open(attn_map_paired_rgb_grid_path) |
|
|
|
attn_map_src_img = get_image(attn_map_paired_rgb_grid, row = 0, col = 0, image_size = 1024, grid_width = 10) |
|
attn_map_tgt_img = get_image(attn_map_paired_rgb_grid, row = 0, col = tgt_idx, image_size = 1024, grid_width = 10) |
|
|
|
|
|
h, w = 256, 256 |
|
num_of_grid = 64 |
|
|
|
plus_50 = 0 |
|
|
|
|
|
key_idx_list = [6, 28] |
|
|
|
|
|
|
|
|
|
saved_attention_map_idx = [0] |
|
|
|
source_image = attn_map_src_img |
|
target_image = attn_map_tgt_img |
|
|
|
|
|
source_image = source_image.resize((h, w)) |
|
target_image = target_image.resize((h, w)) |
|
|
|
|
|
source_image = np.array(source_image) |
|
target_image = np.array(target_image) |
|
|
|
canvas = Canvas(width=4 * w, height=h * len(key_idx_list), sync_image_data=True) |
|
canvas.put_image_data(source_image, w * 3, 0) |
|
canvas.put_image_data(target_image, 0, 0) |
|
|
|
canvas.put_image_data(source_image, w * 3, h) |
|
canvas.put_image_data(target_image, 0, h) |
|
|
|
|
|
|
|
|
|
|
|
def save_to_file(*args, **kwargs): |
|
canvas.to_file("my_file1.png") |
|
|
|
|
|
|
|
canvas.observe(save_to_file, "image_data") |
|
|
|
|
|
def on_click(x, y): |
|
cnt = 0 |
|
canvas.put_image_data(target_image, 0, 0) |
|
|
|
print(x, y) |
|
|
|
canvas.fill_style = 'red' |
|
canvas.fill_circle(x, y, 4) |
|
|
|
for step_i, step in enumerate(range(len(saved_attention_map_idx))): |
|
|
|
attn_map_raw = attn_map_raws[step_i] |
|
|
|
for key_i, key_idx in enumerate(key_idx_list): |
|
key = keys[key_idx] |
|
|
|
num_of_grid = int(attn_map_raw[key].shape[-1] ** (0.5)) |
|
|
|
|
|
grid_x_idx = int(x / (w / num_of_grid)) |
|
grid_y_idx = int(y / (h / num_of_grid)) |
|
|
|
print(grid_x_idx, grid_y_idx) |
|
|
|
grid_idx = grid_x_idx + grid_y_idx * num_of_grid |
|
|
|
attn_map = attn_map_raw[key][tgt_idx * 10:10 + tgt_idx * 10, grid_idx, :] |
|
|
|
attn_map = attn_map.sum(dim=0) |
|
|
|
attn_map = attn_map.reshape(num_of_grid, num_of_grid) |
|
|
|
|
|
attn_map = attn_map.detach().cpu().numpy() |
|
|
|
|
|
normalized_attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) |
|
normalized_attn_map = 1.0 - normalized_attn_map |
|
|
|
heatmap = cv2.applyColorMap(np.uint8(255 * normalized_attn_map), cv2.COLORMAP_JET) |
|
heatmap = cv2.resize(heatmap, (w, h)) |
|
|
|
attn_map = normalized_attn_map * 255 |
|
|
|
attn_map = attn_map.astype(np.uint8) |
|
|
|
attn_map = cv2.cvtColor(attn_map, cv2.COLOR_GRAY2RGB) |
|
|
|
attn_map = cv2.resize(attn_map, (w, h)) |
|
|
|
|
|
canvas.put_image_data(attn_map, w + step_i * 4 * w, h * key_i) |
|
|
|
|
|
|
|
alpha = 0.85 |
|
blended_image = cv2.addWeighted(source_image, 1 - alpha, heatmap, alpha, 0) |
|
|
|
|
|
canvas.put_image_data(blended_image, w * 2 + step_i * 4 * w, h * key_i) |
|
|
|
cnt += 1 |
|
|
|
|
|
|
|
|
|
canvas.on_mouse_down(on_click) |