SajjadAyoubi commited on
Commit
665ec8e
·
1 Parent(s): 7d696ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -9
app.py CHANGED
@@ -32,16 +32,9 @@ st.cache(show_spinner=False)
32
  def image_search(query, top_k=10):
33
  with torch.no_grad():
34
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
35
- values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
36
  return [links[i] for i in indices[:top_k]]
37
 
38
- st.cache(show_spinner=False)
39
- def image_search_(query, top_k=10):
40
- with torch.no_grad():
41
- text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output.numpy()
42
- results = np.argsort((image_embeddings@text_embedding.T)[:, 0])[-1:-top_k-1:-1]
43
- return [links[i] for i in results]
44
-
45
 
46
  description = '''
47
  # Persian (fa) image search
@@ -79,7 +72,7 @@ def main():
79
  _, c, _ = st.columns((1, 3, 1))
80
  query = c.text_input('Search Box (type in fa)', value='گل صورتی')
81
  if len(query) > 0:
82
- results = image_search_(query)
83
  st.markdown(get_html(results), unsafe_allow_html=True)
84
 
85
 
 
32
  def image_search(query, top_k=10):
33
  with torch.no_grad():
34
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
35
+ values, indices = torch.dot(image_embeddings, text_embedding.T).sort(descending=True)
36
  return [links[i] for i in indices[:top_k]]
37
 
 
 
 
 
 
 
 
38
 
39
  description = '''
40
  # Persian (fa) image search
 
72
  _, c, _ = st.columns((1, 3, 1))
73
  query = c.text_input('Search Box (type in fa)', value='گل صورتی')
74
  if len(query) > 0:
75
+ results = image_search(query)
76
  st.markdown(get_html(results), unsafe_allow_html=True)
77
 
78