Spaces:
Runtime error
Runtime error
File size: 2,214 Bytes
d9f5274 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import numpy as np
import faiss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from PIL import Image
from inference import load_trained_model
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
transforms = Compose([Resize((232, 232)), ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
class FashionDataset(Dataset):
def __init__(self, transforms, path='./data/small_images/'):
self.transforms = transforms
self.path = path
def __len__(self):
return len(os.listdir(self.path))
def __getitem__(self, index):
image_id = str(index) + '.jpg'
the_image = Image.open(os.path.join(self.path, image_id))
transformed_image = self.transforms(the_image)
return transformed_image
def build_flat_index(img_path, index_path, embeddings_model):
dimension = 2048
index = faiss.IndexFlatL2(dimension) # build the index
the_dataset = FashionDataset(transforms)
loader = DataLoader(the_dataset, batch_size=256, shuffle=False, pin_memory=False,
num_workers=16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
all_embeddings = torch.tensor([]).to(device)
print(device)
embeddings_model.to(device)
for step, batch in enumerate(loader):
batch = batch.to('cuda')
with torch.inference_mode():
embeddings = embeddings_model(batch, True)
all_embeddings = torch.cat((all_embeddings, embeddings), dim=0)
print(all_embeddings.shape)
all_embeddings = all_embeddings.numpy(force=True)
print(all_embeddings.shape)
index.add(all_embeddings)
print('total embeddings', index.ntotal)
faiss.write_index(index, os.path.join(index_path, 'resnet152_unweighted_flat.index'))
if __name__ == "__main__":
model_path = './data/final-models/resnet_152_classification.pt'
image_path = os.path.join('./data/small_images')
index_path = os.path.join('./data/index_files')
model = load_trained_model(model_path)
build_flat_index(image_path, index_path, model)
|