Spaces:
Running
Running
sebastiansarasti
commited on
Commit
•
0d38ded
1
Parent(s):
a8106b8
adding the files to run the app
Browse files- app.py +161 -0
- model.py +75 -0
- requirements.txt +3 -0
- search.py +105 -0
- utils.py +34 -0
app.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw, ImageFont
|
4 |
+
from search import search_similarity, process_image_for_encoder_gradio
|
5 |
+
from utils import str_to_bytes
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
def add_ranking_number(image, rank):
|
9 |
+
"""Añade un número de ranking a la imagen"""
|
10 |
+
img_with_rank = image.copy()
|
11 |
+
draw = ImageDraw.Draw(img_with_rank)
|
12 |
+
|
13 |
+
width, height = image.size
|
14 |
+
circle_radius = min(width, height) // 15
|
15 |
+
circle_position = (circle_radius + 10, circle_radius + 10)
|
16 |
+
|
17 |
+
draw.ellipse(
|
18 |
+
[(circle_position[0] - circle_radius, circle_position[1] - circle_radius),
|
19 |
+
(circle_position[0] + circle_radius, circle_position[1] + circle_radius)],
|
20 |
+
fill='white',
|
21 |
+
outline='black'
|
22 |
+
)
|
23 |
+
|
24 |
+
font_size = circle_radius
|
25 |
+
try:
|
26 |
+
font = ImageFont.truetype("Arial.ttf", font_size)
|
27 |
+
except:
|
28 |
+
font = ImageFont.load_default()
|
29 |
+
|
30 |
+
text = str(rank + 1)
|
31 |
+
text_bbox = draw.textbbox((0, 0), text, font=font)
|
32 |
+
text_width = text_bbox[2] - text_bbox[0]
|
33 |
+
text_height = text_bbox[3] - text_bbox[1]
|
34 |
+
text_position = (
|
35 |
+
circle_position[0] - text_width // 2,
|
36 |
+
circle_position[1] - text_height // 2
|
37 |
+
)
|
38 |
+
|
39 |
+
draw.text(text_position, text, fill='black', font=font)
|
40 |
+
return img_with_rank
|
41 |
+
|
42 |
+
def process_image_result(image_str, rank):
|
43 |
+
"""Convierte una cadena de imagen en un objeto PIL Image con ranking"""
|
44 |
+
try:
|
45 |
+
img = Image.open(BytesIO(str_to_bytes(image_str)))
|
46 |
+
return add_ranking_number(img, rank)
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Error procesando imagen: {e}")
|
49 |
+
return None
|
50 |
+
|
51 |
+
def interface_fn(mode, input_text, input_image, top_k):
|
52 |
+
try:
|
53 |
+
# Determinar qué input usar basado en el modo
|
54 |
+
if mode == "text":
|
55 |
+
if not input_text.strip():
|
56 |
+
return [], "Por favor, ingresa un texto para buscar."
|
57 |
+
input_data = input_text
|
58 |
+
else: # mode == "image"
|
59 |
+
if input_image is None:
|
60 |
+
return [], "Por favor, sube una imagen para buscar."
|
61 |
+
input_data = process_image_for_encoder_gradio(input_image, is_bytes=False)
|
62 |
+
|
63 |
+
# Show the input data
|
64 |
+
print(f"Input data: {input_data}") # Para debugging
|
65 |
+
|
66 |
+
# Realizar la búsqueda
|
67 |
+
results = search_similarity(input_data, mode, int(top_k))
|
68 |
+
|
69 |
+
# Formatear resultados según el modo
|
70 |
+
if mode == "text": # Devuelve imágenes
|
71 |
+
processed_images = []
|
72 |
+
# Si results es una lista de listas, la aplanamos
|
73 |
+
if results and isinstance(results[0], list):
|
74 |
+
print("Recibida lista de listas, aplanando...") # Para debugging
|
75 |
+
results = [item for sublist in results for item in sublist]
|
76 |
+
|
77 |
+
for idx, img_str in enumerate(results):
|
78 |
+
img = process_image_result(img_str, idx)
|
79 |
+
if img is not None:
|
80 |
+
processed_images.append(img)
|
81 |
+
|
82 |
+
if not processed_images:
|
83 |
+
return [], "No se pudieron procesar las imágenes"
|
84 |
+
return processed_images, None
|
85 |
+
|
86 |
+
else: # mode == "image" - Devuelve textos
|
87 |
+
if isinstance(results, list):
|
88 |
+
numbered_texts = [f"{i+1}. {text}" for i, text in enumerate(results)]
|
89 |
+
return [], "\n\n".join(numbered_texts)
|
90 |
+
else:
|
91 |
+
return [], str(results)
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
print(f"Error en interface_fn: {str(e)}")
|
95 |
+
print(f"Tipo de resultados: {type(results)}") # Para debugging
|
96 |
+
return [], f"Error durante la búsqueda: {str(e)}"
|
97 |
+
|
98 |
+
|
99 |
+
def search_text(input_text, top_k):
|
100 |
+
try:
|
101 |
+
if not input_text.strip():
|
102 |
+
return []
|
103 |
+
|
104 |
+
# Realizar la búsqueda
|
105 |
+
results = search_similarity(input_text, "text", int(top_k))
|
106 |
+
|
107 |
+
processed_images = []
|
108 |
+
# Si results es una lista de listas, la aplanamos
|
109 |
+
if results and isinstance(results[0], list):
|
110 |
+
results = [item for sublist in results for item in sublist]
|
111 |
+
|
112 |
+
for idx, img_str in enumerate(results):
|
113 |
+
img = process_image_result(img_str, idx)
|
114 |
+
if img is not None:
|
115 |
+
processed_images.append(img)
|
116 |
+
|
117 |
+
return processed_images
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
print(f"Error en search_text: {str(e)}")
|
121 |
+
return []
|
122 |
+
|
123 |
+
with gr.Blocks() as demo:
|
124 |
+
gr.Markdown("# Buscador de Similitud por Texto")
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
with gr.Column(scale=1):
|
128 |
+
input_text = gr.Textbox(
|
129 |
+
label="Texto de búsqueda",
|
130 |
+
placeholder="Ingresa aquí tu texto...",
|
131 |
+
lines=3
|
132 |
+
)
|
133 |
+
|
134 |
+
top_k = gr.Slider(
|
135 |
+
minimum=1,
|
136 |
+
maximum=20,
|
137 |
+
value=5,
|
138 |
+
step=1,
|
139 |
+
label="Número de resultados",
|
140 |
+
info="¿Cuántos resultados similares quieres ver?"
|
141 |
+
)
|
142 |
+
|
143 |
+
search_button = gr.Button("Buscar")
|
144 |
+
|
145 |
+
with gr.Column(scale=1):
|
146 |
+
output_gallery = gr.Gallery(
|
147 |
+
label="Imágenes similares",
|
148 |
+
columns=3,
|
149 |
+
height="auto"
|
150 |
+
)
|
151 |
+
|
152 |
+
search_button.click(
|
153 |
+
fn=search_text,
|
154 |
+
inputs=[input_text, top_k],
|
155 |
+
outputs=output_gallery
|
156 |
+
)
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
from multiprocessing import freeze_support
|
160 |
+
freeze_support()
|
161 |
+
demo.launch()
|
model.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from huggingface_hub import PyTorchModelHubMixin
|
3 |
+
|
4 |
+
class TextEncoderHead(nn.Module):
|
5 |
+
def __init__(self, model):
|
6 |
+
super(TextEncoderHead, self).__init__()
|
7 |
+
self.model = model
|
8 |
+
for param in self.model.parameters():
|
9 |
+
param.requires_grad = False
|
10 |
+
# uncomment this for chemberta
|
11 |
+
# self.seq1 = nn.Sequential(
|
12 |
+
# nn.Flatten(),
|
13 |
+
# nn.Linear(767*256, 2000),
|
14 |
+
# nn.Dropout(0.3),
|
15 |
+
# nn.ReLU(),
|
16 |
+
# nn.Linear(2000, 512),
|
17 |
+
# nn.LayerNorm(512)
|
18 |
+
# )
|
19 |
+
self.seq1 = nn.Sequential(
|
20 |
+
nn.Flatten(),
|
21 |
+
nn.Linear(768*256, 2000),
|
22 |
+
nn.Dropout(0.3),
|
23 |
+
nn.ReLU(),
|
24 |
+
nn.Linear(2000, 512),
|
25 |
+
nn.LayerNorm(512)
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, input_ids, attention_mask):
|
29 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
30 |
+
# uncomment this for chemberta
|
31 |
+
# outputs = outputs.logits
|
32 |
+
outputs = outputs.last_hidden_state
|
33 |
+
outputs = self.seq1(outputs)
|
34 |
+
return outputs.contiguous()
|
35 |
+
|
36 |
+
class ImageEncoderHead(nn.Module):
|
37 |
+
def __init__(self, model):
|
38 |
+
super(ImageEncoderHead, self).__init__()
|
39 |
+
self.model = model
|
40 |
+
for param in self.model.parameters():
|
41 |
+
param.requires_grad = False
|
42 |
+
# for resnet model
|
43 |
+
# self.seq1 = nn.Sequential(
|
44 |
+
# nn.Flatten(),
|
45 |
+
# nn.Linear(512*7*7, 1000),
|
46 |
+
# nn.Linear(1000, 512),
|
47 |
+
# nn.LayerNorm(512)
|
48 |
+
# )
|
49 |
+
# for vit model
|
50 |
+
self.seq1 = nn.Sequential(
|
51 |
+
nn.Linear(768, 1000),
|
52 |
+
nn.Dropout(0.3),
|
53 |
+
nn.ReLU(),
|
54 |
+
nn.Linear(1000, 512),
|
55 |
+
nn.LayerNorm(512)
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def forward(self, pixel_values):
|
60 |
+
outputs = self.model(pixel_values)
|
61 |
+
outputs = outputs.last_hidden_state.mean(dim=1)
|
62 |
+
outputs = self.seq1(outputs)
|
63 |
+
return outputs.contiguous()
|
64 |
+
|
65 |
+
class CLIPChemistryModel(nn.Module, PyTorchModelHubMixin):
|
66 |
+
def __init__(self, text_encoder, image_encoder):
|
67 |
+
super(CLIPChemistryModel, self).__init__()
|
68 |
+
self.text_encoder = text_encoder
|
69 |
+
self.image_encoder = image_encoder
|
70 |
+
|
71 |
+
def forward(self, image, input_ids, attention_mask):
|
72 |
+
# calculate the embeddings
|
73 |
+
ie = self.image_encoder(image)
|
74 |
+
te = self.text_encoder(input_ids, attention_mask)
|
75 |
+
return ie, te
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pinecone=="5.4.2"
|
2 |
+
transformers=="4.47.0"
|
3 |
+
huggingface-hub=="0.26.5"
|
search.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor, DistilBertModel
|
2 |
+
from pinecone import Pinecone
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
pc = Pinecone()
|
7 |
+
index = pc.Index("clipmodel")
|
8 |
+
|
9 |
+
|
10 |
+
from io import BytesIO
|
11 |
+
import base64
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
import sys
|
15 |
+
|
16 |
+
sys.path.append('../src')
|
17 |
+
|
18 |
+
from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead
|
19 |
+
|
20 |
+
|
21 |
+
ENCODER_BASE = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
22 |
+
IMAGE_BASE = ViTModel.from_pretrained("google/vit-base-patch16-224")
|
23 |
+
text_encoder = TextEncoderHead(model=ENCODER_BASE)
|
24 |
+
image_encoder = ImageEncoderHead(model=IMAGE_BASE)
|
25 |
+
|
26 |
+
clip_model = CLIPChemistryModel(text_encoder=text_encoder, image_encoder=image_encoder)
|
27 |
+
|
28 |
+
clip_model.load_state_dict(torch.load('/Users/sebastianalejandrosarastizambonino/Documents/projects/CLIP_Pytorch/src/best_model_fashion.pth', map_location=torch.device('cpu')))
|
29 |
+
|
30 |
+
te_final = clip_model.text_encoder
|
31 |
+
ie_final = clip_model.image_encoder
|
32 |
+
|
33 |
+
def process_text_for_encoder(text, model):
|
34 |
+
# tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
36 |
+
encoded_input = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=256)
|
37 |
+
input_ids = encoded_input['input_ids']
|
38 |
+
attention_mask = encoded_input['attention_mask']
|
39 |
+
output = model(input_ids=input_ids, attention_mask=attention_mask)
|
40 |
+
return output.detach().numpy().tolist()[0]
|
41 |
+
|
42 |
+
def process_image_for_encoder(image, model):
|
43 |
+
# image = Image.open(BytesIO(image))
|
44 |
+
print(type(image))
|
45 |
+
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
46 |
+
image_tensor = image_processor(image,
|
47 |
+
return_tensors="pt",
|
48 |
+
do_resize=True
|
49 |
+
)['pixel_values']
|
50 |
+
output = model(pixel_values=image_tensor)
|
51 |
+
return output.detach().numpy().tolist()[0]
|
52 |
+
|
53 |
+
def search_similarity(input, mode, top_k=5):
|
54 |
+
if mode == 'text':
|
55 |
+
output = process_text_for_encoder(input, model=te_final)
|
56 |
+
else:
|
57 |
+
output = input
|
58 |
+
|
59 |
+
if mode == 'text':
|
60 |
+
mode_search = 'image'
|
61 |
+
response = index.query(
|
62 |
+
namespace="space-" + mode_search + "-fashion",
|
63 |
+
vector=output,
|
64 |
+
top_k=top_k,
|
65 |
+
include_values=True,
|
66 |
+
include_metadata=True
|
67 |
+
)
|
68 |
+
similar_images = [value['metadata']['image'] for value in response['matches']]
|
69 |
+
return similar_images
|
70 |
+
elif mode == 'image':
|
71 |
+
mode_search = 'text'
|
72 |
+
response = index.query(
|
73 |
+
namespace="space-" + mode_search + "-fashion",
|
74 |
+
vector=output,
|
75 |
+
top_k=top_k,
|
76 |
+
include_values=True,
|
77 |
+
include_metadata=True
|
78 |
+
)
|
79 |
+
similar_text = [value['metadata']['text'] for value in response['matches']]
|
80 |
+
return similar_text
|
81 |
+
else:
|
82 |
+
raise ValueError("mode must be either 'text' or 'image'")
|
83 |
+
|
84 |
+
def process_image_for_encoder_gradio(image, is_bytes=True):
|
85 |
+
"""Procesa tanto imágenes en bytes como objetos PIL Image"""
|
86 |
+
try:
|
87 |
+
if is_bytes:
|
88 |
+
# Si la imagen viene en bytes
|
89 |
+
image = Image.open(BytesIO(image))
|
90 |
+
else:
|
91 |
+
# Si la imagen ya es un objeto PIL Image o viene de gradio
|
92 |
+
if not isinstance(image, Image.Image):
|
93 |
+
# Si viene de gradio, podría ser un numpy array
|
94 |
+
image = Image.fromarray(image)
|
95 |
+
|
96 |
+
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
97 |
+
image_tensor = image_processor(image,
|
98 |
+
return_tensors="pt",
|
99 |
+
do_resize=True
|
100 |
+
)['pixel_values']
|
101 |
+
output = ie_final(pixel_values=image_tensor)
|
102 |
+
return output.detach().numpy().tolist()[0]
|
103 |
+
except Exception as e:
|
104 |
+
print(f"Error en process_image_for_encoder: {e}")
|
105 |
+
raise
|
utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead
|
2 |
+
from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor
|
3 |
+
|
4 |
+
from io import BytesIO
|
5 |
+
import base64
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
def bytes_to_str(bytes_data):
|
10 |
+
return base64.b64encode(bytes_data).decode('utf-8')
|
11 |
+
|
12 |
+
def str_to_bytes(str_data):
|
13 |
+
return base64.b64decode(str_data)
|
14 |
+
|
15 |
+
def push_embeddings_to_pine_cone(index, embeddings, df, mode, length):
|
16 |
+
records = []
|
17 |
+
for i in range(length):
|
18 |
+
if mode == 'text':
|
19 |
+
records.append({
|
20 |
+
"id": str(mode) + str(i),
|
21 |
+
"values": embeddings[i],
|
22 |
+
"metadata": {str(mode): df[mode].iloc[i]}})
|
23 |
+
elif mode == 'image':
|
24 |
+
records.append({
|
25 |
+
"id": str(mode) + str(i),
|
26 |
+
"values": embeddings[i],
|
27 |
+
"metadata": {str(mode): bytes_to_str(df[mode].iloc[i]['bytes'])}})
|
28 |
+
else:
|
29 |
+
raise ValueError("mode must be either 'text' or 'image'")
|
30 |
+
|
31 |
+
index.upsert(
|
32 |
+
vectors=records,
|
33 |
+
namespace="space-" + mode
|
34 |
+
)
|