Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import faiss | |
from torch.utils.data import Dataset | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import torch | |
from PIL import Image | |
from inference import load_trained_model | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
transforms = Compose([Resize((232, 232)), ToTensor(), | |
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
class FashionDataset(Dataset): | |
def __init__(self, transforms, path='./data/small_images/'): | |
self.transforms = transforms | |
self.path = path | |
def __len__(self): | |
return len(os.listdir(self.path)) | |
def __getitem__(self, index): | |
image_id = str(index) + '.jpg' | |
the_image = Image.open(os.path.join(self.path, image_id)) | |
transformed_image = self.transforms(the_image) | |
return transformed_image | |
def build_flat_index(img_path, index_path, embeddings_model): | |
dimension = 2048 | |
index = faiss.IndexFlatL2(dimension) # build the index | |
the_dataset = FashionDataset(transforms) | |
loader = DataLoader(the_dataset, batch_size=256, shuffle=False, pin_memory=False, | |
num_workers=16) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
all_embeddings = torch.tensor([]).to(device) | |
print(device) | |
embeddings_model.to(device) | |
for step, batch in enumerate(loader): | |
batch = batch.to('cuda') | |
with torch.inference_mode(): | |
embeddings = embeddings_model(batch, True) | |
all_embeddings = torch.cat((all_embeddings, embeddings), dim=0) | |
print(all_embeddings.shape) | |
all_embeddings = all_embeddings.numpy(force=True) | |
print(all_embeddings.shape) | |
index.add(all_embeddings) | |
print('total embeddings', index.ntotal) | |
faiss.write_index(index, os.path.join(index_path, 'resnet152_unweighted_flat.index')) | |
if __name__ == "__main__": | |
model_path = './data/final-models/resnet_152_classification.pt' | |
image_path = os.path.join('./data/small_images') | |
index_path = os.path.join('./data/index_files') | |
model = load_trained_model(model_path) | |
build_flat_index(image_path, index_path, model) | |