from collections import defaultdict from torch.nn.parallel import DistributedDataParallel from matplotlib import pyplot as plt import torch import requests from io import BytesIO from PIL import Image, ImageDraw from torchvision.transforms import ToPILImage import torch.nn.functional as F import numpy as np import os import datetime from tinyllava.data import * from tinyllava.utils import * from tinyllava.model import * import pdb def load_image(image_file): if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_file).convert("RGB") return image def load_images(image_files): out = [] for image_file in image_files: image = load_image(image_file) out.append(image) return out def extract_max_values_and_indices(tensor, k): max_values, max_indices = torch.topk(tensor, k, dim=2) max_values_with_indices = torch.stack((max_indices, max_values), dim=3) return max_values_with_indices def visualize_grid_to_grid(i, mask, image, output_dir, grid_size=27, alpha=0.6): if not isinstance(grid_size, tuple): grid_size = (grid_size, grid_size) mask = mask.detach().cpu().numpy() mask = Image.fromarray(mask).resize((384, 384)) fig, ax = plt.subplots(1, 2, figsize=(10, 7)) fig.tight_layout() ax[0].imshow(image) ax[0].axis('off') ax[1].imshow(image) im = ax[1].imshow(mask / np.max(mask), alpha=alpha, cmap='rainbow') ax[1].axis('off') cbar = fig.colorbar(im, ax=ax[1]) cbar.set_label('Color Temperature') name = os.path.join(output_dir, "hot_image", f"{i}.png") plt.savefig(name) plt.close(fig) def generate_square_subsequent_mask(sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def generate_word_images(tokenizer, top_words_tensor, num, input_ids, embed_tokens, output_dir): num_top_words = top_words_tensor.shape[1] for i in range(num_top_words - num, num_top_words): fig, ax = plt.subplots() word_indices = top_words_tensor[0, i, :, 0].detach().cpu().numpy() probabilities = top_words_tensor[0, i, :, 1].detach().cpu().numpy() colors = plt.cm.viridis(probabilities) for j, (word_index, color, prob) in enumerate(zip(word_indices, colors, probabilities)): word = tokenizer.decode([int(word_index)]) prob_text = f"{word} P: {prob:.2f}" ax.text(0.5, 0.9 - j * 0.1, prob_text, color=color, ha='center', va='center', transform=ax.transAxes) ax.axis('off') ax.set_title('Top Words for Index {}'.format(i - num_top_words + num + 1)) plt.savefig(os.path.join(output_dir, 'word', f"word_image_{i - num_top_words + num + 1}.png")) plt.close() def generate_word_images_before(tokenizer, input_ids, tensor, num, top_words_tensor, output_dir): num_top_words = tensor.shape[2] result = tensor.mean(dim=1) # [1, len, len] input_ids_fir = input_ids[input_ids != -200].unsqueeze(0) for i in range(num_top_words - num, num_top_words - 1): top1_indices = top_words_tensor[0, i, 0, 0].long() fig, ax = plt.subplots() result_1 = result[0, i, 0:input_ids.shape[1]] result_1 = result_1[input_ids.squeeze() != -200] if not i == num_top_words - num: result_2 = result[0, i, num_top_words - num + 1:i + 1] result_1 = torch.cat((result_1, result_2), dim=0) if not i == num_top_words - num: output_ids = top_words_tensor[0, num_top_words - num:i, 0, 0].unsqueeze(0).long() input_ids_fir = torch.cat((input_ids_fir, output_ids), dim=1) tv, ti = torch.topk(result_1.squeeze(), 8) tv = tv / torch.max(tv) probabilities = tv.detach().cpu().numpy() colors = plt.cm.viridis(probabilities) for j, (word_index, color, prob) in enumerate(zip(ti, colors, probabilities)): word = tokenizer.decode(input_ids_fir[0, word_index.item()]) prob_text = f"{word} P: {prob:.2f}" ax.text(0.5, 0.9 - j * 0.1, prob_text, color=color, ha='center', va='center', transform=ax.transAxes) ax.axis('off') ax.set_title( 'similarities of output word {}'.format(tokenizer.decode([top1_indices.detach().cpu().numpy()]))) plt.savefig(os.path.join(output_dir, 'word_before', f"word_image_{i - (num_top_words - num - 1)}.png")) plt.close() class Monitor: def __init__(self, args, model, llm_layers_index): self.model = model self.args = args self.input_ids = None self.image = None self.params = list(model.parameters()) self.output = defaultdict(dict) self.attentions = [] self.hidden = [] self.logit = [] self.image_token = [] self.llm_layers_index = llm_layers_index self._register(llm_layers_index) def _register(self, llm_layers_index): def attention_hook(module, input, output): self.hidden.append(input[0]) def output_hook(module, input, output): self.logit.append(output) def image_hook(module, input, output): self.image_token.append(output) mod = self.model mod.language_model.model.layers[llm_layers_index].register_forward_hook(attention_hook) mod.language_model.lm_head.register_forward_hook(output_hook) mod.connector.register_forward_hook(image_hook) def prepare_input(self): # 获得input_ids qs = self.args.query qs = DEFAULT_IMAGE_TOKEN + "\n" + qs text_processor = TextPreprocess(self.model.tokenizer, self.args.conv_mode) msg = Message() msg.add_message(qs) result = text_processor(msg.messages, mode='eval') self.input_ids = result['input_ids'].unsqueeze(0).cuda() # 获得图片tensor data_args = self.model.config image_processor = self.model.vision_tower._image_processor image_processor = ImagePreprocess(image_processor, data_args) image_files = self.args.image_file.split(self.args.sep) images = load_images(image_files)[0] images_tensor = image_processor(images) image_tensor = 255 * (images_tensor - images_tensor.min()) / (images_tensor.max() - images_tensor.min()) image_tensor = image_tensor.clamp(0, 255) image_tensor = image_tensor.byte() to_pil = ToPILImage() self.image = to_pil(image_tensor).convert('RGB') self.model.cuda() self.logit = F.softmax(torch.cat(self.logit, dim=1), dim=2) hidden_tensor = torch.cat(self.hidden, dim=1) length = hidden_tensor.shape[1] attention_mask = torch.unsqueeze( torch.unsqueeze(generate_square_subsequent_mask(length).clone().detach(), dim=0), dim=0).cuda() self.hidden = self.model.language_model.model.layers[self.llm_layers_index](hidden_tensor, output_attentions=True, attention_mask=attention_mask) self.image_token = self.image_token[0].squeeze() self.image_token = torch.cat((torch.zeros(1, 2560).cuda(), self.image_token), dim=0) def get_output(self, output_dir='results/'): print("Starting visualization...") self.prepare_input() timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(output_dir, f"run_{timestamp}") os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, 'word'), exist_ok=True) os.makedirs(os.path.join(output_dir, 'word_before'), exist_ok=True) os.makedirs(os.path.join(output_dir, 'hot_image'), exist_ok=True) num = self.logit.shape[1] - 726 - len(self.input_ids[0]) result = extract_max_values_and_indices(self.logit, 8) generate_word_images(self.model.tokenizer, result, num, self.input_ids, self.model.language_model.model.embed_tokens.weight, output_dir) generate_word_images_before(self.model.tokenizer, self.input_ids, self.hidden[1], num, result, output_dir) result_top1 = result[0, :, 0, 0].squeeze() for i in range(len(result_top1) - num, len(result_top1)): word_id = result_top1[i] word_id_tensor = torch.tensor([word_id]).long().cuda() word_vector = self.model.language_model.model.embed_tokens(word_id_tensor).squeeze().detach() vector_expanded = word_vector.unsqueeze(0).expand_as(self.image_token) vector_norm = F.normalize(vector_expanded, p=2, dim=1) matrix_norm = F.normalize(self.image_token, p=2, dim=1) cosine_similarities = torch.sum(vector_norm * matrix_norm, dim=1) normalized_similarities = F.softmax(cosine_similarities, dim=0) visualize_grid_to_grid('hot_image_' + str(i - (len(result_top1) - num) + 1), normalized_similarities.view(27, 27), self.image, output_dir) print("Completed visualization.")