vinid commited on
Commit
dc3cb2a
1 Parent(s): 61448a4
Files changed (5) hide show
  1. app.py +2 -0
  2. home.py +12 -1
  3. text2image.py +54 -13
  4. tweet_eval_retrieval_twlnk.tsv +0 -0
  5. zeroshot.py +0 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import requests
7
  import transformers
8
  import text2image
 
9
  import tokenizers
10
  from io import BytesIO
11
  import streamlit as st
@@ -24,6 +25,7 @@ st.sidebar.title("Explore our PLIP Demo")
24
  PAGES = {
25
  "Introduction": home,
26
  "Text to Image": text2image,
 
27
  }
28
 
29
  page = st.sidebar.radio("", list(PAGES.keys()))
 
6
  import requests
7
  import transformers
8
  import text2image
9
+ import zeroshot
10
  import tokenizers
11
  from io import BytesIO
12
  import streamlit as st
 
25
  PAGES = {
26
  "Introduction": home,
27
  "Text to Image": text2image,
28
+ "Image Prediction": zeroshot,
29
  }
30
 
31
  page = st.sidebar.radio("", list(PAGES.keys()))
home.py CHANGED
@@ -1,5 +1,6 @@
1
  from pathlib import Path
2
  import streamlit as st
 
3
 
4
 
5
  def read_markdown_file(markdown_file):
@@ -8,4 +9,14 @@ def read_markdown_file(markdown_file):
8
 
9
  def app():
10
  intro_markdown = read_markdown_file("introduction.md")
11
- st.markdown(intro_markdown, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
  import streamlit as st
3
+ import streamlit.components.v1 as components
4
 
5
 
6
  def read_markdown_file(markdown_file):
 
9
 
10
  def app():
11
  intro_markdown = read_markdown_file("introduction.md")
12
+ st.markdown(intro_markdown, unsafe_allow_html=True)
13
+
14
+ st.text('An example of twitter:')
15
+ components.html('''
16
+ <blockquote class="twitter-tweet">
17
+ <a href="https://twitter.com/xxx/status/1580753362059788288"></a>
18
+ </blockquote>
19
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
20
+ </script>
21
+ ''',
22
+ height=600)
text2image.py CHANGED
@@ -4,19 +4,17 @@ from plip_support import embed_text
4
  import numpy as np
5
  from PIL import Image
6
  import requests
7
- import transformers
8
  import tokenizers
9
  from io import BytesIO
10
- import streamlit as st
11
- from transformers import CLIPModel
12
- import clip
13
  import torch
14
  from transformers import (
15
  VisionTextDualEncoderModel,
16
  AutoFeatureExtractor,
17
- AutoTokenizer
 
 
18
  )
19
- from transformers import AutoProcessor
20
 
21
 
22
  def embed_texts(model, texts, processor):
@@ -51,7 +49,8 @@ def load_path_clip():
51
  def app():
52
  st.title('PLIP Image Search')
53
 
54
- plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
 
55
 
56
  model, processor = load_path_clip()
57
 
@@ -59,16 +58,58 @@ def app():
59
 
60
  query = st.text_input('Search Query', '')
61
 
62
-
63
  if query:
64
 
65
  text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
66
 
67
  text_embedding = text_embedding/np.linalg.norm(text_embedding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- best_id = np.argmax(text_embedding.dot(image_embedding.T))
70
- url = (plip_dataset.iloc[best_id]["imageURL"])
71
 
72
- response = requests.get(url)
73
- img = Image.open(BytesIO(response.content))
74
- st.image(img)
 
4
  import numpy as np
5
  from PIL import Image
6
  import requests
 
7
  import tokenizers
8
  from io import BytesIO
 
 
 
9
  import torch
10
  from transformers import (
11
  VisionTextDualEncoderModel,
12
  AutoFeatureExtractor,
13
+ AutoTokenizer,
14
+ CLIPModel,
15
+ AutoProcessor
16
  )
17
+ import streamlit.components.v1 as components
18
 
19
 
20
  def embed_texts(model, texts, processor):
 
49
  def app():
50
  st.title('PLIP Image Search')
51
 
52
+ plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
53
+ plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t")
54
 
55
  model, processor = load_path_clip()
56
 
 
58
 
59
  query = st.text_input('Search Query', '')
60
 
 
61
  if query:
62
 
63
  text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
64
 
65
  text_embedding = text_embedding/np.linalg.norm(text_embedding)
66
+
67
+ # Sort IDs by cosine-similarity from high to low
68
+ similarity_scores = text_embedding.dot(image_embedding.T)
69
+ id_sorted = np.argsort(similarity_scores)[::-1]
70
+
71
+
72
+ best_id = id_sorted[0]
73
+ score = similarity_scores[best_id]
74
+ target_url = plip_imgURL.iloc[best_id]["imageURL"]
75
+ target_weblink = plip_weblink.iloc[best_id]["weblink"]
76
+
77
+ st.caption('Most relevant image (similarity = %.4f)' % score)
78
+ #response = requests.get(target_url)
79
+ #img = Image.open(BytesIO(response.content))
80
+ #st.image(img)
81
+
82
+
83
+ components.html('''
84
+ <blockquote class="twitter-tweet">
85
+ <a href="%s"></a>
86
+ </blockquote>
87
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
88
+ </script>
89
+ ''' % target_weblink,
90
+ height=600)
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
 
 
 
115
 
 
 
 
tweet_eval_retrieval_twlnk.tsv ADDED
The diff for this file is too large to render. See raw diff
 
zeroshot.py ADDED
File without changes