|
import os |
|
import clip |
|
import json |
|
import argparse |
|
import ruamel.yaml as yaml |
|
|
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as transforms |
|
from tqdm import tqdm |
|
|
|
from albef.utils import * |
|
from executor import AlbefExecutor |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--input_path", type=str, help="Path to input JSON file") |
|
parser.add_argument("--image_root", type=str, help="Path to directory containing images") |
|
parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF") |
|
parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF") |
|
parser.add_argument("--clip_model", type=str, help="CLIP model to use") |
|
parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use") |
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.albef_path is not None: |
|
executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu)) |
|
model = executor.models[0] |
|
preprocess = executor.preprocesses[0] |
|
model = model.eval() |
|
else: |
|
model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu)) |
|
preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC) |
|
model = model.eval() |
|
input_file = open(args.input_path) |
|
data = json.load(input_file) |
|
input_file.close() |
|
correct = 0 |
|
for i in tqdm(range(0, len(data), args.batch_size)): |
|
batch_images = [] |
|
batch_text = [] |
|
for datum in data[i:min(i+args.batch_size, len(data))]: |
|
img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB') |
|
batch_images.append(preprocess(img)) |
|
if "text2" in datum: |
|
if args.albef_path is None: |
|
datum["text1"] = "a photo of "+datum["text1"] |
|
datum["text2"] = "a photo of "+datum["text2"] |
|
batch_text.append(datum["text1"]) |
|
batch_text.append(datum["text2"]) |
|
else: |
|
img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB') |
|
batch_images.append(preprocess(img2)) |
|
batch_text.append(datum["text1"]) |
|
batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu)) |
|
if args.albef_path is None: |
|
batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu)) |
|
else: |
|
modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text] |
|
batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt") |
|
for key in batch_text: |
|
batch_text[key] = batch_text[key].to(batch_images.device) |
|
|
|
with torch.no_grad(): |
|
if args.albef_path is None: |
|
logits_per_image, logits_per_text = model(batch_images, batch_text) |
|
else: |
|
if not args.albef_itc: |
|
if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]: |
|
batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3]) |
|
else: |
|
assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0] |
|
batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1) |
|
batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1) |
|
image_embeds = model.visual_encoder(batch_images) |
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device) |
|
output = model.text_encoder( |
|
batch_text.input_ids, |
|
attention_mask = batch_text.attention_mask, |
|
encoder_hidden_states = image_embeds, |
|
encoder_attention_mask = image_atts, |
|
return_dict = True, |
|
) |
|
vl_embeddings = output.last_hidden_state[:,0,:] |
|
vl_output = model.itm_head(vl_embeddings) |
|
logits_per_image = vl_output[:,1:2].view(-1, 2) |
|
else: |
|
image_embeds = model.visual_encoder(batch_images) |
|
image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1) |
|
text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask, |
|
return_dict = True, mode = 'text') |
|
text_embeds = text_output.last_hidden_state |
|
text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1) |
|
sim = image_feat@text_feat.t()/model.temp |
|
logits_per_image = sim |
|
if args.albef_path is None or args.albef_itc: |
|
if logits_per_image.shape[0]*2 == logits_per_image.shape[1]: |
|
for j in range(logits_per_image.shape[0]): |
|
correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0 |
|
else: |
|
assert logits_per_image.shape[0] == 2*logits_per_image.shape[1] |
|
for j in range(logits_per_image.shape[1]): |
|
correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0 |
|
else: |
|
correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item() |
|
|
|
print("Accuracy:", correct/len(data)) |
|
|