Spaces:
Runtime error
Runtime error
Commit
·
2c174d9
1
Parent(s):
33b2bbe
Update app.py
Browse files
app.py
CHANGED
@@ -7,16 +7,19 @@ import torch
|
|
7 |
import transformers
|
8 |
from transformers import RobertaModel, AutoTokenizer
|
9 |
|
10 |
-
|
11 |
-
def load():
|
12 |
-
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
|
13 |
-
image_embeddings = torch.load('embedding.pt')
|
14 |
-
links = np.load('data.npy', allow_pickle=True)
|
15 |
-
return text_encoder, links, image_embeddings
|
16 |
|
17 |
|
18 |
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
|
19 |
-
text_encoder
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def get_html(url_list, height=224):
|
@@ -26,8 +29,7 @@ def get_html(url_list, height=224):
|
|
26 |
html = html + html2
|
27 |
html += "</div>"
|
28 |
return html
|
29 |
-
|
30 |
-
@st.cache(show_spinner=False)
|
31 |
def image_search(query, top_k=8):
|
32 |
with torch.no_grad():
|
33 |
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
@@ -74,7 +76,7 @@ def main():
|
|
74 |
unsafe_allow_html=True)
|
75 |
st.sidebar.markdown(description)
|
76 |
_, c, _ = st.columns((1, 3, 1))
|
77 |
-
query = c.text_input('', value='دریا')
|
78 |
if len(query) > 0:
|
79 |
results = image_search(query)
|
80 |
st.markdown(get_html(results), unsafe_allow_html=True)
|
|
|
7 |
import transformers
|
8 |
from transformers import RobertaModel, AutoTokenizer
|
9 |
|
10 |
+
#@st.cache(show_spinner=False)
|
11 |
+
#def load():
|
12 |
+
# text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
|
13 |
+
# image_embeddings = torch.load('embedding.pt')
|
14 |
+
# links = np.load('data.npy', allow_pickle=True)
|
15 |
+
# return text_encoder, links, image_embeddings
|
16 |
|
17 |
|
18 |
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
|
19 |
+
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
|
20 |
+
image_embeddings = torch.load('embedding.pt')
|
21 |
+
links = np.load('data.npy', allow_pickle=True)
|
22 |
+
|
23 |
|
24 |
|
25 |
def get_html(url_list, height=224):
|
|
|
29 |
html = html + html2
|
30 |
html += "</div>"
|
31 |
return html
|
32 |
+
|
|
|
33 |
def image_search(query, top_k=8):
|
34 |
with torch.no_grad():
|
35 |
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
|
|
76 |
unsafe_allow_html=True)
|
77 |
st.sidebar.markdown(description)
|
78 |
_, c, _ = st.columns((1, 3, 1))
|
79 |
+
query = c.text_input('Search text', value='دریا')
|
80 |
if len(query) > 0:
|
81 |
results = image_search(query)
|
82 |
st.markdown(get_html(results), unsafe_allow_html=True)
|