sebastiansarasti commited on
Commit
0d38ded
1 Parent(s): a8106b8

adding the files to run the app

Browse files
Files changed (5) hide show
  1. app.py +161 -0
  2. model.py +75 -0
  3. requirements.txt +3 -0
  4. search.py +105 -0
  5. utils.py +34 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ from search import search_similarity, process_image_for_encoder_gradio
5
+ from utils import str_to_bytes
6
+ from io import BytesIO
7
+
8
+ def add_ranking_number(image, rank):
9
+ """Añade un número de ranking a la imagen"""
10
+ img_with_rank = image.copy()
11
+ draw = ImageDraw.Draw(img_with_rank)
12
+
13
+ width, height = image.size
14
+ circle_radius = min(width, height) // 15
15
+ circle_position = (circle_radius + 10, circle_radius + 10)
16
+
17
+ draw.ellipse(
18
+ [(circle_position[0] - circle_radius, circle_position[1] - circle_radius),
19
+ (circle_position[0] + circle_radius, circle_position[1] + circle_radius)],
20
+ fill='white',
21
+ outline='black'
22
+ )
23
+
24
+ font_size = circle_radius
25
+ try:
26
+ font = ImageFont.truetype("Arial.ttf", font_size)
27
+ except:
28
+ font = ImageFont.load_default()
29
+
30
+ text = str(rank + 1)
31
+ text_bbox = draw.textbbox((0, 0), text, font=font)
32
+ text_width = text_bbox[2] - text_bbox[0]
33
+ text_height = text_bbox[3] - text_bbox[1]
34
+ text_position = (
35
+ circle_position[0] - text_width // 2,
36
+ circle_position[1] - text_height // 2
37
+ )
38
+
39
+ draw.text(text_position, text, fill='black', font=font)
40
+ return img_with_rank
41
+
42
+ def process_image_result(image_str, rank):
43
+ """Convierte una cadena de imagen en un objeto PIL Image con ranking"""
44
+ try:
45
+ img = Image.open(BytesIO(str_to_bytes(image_str)))
46
+ return add_ranking_number(img, rank)
47
+ except Exception as e:
48
+ print(f"Error procesando imagen: {e}")
49
+ return None
50
+
51
+ def interface_fn(mode, input_text, input_image, top_k):
52
+ try:
53
+ # Determinar qué input usar basado en el modo
54
+ if mode == "text":
55
+ if not input_text.strip():
56
+ return [], "Por favor, ingresa un texto para buscar."
57
+ input_data = input_text
58
+ else: # mode == "image"
59
+ if input_image is None:
60
+ return [], "Por favor, sube una imagen para buscar."
61
+ input_data = process_image_for_encoder_gradio(input_image, is_bytes=False)
62
+
63
+ # Show the input data
64
+ print(f"Input data: {input_data}") # Para debugging
65
+
66
+ # Realizar la búsqueda
67
+ results = search_similarity(input_data, mode, int(top_k))
68
+
69
+ # Formatear resultados según el modo
70
+ if mode == "text": # Devuelve imágenes
71
+ processed_images = []
72
+ # Si results es una lista de listas, la aplanamos
73
+ if results and isinstance(results[0], list):
74
+ print("Recibida lista de listas, aplanando...") # Para debugging
75
+ results = [item for sublist in results for item in sublist]
76
+
77
+ for idx, img_str in enumerate(results):
78
+ img = process_image_result(img_str, idx)
79
+ if img is not None:
80
+ processed_images.append(img)
81
+
82
+ if not processed_images:
83
+ return [], "No se pudieron procesar las imágenes"
84
+ return processed_images, None
85
+
86
+ else: # mode == "image" - Devuelve textos
87
+ if isinstance(results, list):
88
+ numbered_texts = [f"{i+1}. {text}" for i, text in enumerate(results)]
89
+ return [], "\n\n".join(numbered_texts)
90
+ else:
91
+ return [], str(results)
92
+
93
+ except Exception as e:
94
+ print(f"Error en interface_fn: {str(e)}")
95
+ print(f"Tipo de resultados: {type(results)}") # Para debugging
96
+ return [], f"Error durante la búsqueda: {str(e)}"
97
+
98
+
99
+ def search_text(input_text, top_k):
100
+ try:
101
+ if not input_text.strip():
102
+ return []
103
+
104
+ # Realizar la búsqueda
105
+ results = search_similarity(input_text, "text", int(top_k))
106
+
107
+ processed_images = []
108
+ # Si results es una lista de listas, la aplanamos
109
+ if results and isinstance(results[0], list):
110
+ results = [item for sublist in results for item in sublist]
111
+
112
+ for idx, img_str in enumerate(results):
113
+ img = process_image_result(img_str, idx)
114
+ if img is not None:
115
+ processed_images.append(img)
116
+
117
+ return processed_images
118
+
119
+ except Exception as e:
120
+ print(f"Error en search_text: {str(e)}")
121
+ return []
122
+
123
+ with gr.Blocks() as demo:
124
+ gr.Markdown("# Buscador de Similitud por Texto")
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=1):
128
+ input_text = gr.Textbox(
129
+ label="Texto de búsqueda",
130
+ placeholder="Ingresa aquí tu texto...",
131
+ lines=3
132
+ )
133
+
134
+ top_k = gr.Slider(
135
+ minimum=1,
136
+ maximum=20,
137
+ value=5,
138
+ step=1,
139
+ label="Número de resultados",
140
+ info="¿Cuántos resultados similares quieres ver?"
141
+ )
142
+
143
+ search_button = gr.Button("Buscar")
144
+
145
+ with gr.Column(scale=1):
146
+ output_gallery = gr.Gallery(
147
+ label="Imágenes similares",
148
+ columns=3,
149
+ height="auto"
150
+ )
151
+
152
+ search_button.click(
153
+ fn=search_text,
154
+ inputs=[input_text, top_k],
155
+ outputs=output_gallery
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ from multiprocessing import freeze_support
160
+ freeze_support()
161
+ demo.launch()
model.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from huggingface_hub import PyTorchModelHubMixin
3
+
4
+ class TextEncoderHead(nn.Module):
5
+ def __init__(self, model):
6
+ super(TextEncoderHead, self).__init__()
7
+ self.model = model
8
+ for param in self.model.parameters():
9
+ param.requires_grad = False
10
+ # uncomment this for chemberta
11
+ # self.seq1 = nn.Sequential(
12
+ # nn.Flatten(),
13
+ # nn.Linear(767*256, 2000),
14
+ # nn.Dropout(0.3),
15
+ # nn.ReLU(),
16
+ # nn.Linear(2000, 512),
17
+ # nn.LayerNorm(512)
18
+ # )
19
+ self.seq1 = nn.Sequential(
20
+ nn.Flatten(),
21
+ nn.Linear(768*256, 2000),
22
+ nn.Dropout(0.3),
23
+ nn.ReLU(),
24
+ nn.Linear(2000, 512),
25
+ nn.LayerNorm(512)
26
+ )
27
+
28
+ def forward(self, input_ids, attention_mask):
29
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
30
+ # uncomment this for chemberta
31
+ # outputs = outputs.logits
32
+ outputs = outputs.last_hidden_state
33
+ outputs = self.seq1(outputs)
34
+ return outputs.contiguous()
35
+
36
+ class ImageEncoderHead(nn.Module):
37
+ def __init__(self, model):
38
+ super(ImageEncoderHead, self).__init__()
39
+ self.model = model
40
+ for param in self.model.parameters():
41
+ param.requires_grad = False
42
+ # for resnet model
43
+ # self.seq1 = nn.Sequential(
44
+ # nn.Flatten(),
45
+ # nn.Linear(512*7*7, 1000),
46
+ # nn.Linear(1000, 512),
47
+ # nn.LayerNorm(512)
48
+ # )
49
+ # for vit model
50
+ self.seq1 = nn.Sequential(
51
+ nn.Linear(768, 1000),
52
+ nn.Dropout(0.3),
53
+ nn.ReLU(),
54
+ nn.Linear(1000, 512),
55
+ nn.LayerNorm(512)
56
+ )
57
+
58
+
59
+ def forward(self, pixel_values):
60
+ outputs = self.model(pixel_values)
61
+ outputs = outputs.last_hidden_state.mean(dim=1)
62
+ outputs = self.seq1(outputs)
63
+ return outputs.contiguous()
64
+
65
+ class CLIPChemistryModel(nn.Module, PyTorchModelHubMixin):
66
+ def __init__(self, text_encoder, image_encoder):
67
+ super(CLIPChemistryModel, self).__init__()
68
+ self.text_encoder = text_encoder
69
+ self.image_encoder = image_encoder
70
+
71
+ def forward(self, image, input_ids, attention_mask):
72
+ # calculate the embeddings
73
+ ie = self.image_encoder(image)
74
+ te = self.text_encoder(input_ids, attention_mask)
75
+ return ie, te
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pinecone=="5.4.2"
2
+ transformers=="4.47.0"
3
+ huggingface-hub=="0.26.5"
search.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor, DistilBertModel
2
+ from pinecone import Pinecone
3
+ import torch
4
+
5
+
6
+ pc = Pinecone()
7
+ index = pc.Index("clipmodel")
8
+
9
+
10
+ from io import BytesIO
11
+ import base64
12
+ from PIL import Image
13
+
14
+ import sys
15
+
16
+ sys.path.append('../src')
17
+
18
+ from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead
19
+
20
+
21
+ ENCODER_BASE = DistilBertModel.from_pretrained("distilbert-base-uncased")
22
+ IMAGE_BASE = ViTModel.from_pretrained("google/vit-base-patch16-224")
23
+ text_encoder = TextEncoderHead(model=ENCODER_BASE)
24
+ image_encoder = ImageEncoderHead(model=IMAGE_BASE)
25
+
26
+ clip_model = CLIPChemistryModel(text_encoder=text_encoder, image_encoder=image_encoder)
27
+
28
+ clip_model.load_state_dict(torch.load('/Users/sebastianalejandrosarastizambonino/Documents/projects/CLIP_Pytorch/src/best_model_fashion.pth', map_location=torch.device('cpu')))
29
+
30
+ te_final = clip_model.text_encoder
31
+ ie_final = clip_model.image_encoder
32
+
33
+ def process_text_for_encoder(text, model):
34
+ # tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
35
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
36
+ encoded_input = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=256)
37
+ input_ids = encoded_input['input_ids']
38
+ attention_mask = encoded_input['attention_mask']
39
+ output = model(input_ids=input_ids, attention_mask=attention_mask)
40
+ return output.detach().numpy().tolist()[0]
41
+
42
+ def process_image_for_encoder(image, model):
43
+ # image = Image.open(BytesIO(image))
44
+ print(type(image))
45
+ image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
46
+ image_tensor = image_processor(image,
47
+ return_tensors="pt",
48
+ do_resize=True
49
+ )['pixel_values']
50
+ output = model(pixel_values=image_tensor)
51
+ return output.detach().numpy().tolist()[0]
52
+
53
+ def search_similarity(input, mode, top_k=5):
54
+ if mode == 'text':
55
+ output = process_text_for_encoder(input, model=te_final)
56
+ else:
57
+ output = input
58
+
59
+ if mode == 'text':
60
+ mode_search = 'image'
61
+ response = index.query(
62
+ namespace="space-" + mode_search + "-fashion",
63
+ vector=output,
64
+ top_k=top_k,
65
+ include_values=True,
66
+ include_metadata=True
67
+ )
68
+ similar_images = [value['metadata']['image'] for value in response['matches']]
69
+ return similar_images
70
+ elif mode == 'image':
71
+ mode_search = 'text'
72
+ response = index.query(
73
+ namespace="space-" + mode_search + "-fashion",
74
+ vector=output,
75
+ top_k=top_k,
76
+ include_values=True,
77
+ include_metadata=True
78
+ )
79
+ similar_text = [value['metadata']['text'] for value in response['matches']]
80
+ return similar_text
81
+ else:
82
+ raise ValueError("mode must be either 'text' or 'image'")
83
+
84
+ def process_image_for_encoder_gradio(image, is_bytes=True):
85
+ """Procesa tanto imágenes en bytes como objetos PIL Image"""
86
+ try:
87
+ if is_bytes:
88
+ # Si la imagen viene en bytes
89
+ image = Image.open(BytesIO(image))
90
+ else:
91
+ # Si la imagen ya es un objeto PIL Image o viene de gradio
92
+ if not isinstance(image, Image.Image):
93
+ # Si viene de gradio, podría ser un numpy array
94
+ image = Image.fromarray(image)
95
+
96
+ image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
97
+ image_tensor = image_processor(image,
98
+ return_tensors="pt",
99
+ do_resize=True
100
+ )['pixel_values']
101
+ output = ie_final(pixel_values=image_tensor)
102
+ return output.detach().numpy().tolist()[0]
103
+ except Exception as e:
104
+ print(f"Error en process_image_for_encoder: {e}")
105
+ raise
utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead
2
+ from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor
3
+
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ from PIL import Image
8
+
9
+ def bytes_to_str(bytes_data):
10
+ return base64.b64encode(bytes_data).decode('utf-8')
11
+
12
+ def str_to_bytes(str_data):
13
+ return base64.b64decode(str_data)
14
+
15
+ def push_embeddings_to_pine_cone(index, embeddings, df, mode, length):
16
+ records = []
17
+ for i in range(length):
18
+ if mode == 'text':
19
+ records.append({
20
+ "id": str(mode) + str(i),
21
+ "values": embeddings[i],
22
+ "metadata": {str(mode): df[mode].iloc[i]}})
23
+ elif mode == 'image':
24
+ records.append({
25
+ "id": str(mode) + str(i),
26
+ "values": embeddings[i],
27
+ "metadata": {str(mode): bytes_to_str(df[mode].iloc[i]['bytes'])}})
28
+ else:
29
+ raise ValueError("mode must be either 'text' or 'image'")
30
+
31
+ index.upsert(
32
+ vectors=records,
33
+ namespace="space-" + mode
34
+ )