File size: 5,760 Bytes
92e0882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import clip
import json
import argparse
import ruamel.yaml as yaml
from PIL import Image
import torch
import torchvision.transforms as transforms
from tqdm import tqdm
from albef.utils import *
from executor import AlbefExecutor
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, help="Path to input JSON file")
parser.add_argument("--image_root", type=str, help="Path to directory containing images")
parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF")
parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF")
parser.add_argument("--clip_model", type=str, help="CLIP model to use")
parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP")
args = parser.parse_args()
if args.albef_path is not None:
executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu))
model = executor.models[0]
preprocess = executor.preprocesses[0]
model = model.eval()
else:
model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu))
preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC)
model = model.eval()
input_file = open(args.input_path)
data = json.load(input_file)
input_file.close()
correct = 0
for i in tqdm(range(0, len(data), args.batch_size)):
batch_images = []
batch_text = []
for datum in data[i:min(i+args.batch_size, len(data))]:
img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB')
batch_images.append(preprocess(img))
if "text2" in datum:
if args.albef_path is None:
datum["text1"] = "a photo of "+datum["text1"]
datum["text2"] = "a photo of "+datum["text2"]
batch_text.append(datum["text1"])
batch_text.append(datum["text2"])
else:
img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB')
batch_images.append(preprocess(img2))
batch_text.append(datum["text1"])
batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu))
if args.albef_path is None:
batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu))
else:
modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text]
batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt")
for key in batch_text:
batch_text[key] = batch_text[key].to(batch_images.device)
with torch.no_grad():
if args.albef_path is None:
logits_per_image, logits_per_text = model(batch_images, batch_text)
else:
if not args.albef_itc:
if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]:
batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3])
else:
assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0]
batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
image_embeds = model.visual_encoder(batch_images)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device)
output = model.text_encoder(
batch_text.input_ids,
attention_mask = batch_text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
vl_embeddings = output.last_hidden_state[:,0,:]
vl_output = model.itm_head(vl_embeddings)
logits_per_image = vl_output[:,1:2].view(-1, 2)
else:
image_embeds = model.visual_encoder(batch_images)
image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1)
text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask,
return_dict = True, mode = 'text')
text_embeds = text_output.last_hidden_state
text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1)
sim = image_feat@text_feat.t()/model.temp
logits_per_image = sim
if args.albef_path is None or args.albef_itc:
if logits_per_image.shape[0]*2 == logits_per_image.shape[1]:
for j in range(logits_per_image.shape[0]):
correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0
else:
assert logits_per_image.shape[0] == 2*logits_per_image.shape[1]
for j in range(logits_per_image.shape[1]):
correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0
else:
correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item()
print("Accuracy:", correct/len(data))
|