Spaces:
Runtime error
Runtime error
image2image search
Browse files- app.py +5 -19
- home.py +1 -1
- image2image.py +109 -0
app.py
CHANGED
@@ -1,31 +1,17 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import pandas as pd
|
3 |
import home
|
4 |
-
import numpy as np
|
5 |
-
from PIL import Image
|
6 |
-
import requests
|
7 |
-
import transformers
|
8 |
import text2image
|
9 |
-
import
|
10 |
-
import tokenizers
|
11 |
-
from io import BytesIO
|
12 |
import streamlit as st
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
from transformers import (
|
17 |
-
VisionTextDualEncoderModel,
|
18 |
-
AutoFeatureExtractor,
|
19 |
-
AutoTokenizer
|
20 |
-
)
|
21 |
-
from transformers import AutoProcessor
|
22 |
|
23 |
st.sidebar.title("Explore our PLIP Demo")
|
24 |
|
25 |
PAGES = {
|
26 |
"Introduction": home,
|
27 |
"Text to Image": text2image,
|
28 |
-
"Image
|
29 |
}
|
30 |
|
31 |
page = st.sidebar.radio("", list(PAGES.keys()))
|
|
|
|
|
|
|
1 |
import home
|
|
|
|
|
|
|
|
|
2 |
import text2image
|
3 |
+
import image2image
|
|
|
|
|
4 |
import streamlit as st
|
5 |
+
|
6 |
+
|
7 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
st.sidebar.title("Explore our PLIP Demo")
|
10 |
|
11 |
PAGES = {
|
12 |
"Introduction": home,
|
13 |
"Text to Image": text2image,
|
14 |
+
"Image to Image": image2image,
|
15 |
}
|
16 |
|
17 |
page = st.sidebar.radio("", list(PAGES.keys()))
|
home.py
CHANGED
@@ -11,7 +11,7 @@ def app():
|
|
11 |
intro_markdown = read_markdown_file("introduction.md")
|
12 |
st.markdown(intro_markdown, unsafe_allow_html=True)
|
13 |
|
14 |
-
st.text('An example of
|
15 |
components.html('''
|
16 |
<blockquote class="twitter-tweet">
|
17 |
<a href="https://twitter.com/xxx/status/1580753362059788288"></a>
|
|
|
11 |
intro_markdown = read_markdown_file("introduction.md")
|
12 |
st.markdown(intro_markdown, unsafe_allow_html=True)
|
13 |
|
14 |
+
st.text('An example of tweet:')
|
15 |
components.html('''
|
16 |
<blockquote class="twitter-tweet">
|
17 |
<a href="https://twitter.com/xxx/status/1580753362059788288"></a>
|
image2image.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from plip_support import embed_text
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import requests
|
7 |
+
import tokenizers
|
8 |
+
from io import BytesIO
|
9 |
+
import torch
|
10 |
+
from transformers import (
|
11 |
+
VisionTextDualEncoderModel,
|
12 |
+
AutoFeatureExtractor,
|
13 |
+
AutoTokenizer,
|
14 |
+
CLIPModel,
|
15 |
+
AutoProcessor
|
16 |
+
)
|
17 |
+
import streamlit.components.v1 as components
|
18 |
+
|
19 |
+
|
20 |
+
def embed_images(model, images, processor):
|
21 |
+
inputs = processor(images=images)
|
22 |
+
pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
|
23 |
+
|
24 |
+
with torch.no_grad():
|
25 |
+
embeddings = model.get_image_features(pixel_values=pixel_values)
|
26 |
+
return embeddings
|
27 |
+
|
28 |
+
@st.cache
|
29 |
+
def load_embeddings(embeddings_path):
|
30 |
+
print("loading embeddings")
|
31 |
+
return np.load(embeddings_path)
|
32 |
+
|
33 |
+
@st.cache(
|
34 |
+
hash_funcs={
|
35 |
+
torch.nn.parameter.Parameter: lambda _: None,
|
36 |
+
tokenizers.Tokenizer: lambda _: None,
|
37 |
+
tokenizers.AddedToken: lambda _: None
|
38 |
+
}
|
39 |
+
)
|
40 |
+
def load_path_clip():
|
41 |
+
model = CLIPModel.from_pretrained("vinid/plip")
|
42 |
+
processor = AutoProcessor.from_pretrained("vinid/plip")
|
43 |
+
return model, processor
|
44 |
+
|
45 |
+
|
46 |
+
def app():
|
47 |
+
st.title('PLIP Image Search')
|
48 |
+
|
49 |
+
plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
|
50 |
+
plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t")
|
51 |
+
|
52 |
+
model, processor = load_path_clip()
|
53 |
+
|
54 |
+
image_embedding = load_embeddings("tweet_eval_embeddings.npy")
|
55 |
+
|
56 |
+
query = st.file_uploader("Choose a file")
|
57 |
+
|
58 |
+
|
59 |
+
if query:
|
60 |
+
image = Image.open(query)
|
61 |
+
single_image = embed_images(model, [image], processor)[0].detach().cpu().numpy()
|
62 |
+
|
63 |
+
single_image = single_image/np.linalg.norm(single_image)
|
64 |
+
|
65 |
+
# Sort IDs by cosine-similarity from high to low
|
66 |
+
similarity_scores = single_image.dot(image_embedding.T)
|
67 |
+
id_sorted = np.argsort(similarity_scores)[::-1]
|
68 |
+
|
69 |
+
|
70 |
+
best_id = id_sorted[0]
|
71 |
+
score = similarity_scores[best_id]
|
72 |
+
|
73 |
+
target_weblink = plip_weblink.iloc[best_id]["weblink"]
|
74 |
+
|
75 |
+
st.caption('Most relevant image (similarity = %.4f)' % score)
|
76 |
+
|
77 |
+
components.html('''
|
78 |
+
<blockquote class="twitter-tweet">
|
79 |
+
<a href="%s"></a>
|
80 |
+
</blockquote>
|
81 |
+
<script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
|
82 |
+
</script>
|
83 |
+
''' % target_weblink,
|
84 |
+
height=600)
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|