File size: 6,712 Bytes
27b3217
 
 
 
 
 
 
 
 
 
2bd9f7e
27b3217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
from PIL import Image
from utils.load_models import fclip_model, fclip_processor
from utils.load_models import siglip_model, siglip_preprocess_train, siglip_preprocess_val, siglip_tokenizer

def get_info(catalog, column):
    image_paths = []
    text_descriptions = []
    
    for index, row in catalog.iterrows():
        path = "/home/user/app/images" + str(row["Id"]) + ".jpg"
        image_paths.append(path)
        text_descriptions.append(row[column])
    
    return image_paths, text_descriptions

def normalize_embedding(embedding):
    norm = torch.norm(embedding, p=2, dim=-1, keepdim=True).item()  # Get the norm before normalization
    embedding = embedding / norm
    return embedding.detach().cpu().numpy()

def normalize_embeddings(embeddings):
    norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True)
    normalized_embeddings = embeddings / norm
    return normalized_embeddings

def generate_fclip_embeddings(image_paths, texts, batch_size, alpha):
    image_embeds_list = []
    text_embeds_list = []

    # Batch processing loop
    for i in range(0, len(image_paths), batch_size):
        batch_image_paths = image_paths[i:i + batch_size]
        batch_texts = texts[i:i + batch_size]

        # Load and preprocess batch of images and texts
        images = [Image.open(path).convert("RGB") for path in batch_image_paths]

        # Set the maximum sequence length to 77 to match the position embeddings
        inputs = fclip_processor(text=batch_texts, images=images, return_tensors="pt", padding=True, truncation=True, max_length=77)

        # Move inputs to the GPU
        if torch.cuda.is_available():
          inputs = {k: v.to("cuda") for k, v in inputs.items()} # Move inputs to GPU

        # Generate embeddings
        with torch.no_grad():
            outputs = fclip_model(**inputs)

        image_embeds_list.append(outputs.image_embeds)
        text_embeds_list.append(outputs.text_embeds)

    # Concatenate all embeddings
    image_embeds = torch.cat(image_embeds_list, dim=0)
    text_embeds = torch.cat(text_embeds_list, dim=0)

    # Normalize embeddings
    image_embeds = normalize_embeddings(image_embeds)
    text_embeds = normalize_embeddings(text_embeds)

    # Average embeddings
    avg_embeds = (image_embeds + text_embeds) / 2
    weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
    avg_embeds = normalize_embeddings(avg_embeds)
    weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)

    return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()

def generate_siglip_embeddings(image_paths, texts, batch_size, alpha):
    image_embeds_list = []
    text_embeds_list = []

    # Batch processing loop
    for i in range(0, len(image_paths), batch_size):
        batch_image_paths = image_paths[i:i + batch_size]
        batch_texts = texts[i:i + batch_size]

        # Load and preprocess batch of images and texts
        images = [siglip_preprocess_val(Image.open(image_path).convert('RGB')).unsqueeze(0) for image_path in batch_image_paths]
        images = torch.cat(images)

        tokens = siglip_tokenizer(batch_texts)

        # Move images to the same device as the model weights (GPU if available)
        if torch.cuda.is_available():
            images = images.cuda()
            tokens = tokens.cuda()

        # Generate embeddings
        with torch.no_grad():
            image_embeddings_batch = siglip_model.encode_image(images)
            text_embeddings_batch = siglip_model.encode_text(tokens)

        # Store embeddings
        image_embeds_list.append(image_embeddings_batch)
        text_embeds_list.append(text_embeddings_batch)
        
    # Concatenate all embeddings
    image_embeds = torch.cat(image_embeds_list, dim=0)
    text_embeds = torch.cat(text_embeds_list, dim=0)

    # Normalize embeddings
    image_embeds = normalize_embeddings(image_embeds)
    text_embeds = normalize_embeddings(text_embeds)

    # Average embeddings
    avg_embeds = (image_embeds + text_embeds) / 2
    weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
    avg_embeds = normalize_embeddings(avg_embeds)
    weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)

    return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()

# Function to process text embedding for any model
def generate_text_embedding(model, tokenizer, query, model_type):
    if model_type == "fashionCLIP":
        # Process the text with the tokenizer and move to GPU
        inputs = tokenizer(text=query, return_tensors="pt", padding=True, truncation=True, max_length=77)
        
        if torch.cuda.is_available():
          inputs = {k: v.to("cuda") for k, v in inputs.items()}
        
        # Get text embedding from the model
        text_embed = model.get_text_features(**inputs)
    elif model_type == "fashionSigLIP":
        tokens = tokenizer(query)
        
        # Tokenize text and move to GPU
        if torch.cuda.is_available():
          tokens = tokens.to("cuda")
        
        # Get text embedding from the model
        text_embed = model.encode_text(tokens)

    return normalize_embedding(text_embed)

# Function to process image embedding for any model
def generate_image_embedding(model, processor, image_path, model_type):
    image = Image.open(image_path).convert("RGB")

    if model_type == "fashionCLIP":
        # Preprocess image for FashionCLIP and move to GPU
        inputs = processor(images=image, return_tensors="pt")
        
        if torch.cuda.is_available():
          inputs = {k: v.to("cuda") for k, v in inputs.items()}
        
        # Get image embedding from the model
        image_embed = model.get_image_features(**inputs)
    elif model_type == "fashionSigLIP":
        # Preprocess image for SigLip and move to GPU
        image_tensor = processor(image).unsqueeze(0)
        
        if torch.cuda.is_available():
          image_tensor = image_tensor.to("cuda")
        
        # Get image embedding from the model
        image_embed = model.encode_image(image_tensor)

    return normalize_embedding(image_embed)

# Unified function to generate embeddings for both models and query types
def generate_query_embedding(query, query_type, model, processor, tokenizer, model_type):
    if query_type == "text":
        return generate_text_embedding(model, tokenizer, query, model_type)
    elif query_type == "image":
        return generate_image_embedding(model, processor, query, model_type)
    else:
        raise ValueError("Invalid query type. Choose 'text' or 'image'.")