import datetime import time from collections import OrderedDict from pathlib import Path import einops import numpy as np import torch import torch.nn.functional as F from src.tools.files import json_dump class TestCirr: def __init__(self): pass @staticmethod @torch.no_grad() def __call__(model, data_loader, fabric): model.eval() fabric.print("Computing features for test...") start_time = time.time() tar_img_feats = [] query_feats = [] pair_ids = [] for ref_img, tar_feat, caption, pair_id, *_ in data_loader: pair_ids.extend(pair_id.cpu().numpy().tolist()) device = ref_img.device ref_img_embs = model.visual_encoder(ref_img) ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to( device ) text = model.tokenizer( caption, padding="longest", truncation=True, max_length=64, return_tensors="pt", ).to(device) # Shift encoder encoder_input_ids = text.input_ids.clone() encoder_input_ids[:, 0] = model.tokenizer.enc_token_id query_embs = model.text_encoder( encoder_input_ids, attention_mask=text.attention_mask, encoder_hidden_states=ref_img_embs, encoder_attention_mask=ref_img_atts, return_dict=True, ) query_feat = query_embs.last_hidden_state[:, 0, :] query_feat = F.normalize(model.text_proj(query_feat), dim=-1) query_feats.append(query_feat.cpu()) # Encode the target image tar_img_feats.append(tar_feat.cpu()) pair_ids = torch.tensor(pair_ids, dtype=torch.long) query_feats = torch.cat(query_feats, dim=0) tar_img_feats = torch.cat(tar_img_feats, dim=0) if fabric.world_size > 1: # Gather tensors from every process query_feats = fabric.all_gather(query_feats) tar_img_feats = fabric.all_gather(tar_img_feats) pair_ids = fabric.all_gather(pair_ids) query_feats = einops.rearrange(query_feats, "d b e -> (d b) e") tar_img_feats = einops.rearrange(tar_img_feats, "d b e -> (d b) e") pair_ids = einops.rearrange(pair_ids, "d b -> (d b)") if fabric.global_rank == 0: pair_ids = pair_ids.cpu().numpy().tolist() assert len(query_feats) == len(pair_ids) img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids] assert len(img_ids) == len(pair_ids) id2emb = OrderedDict() for img_id, tar_img_feat in zip(img_ids, tar_img_feats): if img_id not in id2emb: id2emb[img_id] = tar_img_feat tar_feats = torch.stack(list(id2emb.values()), dim=0) sims_q2t = query_feats @ tar_feats.T # Create a mapping from pair_id to row index for faster lookup pairid2index = {pair_id: i for i, pair_id in enumerate(pair_ids)} # Create a mapping from target_id to column index for faster lookup tarid2index = {tar_id: j for j, tar_id in enumerate(id2emb.keys())} # Update the similarity matrix based on the condition for pair_id, query_feat in zip(pair_ids, query_feats): que_id = data_loader.dataset.pairid2ref[pair_id] if que_id in tarid2index: sims_q2t[pairid2index[pair_id], tarid2index[que_id]] = -100 sims_q2t = sims_q2t.cpu().numpy() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Evaluation time {}".format(total_time_str)) recalls = {} recalls["version"] = "rc2" recalls["metric"] = "recall" recalls_subset = {} recalls_subset["version"] = "rc2" recalls_subset["metric"] = "recall_subset" target_imgs = np.array(list(id2emb.keys())) assert len(sims_q2t) == len(pair_ids) for pair_id, query_sims in zip(pair_ids, sims_q2t): sorted_indices = np.argsort(query_sims)[::-1] query_id_recalls = list(target_imgs[sorted_indices][:50]) query_id_recalls = [ str(data_loader.dataset.int2id[x]) for x in query_id_recalls ] recalls[str(pair_id)] = query_id_recalls members = data_loader.dataset.pairid2members[pair_id] query_id_recalls_subset = [ target for target in target_imgs[sorted_indices] if target in members ] query_id_recalls_subset = [ data_loader.dataset.int2id[x] for x in query_id_recalls_subset ][:3] recalls_subset[str(pair_id)] = query_id_recalls_subset json_dump(recalls, "recalls_cirr.json") json_dump(recalls_subset, "recalls_cirr_subset.json") print(f"Recalls saved in {Path.cwd()} as recalls_cirr.json") fabric.barrier()