Go-Raw commited on
Commit
f8d8e0f
·
verified ·
1 Parent(s): c48fa80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -21
app.py CHANGED
@@ -1,6 +1,5 @@
1
  from sentence_transformers import SentenceTransformer, util
2
  from huggingface_hub import hf_hub_download
3
- import os
4
  import pickle
5
  import pandas as pd
6
  from PIL import Image
@@ -8,33 +7,69 @@ import requests
8
  from io import BytesIO
9
  import gradio as gr
10
 
11
- pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
 
12
 
13
- embeddings = pickle.load(open(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="meme-embeddings.pkl"), "rb"))
 
 
 
 
14
  df = pd.read_csv(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="input.csv"))
15
 
 
16
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
17
 
 
18
  def generate_memes(prompt):
 
19
  prompt_embedding = model.encode(prompt, convert_to_tensor=True)
 
 
20
  hits = util.semantic_search(prompt_embedding, embeddings, top_k=6)
21
- hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
22
- desired_ids = hits["corpus_id"]
23
- filtered_df = df.loc[df['id'].isin(desired_ids)]
24
- filtered_list = list(filtered_df["url"])
25
- images = [Image.open(BytesIO(requests.get(img).content)) for img in filtered_list]
26
- return (
27
- images
28
- )
29
- input_textbox = gr.inputs.Textbox(lines=1, label="Search something cool")
30
- output_gallery = gr.Gallery(
31
- label="Retrieved Memes", show_label=False, elem_id="gallery"
32
- ).style(columns=[3], rows=[2], object_fit="contain", height="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  title = "Semantic Search for Memes"
34
- description = "Search Memes from small dataset of 6k memes. Check out [GitHub Repo](https://github.com/bhavya-giri/retrieving-memes)"
35
- examples = ["Get Shreked","Going Crazy","Spiderman is my teacher"]
36
- interpretation='default'
37
- enable_queue=True
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- iface = gr.Interface(fn=generate_memes, inputs=input_textbox, outputs=output_gallery,examples=examples,cache_examples=True,title=title,description=description,interpretation=interpretation,enable_queue=enable_queue)
40
- iface.launch(inline=False)
 
1
  from sentence_transformers import SentenceTransformer, util
2
  from huggingface_hub import hf_hub_download
 
3
  import pickle
4
  import pandas as pd
5
  from PIL import Image
 
7
  from io import BytesIO
8
  import gradio as gr
9
 
10
+ # Silence SettingWithCopyWarning from pandas
11
+ pd.options.mode.chained_assignment = None
12
 
13
+ # Load meme embeddings (pre-computed)
14
+ embeddings = pickle.load(open(
15
+ hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="meme-embeddings.pkl"), "rb"))
16
+
17
+ # Load meme metadata (with 'id' and 'url' columns)
18
  df = pd.read_csv(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="input.csv"))
19
 
20
+ # Load sentence transformer model
21
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
22
 
23
+ # Meme search function
24
  def generate_memes(prompt):
25
+ # Encode user prompt into embedding
26
  prompt_embedding = model.encode(prompt, convert_to_tensor=True)
27
+
28
+ # Perform semantic search
29
  hits = util.semantic_search(prompt_embedding, embeddings, top_k=6)
30
+ hits_df = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
31
+
32
+ # Get matching meme URLs from original DataFrame
33
+ matched_ids = hits_df['corpus_id']
34
+ matched_memes = df[df['id'].isin(matched_ids)]
35
+
36
+ # Download and display meme images
37
+ images = []
38
+ for url in matched_memes["url"]:
39
+ try:
40
+ response = requests.get(url)
41
+ image = Image.open(BytesIO(response.content))
42
+ images.append(image)
43
+ except Exception as e:
44
+ print(f"Error loading image from {url}: {e}")
45
+ continue
46
+
47
+ return images
48
+
49
+ # Gradio UI setup
50
+ input_textbox = gr.Textbox(lines=1, label="Search something cool")
51
+
52
+ output_gallery = gr.Gallery(label="Retrieved Memes").style(
53
+ columns=[3], rows=[2], object_fit="contain", height="auto"
54
+ )
55
+
56
+ # App info
57
  title = "Semantic Search for Memes"
58
+ description = "Search memes from a dataset of ~6k memes using semantic similarity. [GitHub Repo](https://github.com/bhavya-giri/retrieving-memes)"
59
+ examples = ["Get Shreked", "Going Crazy", "Spiderman is my teacher"]
60
+
61
+ # Gradio interface
62
+ iface = gr.Interface(
63
+ fn=generate_memes,
64
+ inputs=input_textbox,
65
+ outputs=output_gallery,
66
+ examples=examples,
67
+ cache_examples=True,
68
+ title=title,
69
+ description=description,
70
+ interpretation='default',
71
+ enable_queue=True
72
+ )
73
 
74
+ # Launch the app
75
+ iface.launch()