File size: 2,413 Bytes
4f12085
4f58d6c
4f12085
 
 
 
 
9e38862
55b4896
 
 
d3e4cd0
55b4896
7d696ff
55b4896
7d696ff
 
4f12085
e6c4b07
9213df3
c3b6f23
e3b861a
9213df3
4f12085
 
260e0f1
16b83cd
36ee6e5
c3b6f23
a7746f5
 
cf4befe
e3b861a
4f12085
 
 
cd19d7a
 
 
55b4896
4f12085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9213df3
b5870b3
4f12085
55b4896
4f12085
665ec8e
62a6cd4
8f4e395
 
4f12085
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer


@st.experimental_memo
def load():
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
  image_embeddings = torch.load('embedding.pt')
  links = np.load('data.npy', allow_pickle=True)
  return tokenizer, text_encoder, links, image_embeddings
  
tokenizer, text_encoder, links, image_embeddings = load()


@st.experimental_memo
def get_html(url_list):
    html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url in url_list:
        html2 = f"<img style='height: 200px; margin: 2px' src='{escape(url)}'>"  
        html = html + html2
    html += "</div>"
    return html
 
@st.experimental_memo
def image_search(query, top_k=10):
    with torch.no_grad():
        text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
    _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
    return [links[i] for i in indices[:top_k]]


description = '''
# Persian (fa) image search 
- Enter your query and hit enter

Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
'''


def main():
    st.markdown('''
              <style>
              .block-container{
                max-width: 1200px;
              }
              section.main>div:first-child {
                padding-top: 0px;
              }
              section:not(.main)>div:first-child {
                padding-top: 30px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>''',
                unsafe_allow_html=True)
                
    st.sidebar.markdown(description)
    _, c, _ = st.columns((1, 3, 1))
    query = c.text_input('Search Box (type in fa)', value='گل صورتی')
    if len(query) > 0:
        results = image_search(query)
        st.markdown(get_html(results), unsafe_allow_html=True)


if __name__ == '__main__':
    main()