José Guillermo Araya
search interface
de292ea
import json
from pathlib import Path
import gradio as gr
import numpy as np
from app.models.text_encoder import TextEncoder
SUMMARY_DIR = Path("data/summaries")
EMBEDDING_DIR = Path("data/embeddings")
NAME_MAP_FILE = Path("data/name_map.json")
def search_memes(prompt, top_k=10):
"""
Search for memes based on the input prompt.
Args:
prompt: The text prompt to search for
top_k: Number of results to return
Returns:
List of dictionaries containing search results
"""
# Initialize results list
results = []
# Get the embedding file paths
embedding_paths = list(EMBEDDING_DIR.glob("*.npy"))
# Load the embeddings
embeddings = [np.load(path) for path in embedding_paths]
# Load the text encoder
text_encoder = TextEncoder()
# Generate embeddings for the prompt
prompt_embedding = text_encoder.encode(prompt)
# Calculate similarities
similarities = np.dot(embeddings, prompt_embedding) / (
np.linalg.norm(embeddings, axis=1) * np.linalg.norm(prompt_embedding)
)
# Get the top k indices
top_k_indices = np.argsort(similarities)[-top_k:]
# Load the summaries
summaries = []
for path in SUMMARY_DIR.glob("*.txt"):
with open(path, "r", encoding="utf-8") as f:
summaries.append(f.read())
# Load the name map
with open(NAME_MAP_FILE, "r") as f:
name_map = json.load(f)
# Process the top k results
for i, index in enumerate(top_k_indices[::-1]):
try:
result = {
"rank": i + 1,
"similarity": round(float(similarities[index]), 3),
"filename": embedding_paths[index].stem,
"original_filename": name_map.get(
embedding_paths[index].stem, "Unknown"
),
"summary": summaries[index]
if index < len(summaries)
else "No summary available",
}
results.append(result)
except (IndexError, KeyError) as e:
results.append(
{"rank": i + 1, "error": f"Error processing result {i + 1}: {str(e)}"}
)
return results
def format_results(results):
"""Format the results for display in the Gradio interface"""
html_output = ""
for result in results:
if "error" in result:
html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ff6b6b; border-radius: 5px;'>"
html_output += f"<p><b>Rank {result['rank']}:</b> {result['error']}</p>"
html_output += "</div>"
else:
html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 5px;'>"
html_output += f"<p><b>Rank {result['rank']}</b> (Similarity: {result['similarity']})</p>"
html_output += f"<p><b>File:</b> {result['original_filename']}</p>"
html_output += f"<p><b>Summary:</b> {result['summary']}</p>"
html_output += "</div>"
return html_output
# Define the Gradio interface
def search_interface(prompt, top_k):
"""Main search interface function for Gradio"""
if not prompt:
return "Please enter a search query"
results = search_memes(prompt, int(top_k))
formatted_results = format_results(results)
return formatted_results
# Create the Gradio app
with gr.Blocks(title="Meme Search", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔍 Meme Search")
gr.Markdown("Search for memes using natural language descriptions")
with gr.Row():
with gr.Column(scale=4):
prompt_input = gr.Textbox(
label="Search Query", placeholder="Enter your search here..."
)
with gr.Column(scale=1):
top_k_slider = gr.Slider(
minimum=1, maximum=20, value=5, step=1, label="Number of Results"
)
search_button = gr.Button("Search", variant="primary")
output = gr.HTML(label="Results")
search_button.click(
fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output
)
# Also trigger search on Enter key
prompt_input.submit(
fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output
)
gr.Markdown("## How to use")
gr.Markdown("""
1. Enter a description of the meme you're looking for
2. Adjust the number of results to show
3. Click 'Search' or press Enter
4. Results are sorted by similarity to your query
""")
# Launch the app
if __name__ == "__main__":
demo.launch()