SajjadAyoubi commited on
Commit
c7e078f
·
1 Parent(s): 36ee6e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -17,6 +17,13 @@ tokenizer, text_encoder, links, image_embeddings = load()
17
 
18
 
19
  @st.experimental_memo
 
 
 
 
 
 
 
20
  def get_html(url_list):
21
  html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
22
  for url in url_list:
@@ -24,13 +31,6 @@ def get_html(url_list):
24
  html = html + html2
25
  html += "</div>"
26
  return html
27
-
28
- @st.experimental_memo
29
- def image_search(query, top_k=10):
30
- with torch.no_grad():
31
- text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
32
- _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
33
- return [links[i] for i in indices[:top_k]]
34
 
35
 
36
  description = '''
@@ -68,6 +68,7 @@ def main():
68
  st.sidebar.markdown(description)
69
  _, c, _ = st.columns((1, 3, 1))
70
  query = c.text_input('Search Box (type in fa)', value='گل صورتی')
 
71
  if len(query) > 0:
72
  results = image_search(query)
73
  st.markdown(get_html(results), unsafe_allow_html=True)
 
17
 
18
 
19
  @st.experimental_memo
20
+ def image_search(query, top_k=10):
21
+ with torch.no_grad():
22
+ text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
23
+ _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
24
+ return [links[i] for i in indices[:top_k]]
25
+
26
+
27
  def get_html(url_list):
28
  html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
29
  for url in url_list:
 
31
  html = html + html2
32
  html += "</div>"
33
  return html
 
 
 
 
 
 
 
34
 
35
 
36
  description = '''
 
68
  st.sidebar.markdown(description)
69
  _, c, _ = st.columns((1, 3, 1))
70
  query = c.text_input('Search Box (type in fa)', value='گل صورتی')
71
+ c.text("It'll take about 20s to load all new images")
72
  if len(query) > 0:
73
  results = image_search(query)
74
  st.markdown(get_html(results), unsafe_allow_html=True)