import json from pathlib import Path import gradio as gr import numpy as np from app.models.text_encoder import TextEncoder SUMMARY_DIR = Path("data/summaries") EMBEDDING_DIR = Path("data/embeddings") NAME_MAP_FILE = Path("data/name_map.json") def search_memes(prompt, top_k=10): """ Search for memes based on the input prompt. Args: prompt: The text prompt to search for top_k: Number of results to return Returns: List of dictionaries containing search results """ # Initialize results list results = [] # Get the embedding file paths embedding_paths = list(EMBEDDING_DIR.glob("*.npy")) # Load the embeddings embeddings = [np.load(path) for path in embedding_paths] # Load the text encoder text_encoder = TextEncoder() # Generate embeddings for the prompt prompt_embedding = text_encoder.encode(prompt) # Calculate similarities similarities = np.dot(embeddings, prompt_embedding) / ( np.linalg.norm(embeddings, axis=1) * np.linalg.norm(prompt_embedding) ) # Get the top k indices top_k_indices = np.argsort(similarities)[-top_k:] # Load the summaries summaries = [] for path in SUMMARY_DIR.glob("*.txt"): with open(path, "r", encoding="utf-8") as f: summaries.append(f.read()) # Load the name map with open(NAME_MAP_FILE, "r") as f: name_map = json.load(f) # Process the top k results for i, index in enumerate(top_k_indices[::-1]): try: result = { "rank": i + 1, "similarity": round(float(similarities[index]), 3), "filename": embedding_paths[index].stem, "original_filename": name_map.get( embedding_paths[index].stem, "Unknown" ), "summary": summaries[index] if index < len(summaries) else "No summary available", } results.append(result) except (IndexError, KeyError) as e: results.append( {"rank": i + 1, "error": f"Error processing result {i + 1}: {str(e)}"} ) return results def format_results(results): """Format the results for display in the Gradio interface""" html_output = "" for result in results: if "error" in result: html_output += "
" html_output += f"

Rank {result['rank']}: {result['error']}

" html_output += "
" else: html_output += "
" html_output += f"

Rank {result['rank']} (Similarity: {result['similarity']})

" html_output += f"

File: {result['original_filename']}

" html_output += f"

Summary: {result['summary']}

" html_output += "
" return html_output # Define the Gradio interface def search_interface(prompt, top_k): """Main search interface function for Gradio""" if not prompt: return "Please enter a search query" results = search_memes(prompt, int(top_k)) formatted_results = format_results(results) return formatted_results # Create the Gradio app with gr.Blocks(title="Meme Search", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔍 Meme Search") gr.Markdown("Search for memes using natural language descriptions") with gr.Row(): with gr.Column(scale=4): prompt_input = gr.Textbox( label="Search Query", placeholder="Enter your search here..." ) with gr.Column(scale=1): top_k_slider = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Number of Results" ) search_button = gr.Button("Search", variant="primary") output = gr.HTML(label="Results") search_button.click( fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output ) # Also trigger search on Enter key prompt_input.submit( fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output ) gr.Markdown("## How to use") gr.Markdown(""" 1. Enter a description of the meme you're looking for 2. Adjust the number of results to show 3. Click 'Search' or press Enter 4. Results are sorted by similarity to your query """) # Launch the app if __name__ == "__main__": demo.launch()