Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
from search import search_similarity, process_image_for_encoder_gradio | |
from utils import str_to_bytes | |
from io import BytesIO | |
def add_ranking_number(image, rank): | |
"""Añade un número de ranking a la imagen""" | |
img_with_rank = image.copy() | |
draw = ImageDraw.Draw(img_with_rank) | |
width, height = image.size | |
circle_radius = min(width, height) // 15 | |
circle_position = (circle_radius + 10, circle_radius + 10) | |
draw.ellipse( | |
[(circle_position[0] - circle_radius, circle_position[1] - circle_radius), | |
(circle_position[0] + circle_radius, circle_position[1] + circle_radius)], | |
fill='white', | |
outline='black' | |
) | |
font_size = circle_radius | |
try: | |
font = ImageFont.truetype("Arial.ttf", font_size) | |
except: | |
font = ImageFont.load_default() | |
text = str(rank + 1) | |
text_bbox = draw.textbbox((0, 0), text, font=font) | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
text_position = ( | |
circle_position[0] - text_width // 2, | |
circle_position[1] - text_height // 2 | |
) | |
draw.text(text_position, text, fill='black', font=font) | |
return img_with_rank | |
def process_image_result(image_str, rank): | |
"""Convierte una cadena de imagen en un objeto PIL Image con ranking""" | |
try: | |
img = Image.open(BytesIO(str_to_bytes(image_str))) | |
return add_ranking_number(img, rank) | |
except Exception as e: | |
print(f"Error procesando imagen: {e}") | |
return None | |
def interface_fn(mode, input_text, input_image, top_k): | |
try: | |
# Determinar qué input usar basado en el modo | |
if mode == "text": | |
if not input_text.strip(): | |
return [], "Please, add the description of the clothes to search." | |
input_data = input_text | |
else: # mode == "image" | |
if input_image is None: | |
return [], "Please, select an image" | |
input_data = process_image_for_encoder_gradio(input_image, is_bytes=False) | |
# Show the input data | |
print(f"Input data: {input_data}") # Para debugging | |
# Realizar la búsqueda | |
results = search_similarity(input_data, mode, int(top_k)) | |
# Formatear resultados según el modo | |
if mode == "text": # Devuelve imágenes | |
processed_images = [] | |
# Si results es una lista de listas, la aplanamos | |
if results and isinstance(results[0], list): | |
print("Recibida lista de listas, aplanando...") # Para debugging | |
results = [item for sublist in results for item in sublist] | |
for idx, img_str in enumerate(results): | |
img = process_image_result(img_str, idx) | |
if img is not None: | |
processed_images.append(img) | |
if not processed_images: | |
return [], "No se pudieron procesar las imágenes" | |
return processed_images, None | |
else: # mode == "image" - Devuelve textos | |
if isinstance(results, list): | |
numbered_texts = [f"{i+1}. {text}" for i, text in enumerate(results)] | |
return [], "\n\n".join(numbered_texts) | |
else: | |
return [], str(results) | |
except Exception as e: | |
print(f"Error en interface_fn: {str(e)}") | |
print(f"Tipo de resultados: {type(results)}") # Para debugging | |
return [], f"Error durante la búsqueda: {str(e)}" | |
def search_text(input_text, top_k): | |
try: | |
if not input_text.strip(): | |
return [] | |
# Realizar la búsqueda | |
results = search_similarity(input_text, "text", int(top_k)) | |
processed_images = [] | |
# Si results es una lista de listas, la aplanamos | |
if results and isinstance(results[0], list): | |
results = [item for sublist in results for item in sublist] | |
for idx, img_str in enumerate(results): | |
img = process_image_result(img_str, idx) | |
if img is not None: | |
processed_images.append(img) | |
return processed_images | |
except Exception as e: | |
print(f"Error en search_text: {str(e)}") | |
return [] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image search engine based on descriptions (CLIP model)") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_text = gr.Textbox( | |
label="Search text", | |
placeholder="Insert your description here...", | |
lines=3 | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=5, | |
step=1, | |
label="Number of images", | |
info="How many images do you want to see?" | |
) | |
search_button = gr.Button("Search") | |
with gr.Column(scale=1): | |
output_gallery = gr.Gallery( | |
label="Similar images", | |
columns=3, | |
height="auto" | |
) | |
search_button.click( | |
fn=search_text, | |
inputs=[input_text, top_k], | |
outputs=output_gallery | |
) | |
if __name__ == "__main__": | |
from multiprocessing import freeze_support | |
freeze_support() | |
demo.launch() |