Spaces:
Runtime error
Runtime error
from collections import defaultdict | |
from torch.nn.parallel import DistributedDataParallel | |
from matplotlib import pyplot as plt | |
import torch | |
import requests | |
from io import BytesIO | |
from PIL import Image, ImageDraw | |
from torchvision.transforms import ToPILImage | |
import torch.nn.functional as F | |
import numpy as np | |
import os | |
import datetime | |
from tinyllava.data import * | |
from tinyllava.utils import * | |
from tinyllava.model import * | |
import pdb | |
def load_image(image_file): | |
if image_file.startswith("http") or image_file.startswith("https"): | |
response = requests.get(image_file) | |
image = Image.open(BytesIO(response.content)).convert("RGB") | |
else: | |
image = Image.open(image_file).convert("RGB") | |
return image | |
def load_images(image_files): | |
out = [] | |
for image_file in image_files: | |
image = load_image(image_file) | |
out.append(image) | |
return out | |
def extract_max_values_and_indices(tensor, k): | |
max_values, max_indices = torch.topk(tensor, k, dim=2) | |
max_values_with_indices = torch.stack((max_indices, max_values), dim=3) | |
return max_values_with_indices | |
def visualize_grid_to_grid(i, mask, image, output_dir, grid_size=27, alpha=0.6): | |
if not isinstance(grid_size, tuple): | |
grid_size = (grid_size, grid_size) | |
mask = mask.detach().cpu().numpy() | |
mask = Image.fromarray(mask).resize((384, 384)) | |
fig, ax = plt.subplots(1, 2, figsize=(10, 7)) | |
fig.tight_layout() | |
ax[0].imshow(image) | |
ax[0].axis('off') | |
ax[1].imshow(image) | |
im = ax[1].imshow(mask / np.max(mask), alpha=alpha, cmap='rainbow') | |
ax[1].axis('off') | |
cbar = fig.colorbar(im, ax=ax[1]) | |
cbar.set_label('Color Temperature') | |
name = os.path.join(output_dir, "hot_image", f"{i}.png") | |
plt.savefig(name) | |
plt.close(fig) | |
def generate_square_subsequent_mask(sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
def generate_word_images(tokenizer, top_words_tensor, num, input_ids, embed_tokens, output_dir): | |
num_top_words = top_words_tensor.shape[1] | |
for i in range(num_top_words - num, num_top_words): | |
fig, ax = plt.subplots() | |
word_indices = top_words_tensor[0, i, :, 0].detach().cpu().numpy() | |
probabilities = top_words_tensor[0, i, :, 1].detach().cpu().numpy() | |
colors = plt.cm.viridis(probabilities) | |
for j, (word_index, color, prob) in enumerate(zip(word_indices, colors, probabilities)): | |
word = tokenizer.decode([int(word_index)]) | |
prob_text = f"{word} P: {prob:.2f}" | |
ax.text(0.5, 0.9 - j * 0.1, prob_text, color=color, ha='center', va='center', transform=ax.transAxes) | |
ax.axis('off') | |
ax.set_title('Top Words for Index {}'.format(i - num_top_words + num + 1)) | |
plt.savefig(os.path.join(output_dir, 'word', f"word_image_{i - num_top_words + num + 1}.png")) | |
plt.close() | |
def generate_word_images_before(tokenizer, input_ids, tensor, num, top_words_tensor, output_dir): | |
num_top_words = tensor.shape[2] | |
result = tensor.mean(dim=1) # [1, len, len] | |
input_ids_fir = input_ids[input_ids != -200].unsqueeze(0) | |
for i in range(num_top_words - num, num_top_words - 1): | |
top1_indices = top_words_tensor[0, i, 0, 0].long() | |
fig, ax = plt.subplots() | |
result_1 = result[0, i, 0:input_ids.shape[1]] | |
result_1 = result_1[input_ids.squeeze() != -200] | |
if not i == num_top_words - num: | |
result_2 = result[0, i, num_top_words - num + 1:i + 1] | |
result_1 = torch.cat((result_1, result_2), dim=0) | |
if not i == num_top_words - num: | |
output_ids = top_words_tensor[0, num_top_words - num:i, 0, 0].unsqueeze(0).long() | |
input_ids_fir = torch.cat((input_ids_fir, output_ids), dim=1) | |
tv, ti = torch.topk(result_1.squeeze(), 8) | |
tv = tv / torch.max(tv) | |
probabilities = tv.detach().cpu().numpy() | |
colors = plt.cm.viridis(probabilities) | |
for j, (word_index, color, prob) in enumerate(zip(ti, colors, probabilities)): | |
word = tokenizer.decode(input_ids_fir[0, word_index.item()]) | |
prob_text = f"{word} P: {prob:.2f}" | |
ax.text(0.5, 0.9 - j * 0.1, prob_text, color=color, ha='center', va='center', transform=ax.transAxes) | |
ax.axis('off') | |
ax.set_title( | |
'similarities of output word {}'.format(tokenizer.decode([top1_indices.detach().cpu().numpy()]))) | |
plt.savefig(os.path.join(output_dir, 'word_before', f"word_image_{i - (num_top_words - num - 1)}.png")) | |
plt.close() | |
class Monitor: | |
def __init__(self, args, model, llm_layers_index): | |
self.model = model | |
self.args = args | |
self.input_ids = None | |
self.image = None | |
self.params = list(model.parameters()) | |
self.output = defaultdict(dict) | |
self.attentions = [] | |
self.hidden = [] | |
self.logit = [] | |
self.image_token = [] | |
self.llm_layers_index = llm_layers_index | |
self._register(llm_layers_index) | |
def _register(self, llm_layers_index): | |
def attention_hook(module, input, output): | |
self.hidden.append(input[0]) | |
def output_hook(module, input, output): | |
self.logit.append(output) | |
def image_hook(module, input, output): | |
self.image_token.append(output) | |
mod = self.model | |
mod.language_model.model.layers[llm_layers_index].register_forward_hook(attention_hook) | |
mod.language_model.lm_head.register_forward_hook(output_hook) | |
mod.connector.register_forward_hook(image_hook) | |
def prepare_input(self): | |
# 获得input_ids | |
qs = self.args.query | |
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs | |
text_processor = TextPreprocess(self.model.tokenizer, self.args.conv_mode) | |
msg = Message() | |
msg.add_message(qs) | |
result = text_processor(msg.messages, mode='eval') | |
self.input_ids = result['input_ids'].unsqueeze(0).cuda() | |
# 获得图片tensor | |
data_args = self.model.config | |
image_processor = self.model.vision_tower._image_processor | |
image_processor = ImagePreprocess(image_processor, data_args) | |
image_files = self.args.image_file.split(self.args.sep) | |
images = load_images(image_files)[0] | |
images_tensor = image_processor(images) | |
image_tensor = 255 * (images_tensor - images_tensor.min()) / (images_tensor.max() - images_tensor.min()) | |
image_tensor = image_tensor.clamp(0, 255) | |
image_tensor = image_tensor.byte() | |
to_pil = ToPILImage() | |
self.image = to_pil(image_tensor).convert('RGB') | |
self.model.cuda() | |
self.logit = F.softmax(torch.cat(self.logit, dim=1), dim=2) | |
hidden_tensor = torch.cat(self.hidden, dim=1) | |
length = hidden_tensor.shape[1] | |
attention_mask = torch.unsqueeze( | |
torch.unsqueeze(generate_square_subsequent_mask(length).clone().detach(), dim=0), | |
dim=0).cuda() | |
self.hidden = self.model.language_model.model.layers[self.llm_layers_index](hidden_tensor, | |
output_attentions=True, | |
attention_mask=attention_mask) | |
self.image_token = self.image_token[0].squeeze() | |
self.image_token = torch.cat((torch.zeros(1, 2560).cuda(), self.image_token), dim=0) | |
def get_output(self, output_dir='results/'): | |
print("Starting visualization...") | |
self.prepare_input() | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_dir = os.path.join(output_dir, f"run_{timestamp}") | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(os.path.join(output_dir, 'word'), exist_ok=True) | |
os.makedirs(os.path.join(output_dir, 'word_before'), exist_ok=True) | |
os.makedirs(os.path.join(output_dir, 'hot_image'), exist_ok=True) | |
num = self.logit.shape[1] - 726 - len(self.input_ids[0]) | |
result = extract_max_values_and_indices(self.logit, 8) | |
generate_word_images(self.model.tokenizer, result, num, self.input_ids, | |
self.model.language_model.model.embed_tokens.weight, output_dir) | |
generate_word_images_before(self.model.tokenizer, self.input_ids, self.hidden[1], num, result, output_dir) | |
result_top1 = result[0, :, 0, 0].squeeze() | |
for i in range(len(result_top1) - num, len(result_top1)): | |
word_id = result_top1[i] | |
word_id_tensor = torch.tensor([word_id]).long().cuda() | |
word_vector = self.model.language_model.model.embed_tokens(word_id_tensor).squeeze().detach() | |
vector_expanded = word_vector.unsqueeze(0).expand_as(self.image_token) | |
vector_norm = F.normalize(vector_expanded, p=2, dim=1) | |
matrix_norm = F.normalize(self.image_token, p=2, dim=1) | |
cosine_similarities = torch.sum(vector_norm * matrix_norm, dim=1) | |
normalized_similarities = F.softmax(cosine_similarities, dim=0) | |
visualize_grid_to_grid('hot_image_' + str(i - (len(result_top1) - num) + 1), | |
normalized_similarities.view(27, 27), | |
self.image, output_dir) | |
print("Completed visualization.") | |