File size: 2,886 Bytes
4f12085
4f58d6c
4f12085
 
 
 
 
55b4896
 
 
 
 
 
 
 
 
 
 
 
4f12085
 
9213df3
c3b6f23
e3b861a
9213df3
4f12085
 
260e0f1
16b83cd
55b4896
c3b6f23
a7746f5
 
4f12085
e3b861a
4f12085
55b4896
 
 
 
 
 
 
4f12085
 
cd19d7a
 
 
55b4896
4f12085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9213df3
b5870b3
4f12085
55b4896
4f12085
55b4896
62a6cd4
8f4e395
 
4f12085
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer


@st.cache(show_spinner=False,
          hash_funcs={AutoTokenizer: lambda _: None,
                      RobertaModel: lambda _: None,
                      dict: lambda _: None})
def load():
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
  image_embeddings = torch.load('embedding.pt').numpy()
  links = np.load('data.npy', allow_pickle=True)
  return model, processor, df, embeddings
  
model, processor, df, embeddings = load()


def get_html(url_list):
    html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url in url_list:
        html2 = f"<img style='height: 200px; margin: 2px' src='{escape(url)}'>"  
        html = html + html2
    html += "</div>"
    return html
 
st.cache(show_spinner=False)                           
def image_search(query, top_k=10):
    with torch.no_grad():
        text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
    values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
    return [links[i] for i in indices[:top_k]]

st.cache(show_spinner=False)     
def image_search_(query, top_k=10):
    with torch.no_grad():
        text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output.numpy()
    results = np.argsort((image_embeddings@text_embedding.T)[:, 0])[-1:-top_k-1:-1]
    return [links[i] for i in results]


description = '''
# Persian (fa) image search 
- Enter your query and hit enter

Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
'''


def main():
    st.markdown('''
              <style>
              .block-container{
                max-width: 1200px;
              }
              section.main>div:first-child {
                padding-top: 0px;
              }
              section:not(.main)>div:first-child {
                padding-top: 30px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>''',
                unsafe_allow_html=True)
                
    st.sidebar.markdown(description)
    _, c, _ = st.columns((1, 3, 1))
    query = c.text_input('Search Box (type in fa)', value='گل صورتی')
    if len(query) > 0:
        results = image_search_(query)
        st.markdown(get_html(results), unsafe_allow_html=True)


if __name__ == '__main__':
    main()