Spaces:
Runtime error
Runtime error
Commit
·
55b4896
1
Parent(s):
c3b6f23
Update app.py
Browse files
app.py
CHANGED
@@ -5,11 +5,18 @@ import torch
|
|
5 |
from transformers import RobertaModel, AutoTokenizer
|
6 |
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def get_html(url_list):
|
@@ -20,19 +27,26 @@ def get_html(url_list):
|
|
20 |
html += "</div>"
|
21 |
return html
|
22 |
|
23 |
-
|
24 |
def image_search(query, top_k=10):
|
25 |
with torch.no_grad():
|
26 |
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
27 |
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
|
28 |
return [links[i] for i in indices[:top_k]]
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
description = '''
|
32 |
# Persian (fa) image search
|
33 |
- Enter your query and hit enter
|
34 |
|
35 |
-
Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from Unsplash
|
36 |
'''
|
37 |
|
38 |
|
@@ -62,9 +76,9 @@ def main():
|
|
62 |
|
63 |
st.sidebar.markdown(description)
|
64 |
_, c, _ = st.columns((1, 3, 1))
|
65 |
-
query = c.text_input('Search Box', value='گل صورتی')
|
66 |
if len(query) > 0:
|
67 |
-
results =
|
68 |
st.markdown(get_html(results), unsafe_allow_html=True)
|
69 |
|
70 |
|
|
|
5 |
from transformers import RobertaModel, AutoTokenizer
|
6 |
|
7 |
|
8 |
+
@st.cache(show_spinner=False,
|
9 |
+
hash_funcs={AutoTokenizer: lambda _: None,
|
10 |
+
RobertaModel: lambda _: None,
|
11 |
+
dict: lambda _: None})
|
12 |
+
def load():
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
|
14 |
+
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
|
15 |
+
image_embeddings = torch.load('embedding.pt').numpy()
|
16 |
+
links = np.load('data.npy', allow_pickle=True)
|
17 |
+
return model, processor, df, embeddings
|
18 |
+
|
19 |
+
model, processor, df, embeddings = load()
|
20 |
|
21 |
|
22 |
def get_html(url_list):
|
|
|
27 |
html += "</div>"
|
28 |
return html
|
29 |
|
30 |
+
st.cache(show_spinner=False)
|
31 |
def image_search(query, top_k=10):
|
32 |
with torch.no_grad():
|
33 |
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
34 |
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
|
35 |
return [links[i] for i in indices[:top_k]]
|
36 |
|
37 |
+
st.cache(show_spinner=False)
|
38 |
+
def image_search_(query, top_k=10):
|
39 |
+
with torch.no_grad():
|
40 |
+
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output.numpy()
|
41 |
+
results = np.argsort((image_embeddings@text_embedding.T)[:, 0])[-1:-top_k-1:-1]
|
42 |
+
return [links[i] for i in results]
|
43 |
+
|
44 |
|
45 |
description = '''
|
46 |
# Persian (fa) image search
|
47 |
- Enter your query and hit enter
|
48 |
|
49 |
+
Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
|
50 |
'''
|
51 |
|
52 |
|
|
|
76 |
|
77 |
st.sidebar.markdown(description)
|
78 |
_, c, _ = st.columns((1, 3, 1))
|
79 |
+
query = c.text_input('Search Box (type in fa)', value='گل صورتی')
|
80 |
if len(query) > 0:
|
81 |
+
results = image_search_(query)
|
82 |
st.markdown(get_html(results), unsafe_allow_html=True)
|
83 |
|
84 |
|