File size: 4,636 Bytes
de292ea 71947f0 de292ea 71947f0 de292ea 71947f0 de292ea 71947f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import json
from pathlib import Path
import gradio as gr
import numpy as np
from app.models.text_encoder import TextEncoder
SUMMARY_DIR = Path("data/summaries")
EMBEDDING_DIR = Path("data/embeddings")
NAME_MAP_FILE = Path("data/name_map.json")
def search_memes(prompt, top_k=10):
"""
Search for memes based on the input prompt.
Args:
prompt: The text prompt to search for
top_k: Number of results to return
Returns:
List of dictionaries containing search results
"""
# Initialize results list
results = []
# Get the embedding file paths
embedding_paths = list(EMBEDDING_DIR.glob("*.npy"))
# Load the embeddings
embeddings = [np.load(path) for path in embedding_paths]
# Load the text encoder
text_encoder = TextEncoder()
# Generate embeddings for the prompt
prompt_embedding = text_encoder.encode(prompt)
# Calculate similarities
similarities = np.dot(embeddings, prompt_embedding) / (
np.linalg.norm(embeddings, axis=1) * np.linalg.norm(prompt_embedding)
)
# Get the top k indices
top_k_indices = np.argsort(similarities)[-top_k:]
# Load the summaries
summaries = []
for path in SUMMARY_DIR.glob("*.txt"):
with open(path, "r", encoding="utf-8") as f:
summaries.append(f.read())
# Load the name map
with open(NAME_MAP_FILE, "r") as f:
name_map = json.load(f)
# Process the top k results
for i, index in enumerate(top_k_indices[::-1]):
try:
result = {
"rank": i + 1,
"similarity": round(float(similarities[index]), 3),
"filename": embedding_paths[index].stem,
"original_filename": name_map.get(
embedding_paths[index].stem, "Unknown"
),
"summary": summaries[index]
if index < len(summaries)
else "No summary available",
}
results.append(result)
except (IndexError, KeyError) as e:
results.append(
{"rank": i + 1, "error": f"Error processing result {i + 1}: {str(e)}"}
)
return results
def format_results(results):
"""Format the results for display in the Gradio interface"""
html_output = ""
for result in results:
if "error" in result:
html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ff6b6b; border-radius: 5px;'>"
html_output += f"<p><b>Rank {result['rank']}:</b> {result['error']}</p>"
html_output += "</div>"
else:
html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 5px;'>"
html_output += f"<p><b>Rank {result['rank']}</b> (Similarity: {result['similarity']})</p>"
html_output += f"<p><b>File:</b> {result['original_filename']}</p>"
html_output += f"<p><b>Summary:</b> {result['summary']}</p>"
html_output += "</div>"
return html_output
# Define the Gradio interface
def search_interface(prompt, top_k):
"""Main search interface function for Gradio"""
if not prompt:
return "Please enter a search query"
results = search_memes(prompt, int(top_k))
formatted_results = format_results(results)
return formatted_results
# Create the Gradio app
with gr.Blocks(title="Meme Search", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔍 Meme Search")
gr.Markdown("Search for memes using natural language descriptions")
with gr.Row():
with gr.Column(scale=4):
prompt_input = gr.Textbox(
label="Search Query", placeholder="Enter your search here..."
)
with gr.Column(scale=1):
top_k_slider = gr.Slider(
minimum=1, maximum=20, value=5, step=1, label="Number of Results"
)
search_button = gr.Button("Search", variant="primary")
output = gr.HTML(label="Results")
search_button.click(
fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output
)
# Also trigger search on Enter key
prompt_input.submit(
fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output
)
gr.Markdown("## How to use")
gr.Markdown("""
1. Enter a description of the meme you're looking for
2. Adjust the number of results to show
3. Click 'Search' or press Enter
4. Results are sorted by similarity to your query
""")
# Launch the app
if __name__ == "__main__":
demo.launch()
|