import os import numpy as np import faiss from torch.utils.data import Dataset from torch.utils.data import DataLoader from tqdm import tqdm import torch from PIL import Image from inference import load_trained_model from torchvision.transforms import Compose, Resize, ToTensor, Normalize transforms = Compose([Resize((232, 232)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) class FashionDataset(Dataset): def __init__(self, transforms, path='./data/small_images/'): self.transforms = transforms self.path = path def __len__(self): return len(os.listdir(self.path)) def __getitem__(self, index): image_id = str(index) + '.jpg' the_image = Image.open(os.path.join(self.path, image_id)) transformed_image = self.transforms(the_image) return transformed_image def build_flat_index(img_path, index_path, embeddings_model): dimension = 2048 index = faiss.IndexFlatL2(dimension) # build the index the_dataset = FashionDataset(transforms) loader = DataLoader(the_dataset, batch_size=256, shuffle=False, pin_memory=False, num_workers=16) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") all_embeddings = torch.tensor([]).to(device) print(device) embeddings_model.to(device) for step, batch in enumerate(loader): batch = batch.to('cuda') with torch.inference_mode(): embeddings = embeddings_model(batch, True) all_embeddings = torch.cat((all_embeddings, embeddings), dim=0) print(all_embeddings.shape) all_embeddings = all_embeddings.numpy(force=True) print(all_embeddings.shape) index.add(all_embeddings) print('total embeddings', index.ntotal) faiss.write_index(index, os.path.join(index_path, 'resnet152_unweighted_flat.index')) if __name__ == "__main__": model_path = './data/final-models/resnet_152_classification.pt' image_path = os.path.join('./data/small_images') index_path = os.path.join('./data/index_files') model = load_trained_model(model_path) build_flat_index(image_path, index_path, model)