|
import os.path |
|
from torchvision.models import resnet50 |
|
from torchvision.transforms import transforms |
|
from PIL import Image |
|
import torch.nn as nn |
|
import torch |
|
import pickle |
|
import matplotlib.pyplot as plt |
|
from src.utils.path_utils import get_project_root |
|
|
|
|
|
class ImageSimilarity: |
|
def __init__(self): |
|
self.model = resnet50(weights="DEFAULT") |
|
self.model = nn.Sequential( |
|
*list(self.model.children())[:-1] |
|
) |
|
self.model.eval() |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
|
|
def extract_features(self, image_stream): |
|
image = Image.open(image_stream).convert("RGB") |
|
image = self.transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
features = self.model(image) |
|
features = features.flatten() |
|
return features |
|
|
|
def similarity(self, features1, features2): |
|
|
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6) |
|
similarity = cos(features1.unsqueeze(0), features2.unsqueeze(0)) |
|
return similarity.item() |
|
|
|
|
|
class ImageCorpus: |
|
def __init__(self, feature_corpus_path): |
|
self.feature_corpus_path = feature_corpus_path |
|
self.feature_dict = self.load_features() |
|
self.feature_extractor = ImageSimilarity() |
|
|
|
def load_features(self): |
|
try: |
|
with open(self.feature_corpus_path, "rb") as f: |
|
return pickle.load(f) |
|
except (EOFError, pickle.UnpicklingError): |
|
print( |
|
"Warning: Pickle file is empty or corrupted. Initializing empty feature dict." |
|
) |
|
|
|
def save_features(self): |
|
with open(self.feature_corpus_path, "wb") as f: |
|
pickle.dump(self.feature_dict, f) |
|
|
|
def add_image(self, image_path): |
|
features = self.feature_extractor.extract_features(image_path) |
|
self.feature_dict[image_path] = features |
|
self.save_features() |
|
|
|
def create_feature_corpus(self, image_dir): |
|
for image_name in os.listdir(image_dir): |
|
image_path = os.path.join(image_dir, image_name) |
|
if os.path.isfile(image_path) and image_path.lower().endswith( |
|
(".png", ".jpg", ".jpeg") |
|
): |
|
features = self.feature_extractor.extract_features(image_path) |
|
self.feature_dict[image_path] = features |
|
|
|
self.save_features() |
|
|
|
def retrieve_similar_images(self, query_image_path, top_k=50): |
|
query_features = self.feature_extractor.extract_features(query_image_path) |
|
similarity_scores = {} |
|
|
|
for image_name, corpus_feature in self.feature_dict.items(): |
|
similarity = self.feature_extractor.similarity( |
|
query_features, corpus_feature |
|
) |
|
similarity_scores[image_name] = similarity |
|
|
|
retrieved_images = sorted( |
|
similarity_scores.items(), key=lambda x: x[1], reverse=True |
|
) |
|
|
|
|
|
unique_scores = set() |
|
filtered_images = [] |
|
|
|
for image_path, score in retrieved_images: |
|
if score not in unique_scores: |
|
unique_scores.add(score) |
|
filtered_images.append((image_path, score)) |
|
|
|
if len(filtered_images) == top_k: |
|
break |
|
|
|
return filtered_images |
|
|
|
|
|
def visualize_retrieved_images(query_image_path, top_retrievals): |
|
|
|
|
|
query_image = Image.open(query_image_path).convert("RGB") |
|
project_base = get_project_root() |
|
|
|
retrieved_images = [ |
|
(Image.open(os.path.join(project_base, img_path)).convert("RGB"), score) |
|
for img_path, score in top_retrievals |
|
] |
|
|
|
|
|
total_retrieved = len(retrieved_images) |
|
rows = 2 + (total_retrieved - 1) // 5 |
|
cols = 5 |
|
|
|
|
|
plt.figure(figsize=(20, rows * 4)) |
|
|
|
|
|
plt.subplot(rows, cols, (cols // 2) + 1) |
|
plt.imshow(query_image) |
|
plt.title("Query Image", fontsize=12) |
|
plt.axis("off") |
|
|
|
|
|
for idx, (img, score) in enumerate(retrieved_images): |
|
plt.subplot(rows, cols, cols + idx + 1) |
|
plt.imshow(img) |
|
plt.title(f"Rank: {idx+1}\nScore: {score:.4f}", fontsize=10) |
|
plt.axis("off") |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
project_root = get_project_root() |
|
image_feature = os.path.join(project_root, "evidence_features.pkl") |
|
image_dir = os.path.join( |
|
project_root, "data", "raw", "factify", "extracted", "images", "evidence_corpus" |
|
) |
|
|
|
query_image_path = os.path.join( |
|
project_root, |
|
"data", |
|
"raw", |
|
"factify", |
|
"extracted", |
|
"images", |
|
"train", |
|
"1_claim.jpg", |
|
) |
|
|
|
image_corpus = ImageCorpus(image_feature) |
|
|
|
print(list(image_corpus.feature_dict.keys())[0]) |
|
|
|
top_retrievals = image_corpus.retrieve_similar_images(query_image_path, top_k=5) |
|
print(top_retrievals) |
|
visualize_retrieved_images(query_image_path, top_retrievals) |
|
|