File size: 4,636 Bytes
de292ea
 
 
71947f0
de292ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71947f0
de292ea
 
 
 
 
 
 
71947f0
de292ea
71947f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import json
from pathlib import Path

import gradio as gr
import numpy as np

from app.models.text_encoder import TextEncoder

SUMMARY_DIR = Path("data/summaries")
EMBEDDING_DIR = Path("data/embeddings")
NAME_MAP_FILE = Path("data/name_map.json")


def search_memes(prompt, top_k=10):
    """
    Search for memes based on the input prompt.

    Args:
        prompt: The text prompt to search for
        top_k: Number of results to return

    Returns:
        List of dictionaries containing search results
    """
    # Initialize results list
    results = []

    # Get the embedding file paths
    embedding_paths = list(EMBEDDING_DIR.glob("*.npy"))

    # Load the embeddings
    embeddings = [np.load(path) for path in embedding_paths]

    # Load the text encoder
    text_encoder = TextEncoder()

    # Generate embeddings for the prompt
    prompt_embedding = text_encoder.encode(prompt)

    # Calculate similarities
    similarities = np.dot(embeddings, prompt_embedding) / (
        np.linalg.norm(embeddings, axis=1) * np.linalg.norm(prompt_embedding)
    )

    # Get the top k indices
    top_k_indices = np.argsort(similarities)[-top_k:]

    # Load the summaries
    summaries = []
    for path in SUMMARY_DIR.glob("*.txt"):
        with open(path, "r", encoding="utf-8") as f:
            summaries.append(f.read())

    # Load the name map
    with open(NAME_MAP_FILE, "r") as f:
        name_map = json.load(f)

    # Process the top k results
    for i, index in enumerate(top_k_indices[::-1]):
        try:
            result = {
                "rank": i + 1,
                "similarity": round(float(similarities[index]), 3),
                "filename": embedding_paths[index].stem,
                "original_filename": name_map.get(
                    embedding_paths[index].stem, "Unknown"
                ),
                "summary": summaries[index]
                if index < len(summaries)
                else "No summary available",
            }
            results.append(result)
        except (IndexError, KeyError) as e:
            results.append(
                {"rank": i + 1, "error": f"Error processing result {i + 1}: {str(e)}"}
            )

    return results


def format_results(results):
    """Format the results for display in the Gradio interface"""
    html_output = ""
    for result in results:
        if "error" in result:
            html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ff6b6b; border-radius: 5px;'>"
            html_output += f"<p><b>Rank {result['rank']}:</b> {result['error']}</p>"
            html_output += "</div>"
        else:
            html_output += "<div style='margin-bottom: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 5px;'>"
            html_output += f"<p><b>Rank {result['rank']}</b> (Similarity: {result['similarity']})</p>"
            html_output += f"<p><b>File:</b> {result['original_filename']}</p>"
            html_output += f"<p><b>Summary:</b> {result['summary']}</p>"
            html_output += "</div>"

    return html_output


# Define the Gradio interface
def search_interface(prompt, top_k):
    """Main search interface function for Gradio"""
    if not prompt:
        return "Please enter a search query"

    results = search_memes(prompt, int(top_k))
    formatted_results = format_results(results)
    return formatted_results


# Create the Gradio app
with gr.Blocks(title="Meme Search", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🔍 Meme Search")
    gr.Markdown("Search for memes using natural language descriptions")

    with gr.Row():
        with gr.Column(scale=4):
            prompt_input = gr.Textbox(
                label="Search Query", placeholder="Enter your search here..."
            )
        with gr.Column(scale=1):
            top_k_slider = gr.Slider(
                minimum=1, maximum=20, value=5, step=1, label="Number of Results"
            )

    search_button = gr.Button("Search", variant="primary")

    output = gr.HTML(label="Results")

    search_button.click(
        fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output
    )

    # Also trigger search on Enter key
    prompt_input.submit(
        fn=search_interface, inputs=[prompt_input, top_k_slider], outputs=output
    )

    gr.Markdown("## How to use")
    gr.Markdown("""
    1. Enter a description of the meme you're looking for
    2. Adjust the number of results to show
    3. Click 'Search' or press Enter
    4. Results are sorted by similarity to your query
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()