File size: 2,790 Bytes
4f12085
6b7167b
4f58d6c
4f12085
 
 
90ecddb
4f12085
 
76fc8a2
4f12085
 
 
f0fa5d9
7a4d3b4
e3b861a
4f12085
 
e3b861a
4f12085
 
 
 
e3b861a
 
4f12085
 
3abd33c
4f12085
3abd33c
 
4f12085
3abd33c
4f12085
c2cf659
3abd33c
4f12085
e3b861a
4f12085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76fc8a2
4f12085
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
import pandas as pd
import numpy as np
from html import escape
import os
import torch
import transformers
from transformers import RobertaModel, AutoTokenizer

@st.cache(hash_funcs={transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: hash}, suppress_st_warning=True, allow_output_mutation=True)
def load():
    text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
    tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
    links = np.load('data.npy', allow_pickle=True)
    image_embeddings = torch.load('embedding.pt')
    return text_encoder, tokenizer, links, image_embeddings


text_encoder, tokenizer, links, image_embeddings = load()


def get_html(url_list, height=224):
    html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url in url_list:
        html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"  
        html = html + html2
    html += "</div>"
    return HTML

def compute_embeddings(query):
    return text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output

@st.cache(show_spinner=False)
def image_search(query, top_k=8):
    with torch.no_grad():
        text_embedding = compute_embeddings(query)
    values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
    return [links[i] for i in indices[:top_k]]


description = '''
# Semantic image search :)
'''


def main():
    st.markdown('''
              <style>
              .block-container{
                max-width: 1200px;
              }
              div.row-widget.stRadio > div{
                flex-direction:row;
                display: flex;
                justify-content: center;
              }
              div.row-widget.stRadio > div > label{
                margin-left: 5px;
                margin-right: 5px;
              }
              section.main>div:first-child {
                padding-top: 0px;
              }
              section:not(.main)>div:first-child {
                padding-top: 30px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>''',
                unsafe_allow_html=True)
    st.sidebar.markdown(description)
    _, c, _ = st.columns((1, 3, 1))
    query = c.text_input('', value='غروب خورشید')
    if len(query) > 0:
        results = image_search(query)
        st.markdown(get_html(results), unsafe_allow_html=True)


if __name__ == '__main__':
    main()