niks-salodkar's picture
added code and files
d9f5274
import os
import numpy as np
import faiss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from PIL import Image
from inference import load_trained_model
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
transforms = Compose([Resize((232, 232)), ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
class FashionDataset(Dataset):
def __init__(self, transforms, path='./data/small_images/'):
self.transforms = transforms
self.path = path
def __len__(self):
return len(os.listdir(self.path))
def __getitem__(self, index):
image_id = str(index) + '.jpg'
the_image = Image.open(os.path.join(self.path, image_id))
transformed_image = self.transforms(the_image)
return transformed_image
def build_flat_index(img_path, index_path, embeddings_model):
dimension = 2048
index = faiss.IndexFlatL2(dimension) # build the index
the_dataset = FashionDataset(transforms)
loader = DataLoader(the_dataset, batch_size=256, shuffle=False, pin_memory=False,
num_workers=16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
all_embeddings = torch.tensor([]).to(device)
print(device)
embeddings_model.to(device)
for step, batch in enumerate(loader):
batch = batch.to('cuda')
with torch.inference_mode():
embeddings = embeddings_model(batch, True)
all_embeddings = torch.cat((all_embeddings, embeddings), dim=0)
print(all_embeddings.shape)
all_embeddings = all_embeddings.numpy(force=True)
print(all_embeddings.shape)
index.add(all_embeddings)
print('total embeddings', index.ntotal)
faiss.write_index(index, os.path.join(index_path, 'resnet152_unweighted_flat.index'))
if __name__ == "__main__":
model_path = './data/final-models/resnet_152_classification.pt'
image_path = os.path.join('./data/small_images')
index_path = os.path.join('./data/index_files')
model = load_trained_model(model_path)
build_flat_index(image_path, index_path, model)