SajjadAyoubi commited on
Commit
2c174d9
·
1 Parent(s): 33b2bbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -7,16 +7,19 @@ import torch
7
  import transformers
8
  from transformers import RobertaModel, AutoTokenizer
9
 
10
- @st.cache(show_spinner=False)
11
- def load():
12
- text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
13
- image_embeddings = torch.load('embedding.pt')
14
- links = np.load('data.npy', allow_pickle=True)
15
- return text_encoder, links, image_embeddings
16
 
17
 
18
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
19
- text_encoder, links, image_embeddings = load()
 
 
 
20
 
21
 
22
  def get_html(url_list, height=224):
@@ -26,8 +29,7 @@ def get_html(url_list, height=224):
26
  html = html + html2
27
  html += "</div>"
28
  return html
29
-
30
- @st.cache(show_spinner=False)
31
  def image_search(query, top_k=8):
32
  with torch.no_grad():
33
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
@@ -74,7 +76,7 @@ def main():
74
  unsafe_allow_html=True)
75
  st.sidebar.markdown(description)
76
  _, c, _ = st.columns((1, 3, 1))
77
- query = c.text_input('', value='دریا')
78
  if len(query) > 0:
79
  results = image_search(query)
80
  st.markdown(get_html(results), unsafe_allow_html=True)
 
7
  import transformers
8
  from transformers import RobertaModel, AutoTokenizer
9
 
10
+ #@st.cache(show_spinner=False)
11
+ #def load():
12
+ # text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
13
+ # image_embeddings = torch.load('embedding.pt')
14
+ # links = np.load('data.npy', allow_pickle=True)
15
+ # return text_encoder, links, image_embeddings
16
 
17
 
18
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
19
+ text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
20
+ image_embeddings = torch.load('embedding.pt')
21
+ links = np.load('data.npy', allow_pickle=True)
22
+
23
 
24
 
25
  def get_html(url_list, height=224):
 
29
  html = html + html2
30
  html += "</div>"
31
  return html
32
+
 
33
  def image_search(query, top_k=8):
34
  with torch.no_grad():
35
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
 
76
  unsafe_allow_html=True)
77
  st.sidebar.markdown(description)
78
  _, c, _ = st.columns((1, 3, 1))
79
+ query = c.text_input('Search text', value='دریا')
80
  if len(query) > 0:
81
  results = image_search(query)
82
  st.markdown(get_html(results), unsafe_allow_html=True)