File size: 2,449 Bytes
4f12085
4f58d6c
4f12085
 
 
 
 
9e38862
55b4896
 
 
d3e4cd0
55b4896
7d696ff
55b4896
7d696ff
 
4f12085
e6c4b07
c7e078f
 
 
 
 
 
 
9213df3
c3b6f23
e3b861a
9213df3
4f12085
 
260e0f1
4f12085
 
 
cd19d7a
 
 
55b4896
4f12085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9213df3
b5870b3
4f12085
55b4896
c7e078f
4f12085
665ec8e
62a6cd4
8f4e395
 
4f12085
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer


@st.experimental_memo
def load():
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
  image_embeddings = torch.load('embedding.pt')
  links = np.load('data.npy', allow_pickle=True)
  return tokenizer, text_encoder, links, image_embeddings
  
tokenizer, text_encoder, links, image_embeddings = load()


@st.experimental_memo
def image_search(query, top_k=10):
    with torch.no_grad():
        text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
    _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
    return [links[i] for i in indices[:top_k]]


def get_html(url_list):
    html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url in url_list:
        html2 = f"<img style='height: 200px; margin: 2px' src='{escape(url)}'>"  
        html = html + html2
    html += "</div>"
    return html


description = '''
# Persian (fa) image search 
- Enter your query and hit enter

Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
'''


def main():
    st.markdown('''
              <style>
              .block-container{
                max-width: 1200px;
              }
              section.main>div:first-child {
                padding-top: 0px;
              }
              section:not(.main)>div:first-child {
                padding-top: 30px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>''',
                unsafe_allow_html=True)
                
    st.sidebar.markdown(description)
    _, c, _ = st.columns((1, 3, 1))
    query = c.text_input('Search Box (type in fa)', value='گل صورتی')
    c.text("It'll take about 20s to load all new images")
    if len(query) > 0:
        results = image_search(query)
        st.markdown(get_html(results), unsafe_allow_html=True)


if __name__ == '__main__':
    main()