SajjadAyoubi commited on
Commit
55b4896
·
1 Parent(s): c3b6f23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -5,11 +5,18 @@ import torch
5
  from transformers import RobertaModel, AutoTokenizer
6
 
7
 
8
- tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
9
- text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
10
- image_embeddings = torch.load('embedding.pt')
11
- links = np.load('data.npy', allow_pickle=True)
12
-
 
 
 
 
 
 
 
13
 
14
 
15
  def get_html(url_list):
@@ -20,19 +27,26 @@ def get_html(url_list):
20
  html += "</div>"
21
  return html
22
 
23
-
24
  def image_search(query, top_k=10):
25
  with torch.no_grad():
26
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
27
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
28
  return [links[i] for i in indices[:top_k]]
29
 
 
 
 
 
 
 
 
30
 
31
  description = '''
32
  # Persian (fa) image search
33
  - Enter your query and hit enter
34
 
35
- Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from Unsplash
36
  '''
37
 
38
 
@@ -62,9 +76,9 @@ def main():
62
 
63
  st.sidebar.markdown(description)
64
  _, c, _ = st.columns((1, 3, 1))
65
- query = c.text_input('Search Box', value='گل صورتی')
66
  if len(query) > 0:
67
- results = image_search(query)
68
  st.markdown(get_html(results), unsafe_allow_html=True)
69
 
70
 
 
5
  from transformers import RobertaModel, AutoTokenizer
6
 
7
 
8
+ @st.cache(show_spinner=False,
9
+ hash_funcs={AutoTokenizer: lambda _: None,
10
+ RobertaModel: lambda _: None,
11
+ dict: lambda _: None})
12
+ def load():
13
+ tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
14
+ text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
15
+ image_embeddings = torch.load('embedding.pt').numpy()
16
+ links = np.load('data.npy', allow_pickle=True)
17
+ return model, processor, df, embeddings
18
+
19
+ model, processor, df, embeddings = load()
20
 
21
 
22
  def get_html(url_list):
 
27
  html += "</div>"
28
  return html
29
 
30
+ st.cache(show_spinner=False)
31
  def image_search(query, top_k=10):
32
  with torch.no_grad():
33
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
34
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
35
  return [links[i] for i in indices[:top_k]]
36
 
37
+ st.cache(show_spinner=False)
38
+ def image_search_(query, top_k=10):
39
+ with torch.no_grad():
40
+ text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output.numpy()
41
+ results = np.argsort((image_embeddings@text_embedding.T)[:, 0])[-1:-top_k-1:-1]
42
+ return [links[i] for i in results]
43
+
44
 
45
  description = '''
46
  # Persian (fa) image search
47
  - Enter your query and hit enter
48
 
49
+ Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
50
  '''
51
 
52
 
 
76
 
77
  st.sidebar.markdown(description)
78
  _, c, _ = st.columns((1, 3, 1))
79
+ query = c.text_input('Search Box (type in fa)', value='گل صورتی')
80
  if len(query) > 0:
81
+ results = image_search_(query)
82
  st.markdown(get_html(results), unsafe_allow_html=True)
83
 
84