File size: 5,433 Bytes
2ab45c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from sentence_transformers import SentenceTransformer, util as st_util
from transformers import CLIPModel, CLIPProcessor

from PIL import Image
import requests
import os
import torch
torch.set_printoptions(precision=10)
from tqdm import tqdm
import s3fs
from io import BytesIO
import vector_db

"sentence-transformer-clip-ViT-L-14"
"openai-clip"
model_names = ["fashion"]

model_name_to_ids = {
    "sentence-transformer-clip-ViT-L-14": "clip-ViT-L-14",
    "fashion": "patrickjohncyh/fashion-clip",
    "openai-clip": "openai/clip-vit-base-patch32",
}

AWS_ACCESS_KEY_ID = os.environ["AWS_ACCESS_KEY_ID"]
AWS_SECRET_ACCESS_KEY = os.environ["AWS_SECRET_ACCESS_KEY"]

# Define your bucket and dataset name.
S3_BUCKET = "s3://disco-io"

fs = s3fs.S3FileSystem(
    key=AWS_ACCESS_KEY_ID,
    secret=AWS_SECRET_ACCESS_KEY,
)

ROOT_DATA_PATH = os.path.join(S3_BUCKET, 'data')

def get_data_path():
    return os.path.join(ROOT_DATA_PATH, cur_dataset)

def get_image_path():
    return os.path.join(get_data_path(), 'images')

def get_metadata_path():
    return os.path.join(get_data_path(), 'metadata')

def get_embeddings_path():
    return os.path.join(get_metadata_path(), cur_dataset + '_embeddings.pq')

model_dict = dict()


def download_to_s3(url, s3_path):
    # Download the file from the URL
    response = requests.get(url, stream=True)
    response.raise_for_status()

    # Upload the file to the S3 path
    with fs.open(s3_path, "wb") as s3_file:
        for chunk in response.iter_content(chunk_size=8192):
            s3_file.write(chunk)


def remove_all_files_from_s3_directory(s3_directory):
    # List all objects in the S3 directory
    objects = fs.ls(s3_directory)

    # Remove each object
    for obj in objects:
        try:
            fs.rm(obj)
        except:
            print('Error removing file: ' + obj)

def download_images(df, img_folder):
    remove_all_files_from_s3_directory(img_folder)
    for index, row in df.iterrows():
        try:
            download_to_s3(row['IMG_URL'], os.path.join(img_folder,
                                                        row['title'].replace('/', '_').replace('\n', '') + '.jpg'))
        except:
            print('Error downloading image: ' + str(index) + row['title'])


def load_models():
    for model_name in model_name_to_ids:
        if model_name not in model_dict:
            model_dict[model_name] = dict()
            if model_name.startswith('sentence-transformer'):
                model_dict[model_name]['model'] = SentenceTransformer(model_name_to_ids[model_name])
            else:
                model_dict[model_name]['hf_dir'] = model_name_to_ids[model_name]
                model_dict[model_name]['model'] = CLIPModel.from_pretrained(model_name_to_ids[model_name])
                model_dict[model_name]['processor'] = CLIPProcessor.from_pretrained(model_name_to_ids[model_name])


if len(model_dict) == 0:
    print('Loading models...')
    load_models()


def get_image_embedding(model_name, image):
    """
    Takes an image as input and returns an embedding vector.
    """
    model = model_dict[model_name]['model']
    if model_name.startswith('sentence-transformer'):
        return model.encode(image)
    else:
        inputs = model_dict[model_name]['processor'](images=image, return_tensors="pt")
        image_features = model.get_image_features(**inputs).detach().numpy()[0]
        return image_features

def s3_path_to_image(fs, s3_path):
    """
    Takes an S3 path as input and returns a PIL Image object.

    Args:
        s3_path (str): The path to the image in the S3 bucket, including the bucket name (e.g., "bucket_name/path/to/image.jpg").

    Returns:
        Image: A PIL Image object.
    """
    with fs.open(s3_path, "rb") as f:
        image_data = BytesIO(f.read())
        img = Image.open(image_data)
        return img

def generate_and_save_embeddings():
    # Get image embeddings
    with torch.no_grad():
        for fp in tqdm(fs.ls(get_image_path()), desc="Generate embeddings for Images"):
            if fp.endswith('.jpg'):
                name = fp.split('/')[-1]
                for model_name in model_name_to_ids.keys():
                    s3_path = 's3://' + fp
                    vector_db.add_image_embedding_to_db(
                        embedding=get_image_embedding(model_name, s3_path_to_image(fs, s3_path)),
                        model_name=model_name,
                        dataset_name=cur_dataset,
                        path_to_image=s3_path,
                        image_name=name,
                    )


def get_immediate_subdirectories(s3_path):
    return [obj.split('/')[-1] for obj in fs.glob(f"{s3_path}/*") if fs.isdir(obj)]

all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
cur_dataset = all_datasets[0]

def set_cur_dataset(dataset):
    refresh_all_datasets()
    print(f"Setting current dataset to {dataset}")
    global cur_dataset
    cur_dataset = dataset

def refresh_all_datasets():
    global all_datasets
    all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
    print(f"Refreshing all datasets: {all_datasets}")

def url_to_image(url):
    try:
        response = requests.get(url)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content))
        return img
    except requests.exceptions.RequestException as e:
        print(f"Error fetching image from URL: {url}")
        return None