|
import json |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import numpy as np |
|
|
|
from app.models.text_encoder import TextEncoder |
|
|
|
SUMMARY_DIR = Path("data/summaries") |
|
EMBEDDING_DIR = Path("data/embeddings") |
|
NAME_MAP_FILE = Path("data/name_map.json") |
|
|
|
|
|
def search_memes(prompt, top_k=10): |
|
""" |
|
Search for memes based on the input prompt. |
|
|
|
Args: |
|
prompt: The text prompt to search for |
|
top_k: Number of results to return |
|
|
|
Returns: |
|
List of dictionaries containing search results |
|
""" |
|
|
|
results = [] |
|
|
|
|
|
embedding_paths = list(EMBEDDING_DIR.glob("*.npy")) |
|
|
|
|
|
embeddings = [np.load(path) for path in embedding_paths] |
|
|
|
|
|
text_encoder = TextEncoder() |
|
|
|
|
|
prompt_embedding = text_encoder.encode(prompt) |
|
|
|
|
|
similarities = np.dot(embeddings, prompt_embedding) / ( |
|
np.linalg.norm(embeddings, axis=1) * np.linalg.norm(prompt_embedding) |
|
) |
|
|
|
|
|
top_k_indices = np.argsort(similarities)[-top_k:] |
|
|
|
|
|
summaries = [] |
|
for path in SUMMARY_DIR.glob("*.txt"): |
|
with open(path, "r", encoding="utf-8") as f: |
|
summaries.append(f.read()) |
|
|
|
|
|
with open(NAME_MAP_FILE, "r") as f: |
|
name_map = json.load(f) |
|
|
|
|
|
for i, index in enumerate(top_k_indices[::-1]): |
|
try: |
|
result = { |
|
"rank": i + 1, |
|
"similarity": round(float(similarities[index]), 3), |
|
"filename": embedding_paths[index].stem, |
|
"original_filename": name_map.get( |
|
embedding_paths[index].stem, "Unknown" |
|
), |
|
"summary": summaries[index] |
|
if index < len(summaries) |
|
else "No summary available", |
|
} |
|
results.append(result) |
|
except (IndexError, KeyError) as e: |
|
results.append( |
|
{"rank": i + 1, "error": f"Error processing result {i + 1}: {str(e)}"} |
|
) |
|
|
|
return results |
|
|
|
|
|
def format_results(results): |
|
"""Format the results for display in the Gradio interface""" |
|
html_output = "" |
|
for result in results: |
|
if "error" in result: |
|
html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ff6b6b; border-radius: 5px;'>" |
|
html_output += f"<p><b>Rank {result['rank']}:</b> {result['error']}</p>" |
|
html_output += "</div>" |
|
else: |
|
html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 5px;'>" |
|
html_output += f"<p><b>Rank {result['rank']}</b> (Similarity: {result['similarity']})</p>" |
|
html_output += f"<p><b>File:</b> {result['original_filename']}</p>" |
|
html_output += f"<p><b>Summary:</b> {result['summary']}</p>" |
|
html_output += "</div>" |
|
|
|
return html_output |
|
|
|
|
|
|
|
def search_interface(prompt, top_k): |
|
"""Main search interface function for Gradio""" |
|
if not prompt: |
|
return "Please enter a search query" |
|
|
|
results = search_memes(prompt, int(top_k)) |
|
formatted_results = format_results(results) |
|
return formatted_results |
|
|
|
|
|
|
|
with gr.Blocks(title="Meme Search", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# 🔍 Meme Search") |
|
gr.Markdown("Search for memes using natural language descriptions") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
prompt_input = gr.Textbox( |
|
label="Search Query", placeholder="Enter your search here..." |
|
) |
|
with gr.Column(scale=1): |
|
top_k_slider = gr.Slider( |
|
minimum=1, maximum=20, value=5, step=1, label="Number of Results" |
|
) |
|
|
|
search_button = gr.Button("Search", variant="primary") |
|
|
|
output = gr.HTML(label="Results") |
|
|
|
search_button.click( |
|
fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output |
|
) |
|
|
|
|
|
prompt_input.submit( |
|
fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output |
|
) |
|
|
|
gr.Markdown("## How to use") |
|
gr.Markdown(""" |
|
1. Enter a description of the meme you're looking for |
|
2. Adjust the number of results to show |
|
3. Click 'Search' or press Enter |
|
4. Results are sorted by similarity to your query |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|