tooba248 commited on
Commit
e54509f
·
verified ·
1 Parent(s): 5685a0e

Upload 2 files

Browse files
Files changed (2) hide show
  1. best_model.pt +3 -0
  2. eval.py +74 -0
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed57591f55d71c06050876296cfabd390e5265ca035dd98e4b8eaecd12203cfe
3
+ size 605264460
eval.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ from datasets import load_dataset
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import requests
7
+ from io import BytesIO
8
+ import numpy as np
9
+ import faiss
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model_clip, preprocess = clip.load("ViT-B/32", device=device)
13
+
14
+ # Load Flickr30k test split
15
+ dataset = load_dataset("nlphuji/flickr30k", split="test")
16
+
17
+ image_embeddings = []
18
+ text_embeddings = []
19
+ ground_truth = []
20
+
21
+ images = []
22
+ captions = []
23
+
24
+ print("Extracting embeddings...")
25
+
26
+ for i, example in enumerate(dataset):
27
+ try:
28
+ img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
29
+ images.append(img)
30
+ captions.append(example["sentence"])
31
+
32
+ img_tensor = preprocess(img).unsqueeze(0).to(device)
33
+ with torch.no_grad():
34
+ img_feat = model_clip.encode_image(img_tensor)
35
+ img_feat /= img_feat.norm(dim=-1, keepdim=True)
36
+ image_embeddings.append(img_feat.cpu())
37
+
38
+ txt_token = clip.tokenize([example["sentence"]]).to(device)
39
+ txt_feat = model_clip.encode_text(txt_token)
40
+ txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
41
+ text_embeddings.append(txt_feat.cpu())
42
+
43
+ ground_truth.append(i)
44
+ except:
45
+ continue
46
+
47
+ image_embeddings = torch.cat(image_embeddings, dim=0)
48
+ text_embeddings = torch.cat(text_embeddings, dim=0)
49
+
50
+ # Build FAISS indexes
51
+ image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
52
+ image_index.add(image_embeddings.numpy())
53
+
54
+ text_index = faiss.IndexFlatIP(text_embeddings.shape[1])
55
+ text_index.add(text_embeddings.numpy())
56
+
57
+ # Text-to-Image Retrieval Accuracy (Recall@1, 5, 10)
58
+ def compute_recall(query_embeddings, index, ground_truth, k_values=[1, 5, 10]):
59
+ D, I = index.search(query_embeddings.numpy(), max(k_values))
60
+ recalls = {k: 0 for k in k_values}
61
+ for i, gt in enumerate(ground_truth):
62
+ for k in k_values:
63
+ if gt in I[i][:k]:
64
+ recalls[k] += 1
65
+ total = len(ground_truth)
66
+ return {f"Recall@{k}": round((recalls[k] / total) * 100, 2) for k in k_values}
67
+
68
+ print("Evaluating text-to-image retrieval...")
69
+ text_to_image_recall = compute_recall(text_embeddings, image_index, ground_truth)
70
+ print("Text-to-Image:", text_to_image_recall)
71
+
72
+ print("Evaluating image-to-text retrieval...")
73
+ image_to_text_recall = compute_recall(image_embeddings, text_index, ground_truth)
74
+ print("Image-to-Text:", image_to_text_recall)