SajjadAyoubi commited on
Commit
9d1171b
·
1 Parent(s): 3abd33c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -7,7 +7,11 @@ import torch
7
  import transformers
8
  from transformers import RobertaModel, AutoTokenizer
9
 
10
- @st.cache(hash_funcs={transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: hash}, suppress_st_warning=True, allow_output_mutation=True)
 
 
 
 
11
  def load():
12
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
13
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
@@ -28,12 +32,12 @@ def get_html(url_list, height=224):
28
  return HTML
29
 
30
  def compute_embeddings(query):
31
- return text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
 
32
 
33
- @st.cache(show_spinner=False)
34
  def image_search(query, top_k=8):
35
- with torch.no_grad():
36
- text_embedding = compute_embeddings(query)
37
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
38
  return [links[i] for i in indices[:top_k]]
39
 
 
7
  import transformers
8
  from transformers import RobertaModel, AutoTokenizer
9
 
10
+ #@st.cache(hash_funcs={transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: hash}, suppress_st_warning=True, allow_output_mutation=True)
11
+ @st.cache(show_spinner=False,
12
+ hash_funcs={RobertaModel: lambda _: None,
13
+ AutoTokenizer: lambda _: None,
14
+ dict: lambda _: None})
15
  def load():
16
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
17
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
 
32
  return HTML
33
 
34
  def compute_embeddings(query):
35
+ with torch.no_grad():
36
+ return text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
37
 
38
+ st.cache(show_spinner=False)
39
  def image_search(query, top_k=8):
40
+ text_embedding = compute_embeddings(query)
 
41
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
42
  return [links[i] for i in indices[:top_k]]
43