File size: 2,580 Bytes
4f12085
4f58d6c
4f12085
 
 
 
 
55b4896
 
 
 
 
 
 
d3e4cd0
55b4896
7d696ff
55b4896
7d696ff
 
4f12085
 
9213df3
c3b6f23
e3b861a
9213df3
4f12085
 
260e0f1
16b83cd
55b4896
c3b6f23
a7746f5
 
ad93ba3
e3b861a
4f12085
 
 
cd19d7a
 
 
55b4896
4f12085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9213df3
b5870b3
4f12085
55b4896
4f12085
665ec8e
62a6cd4
8f4e395
 
4f12085
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer


@st.cache(show_spinner=False,
          hash_funcs={AutoTokenizer: lambda _: None,
                      RobertaModel: lambda _: None,
                      dict: lambda _: None})
def load():
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
  image_embeddings = torch.load('embedding.pt')
  links = np.load('data.npy', allow_pickle=True)
  return tokenizer, text_encoder, links, image_embeddings
  
tokenizer, text_encoder, links, image_embeddings = load()



def get_html(url_list):
    html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url in url_list:
        html2 = f"<img style='height: 200px; margin: 2px' src='{escape(url)}'>"  
        html = html + html2
    html += "</div>"
    return html
 
st.cache(show_spinner=False)                           
def image_search(query, top_k=10):
    with torch.no_grad():
        text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
    values, indices = torch.matmul(image_embeddings, text_embedding.T).sort(descending=True)
    return [links[i] for i in indices[:top_k]]


description = '''
# Persian (fa) image search 
- Enter your query and hit enter

Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
'''


def main():
    st.markdown('''
              <style>
              .block-container{
                max-width: 1200px;
              }
              section.main>div:first-child {
                padding-top: 0px;
              }
              section:not(.main)>div:first-child {
                padding-top: 30px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>''',
                unsafe_allow_html=True)
                
    st.sidebar.markdown(description)
    _, c, _ = st.columns((1, 3, 1))
    query = c.text_input('Search Box (type in fa)', value='گل صورتی')
    if len(query) > 0:
        results = image_search(query)
        st.markdown(get_html(results), unsafe_allow_html=True)


if __name__ == '__main__':
    main()