SajjadAyoubi commited on
Commit
3abd33c
·
1 Parent(s): 76fc8a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -25,13 +25,15 @@ def get_html(url_list, height=224):
25
  html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
26
  html = html + html2
27
  html += "</div>"
28
- return html
29
 
 
 
30
 
31
- #@st.cache(show_spinner=False)
32
  def image_search(query, top_k=8):
33
  with torch.no_grad():
34
- text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
35
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
36
  return [links[i] for i in indices[:top_k]]
37
 
 
25
  html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
26
  html = html + html2
27
  html += "</div>"
28
+ return HTML
29
 
30
+ def compute_embeddings(query):
31
+ return text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
32
 
33
+ @st.cache(show_spinner=False)
34
  def image_search(query, top_k=8):
35
  with torch.no_grad():
36
+ text_embedding = compute_embeddings(query)
37
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
38
  return [links[i] for i in indices[:top_k]]
39