File size: 1,713 Bytes
9c7a90d
 
 
 
 
 
 
 
 
e5126f7
9c7a90d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d140a00
226adc9
9c7a90d
 
 
 
 
 
 
 
362d9c3
9c7a90d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from sentence_transformers import SentenceTransformer, util
from huggingface_hub import hf_hub_download
import os
import pickle
import pandas as pd
import gradio as gr

pd.options.mode.chained_assignment = None  # Turn off SettingWithCopyWarning

embeddings = pickle.load(open(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="meme-embeddings.pkl"), "rb"))
df = pd.read_csv(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="input.csv"))

model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

def generate_memes(prompt):
    prompt_embedding = model.encode(prompt, convert_to_tensor=True)
    hits = util.semantic_search(prompt_embedding, embeddings, top_k=5)
    hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
    desired_ids = hits["corpus_id"]
    filtered_df = df.loc[df['id'].isin(desired_ids)]
    filtered_list = list(filtered_df["url"]) 
    images = [gr.Image.update(value=img, visible=True) for img in filtered_list]
    return (
       images
    )
input_textbox = gr.inputs.Textbox(lines=1, label="Search something cool")
output_gallery = gr.Gallery(
            label="Retrieved Memes", show_label=False, elem_id="gallery"
        ).style(columns=[3], rows=[2], object_fit="contain", height="auto")
title = "Semantic Search for Memes"
description = "Search Memes from small dataset of 6k memes"
examples = ['Spiderman giving lecture', 'Angry Karen']
interpretation='default'
enable_queue=True

iface = gr.Interface(fn=generate_memes, inputs=input_textbox, outputs=output_gallery,examples=examples,title=title,description=description,interpretation=interpretation,enable_queue=enable_queue)
iface.launch(inline=False)