File size: 2,214 Bytes
d9f5274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import numpy as np
import faiss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from PIL import Image
from inference import load_trained_model
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

transforms = Compose([Resize((232, 232)), ToTensor(),
                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


class FashionDataset(Dataset):
    def __init__(self, transforms, path='./data/small_images/'):
        self.transforms = transforms
        self.path = path

    def __len__(self):
        return len(os.listdir(self.path))

    def __getitem__(self, index):
        image_id = str(index) + '.jpg'

        the_image = Image.open(os.path.join(self.path, image_id))
        transformed_image = self.transforms(the_image)
        return transformed_image


def build_flat_index(img_path, index_path, embeddings_model):
    dimension = 2048
    index = faiss.IndexFlatL2(dimension)  # build the index
    the_dataset = FashionDataset(transforms)
    loader = DataLoader(the_dataset, batch_size=256, shuffle=False, pin_memory=False,
                          num_workers=16)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    all_embeddings = torch.tensor([]).to(device)
    print(device)
    embeddings_model.to(device)
    for step, batch in enumerate(loader):
        batch = batch.to('cuda')
        with torch.inference_mode():
            embeddings = embeddings_model(batch, True)
            all_embeddings = torch.cat((all_embeddings, embeddings), dim=0)
            print(all_embeddings.shape)
    all_embeddings = all_embeddings.numpy(force=True)
    print(all_embeddings.shape)
    index.add(all_embeddings)
    print('total embeddings', index.ntotal)
    faiss.write_index(index, os.path.join(index_path, 'resnet152_unweighted_flat.index'))


if __name__ == "__main__":
    model_path = './data/final-models/resnet_152_classification.pt'
    image_path = os.path.join('./data/small_images')
    index_path = os.path.join('./data/index_files')
    model = load_trained_model(model_path)
    build_flat_index(image_path, index_path, model)