Spaces:
Runtime error
Runtime error
Commit
·
c7e078f
1
Parent(s):
36ee6e5
Update app.py
Browse files
app.py
CHANGED
@@ -17,6 +17,13 @@ tokenizer, text_encoder, links, image_embeddings = load()
|
|
17 |
|
18 |
|
19 |
@st.experimental_memo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def get_html(url_list):
|
21 |
html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
|
22 |
for url in url_list:
|
@@ -24,13 +31,6 @@ def get_html(url_list):
|
|
24 |
html = html + html2
|
25 |
html += "</div>"
|
26 |
return html
|
27 |
-
|
28 |
-
@st.experimental_memo
|
29 |
-
def image_search(query, top_k=10):
|
30 |
-
with torch.no_grad():
|
31 |
-
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
32 |
-
_, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
|
33 |
-
return [links[i] for i in indices[:top_k]]
|
34 |
|
35 |
|
36 |
description = '''
|
@@ -68,6 +68,7 @@ def main():
|
|
68 |
st.sidebar.markdown(description)
|
69 |
_, c, _ = st.columns((1, 3, 1))
|
70 |
query = c.text_input('Search Box (type in fa)', value='گل صورتی')
|
|
|
71 |
if len(query) > 0:
|
72 |
results = image_search(query)
|
73 |
st.markdown(get_html(results), unsafe_allow_html=True)
|
|
|
17 |
|
18 |
|
19 |
@st.experimental_memo
|
20 |
+
def image_search(query, top_k=10):
|
21 |
+
with torch.no_grad():
|
22 |
+
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
23 |
+
_, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
|
24 |
+
return [links[i] for i in indices[:top_k]]
|
25 |
+
|
26 |
+
|
27 |
def get_html(url_list):
|
28 |
html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
|
29 |
for url in url_list:
|
|
|
31 |
html = html + html2
|
32 |
html += "</div>"
|
33 |
return html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
description = '''
|
|
|
68 |
st.sidebar.markdown(description)
|
69 |
_, c, _ = st.columns((1, 3, 1))
|
70 |
query = c.text_input('Search Box (type in fa)', value='گل صورتی')
|
71 |
+
c.text("It'll take about 20s to load all new images")
|
72 |
if len(query) > 0:
|
73 |
results = image_search(query)
|
74 |
st.markdown(get_html(results), unsafe_allow_html=True)
|