CLIPfa-Demo / app.py
SajjadAyoubi's picture
Update app.py
cda97bc
raw
history blame
2.99 kB
import streamlit as st
import pandas as pd
import numpy as np
from html import escape
import os
import torch
import transformers
from transformers import RobertaModel, AutoTokenizer
#@st.cache(hash_funcs={transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: hash}, suppress_st_warning=True, allow_output_mutation=True)
@st.cache(show_spinner=False,
hash_funcs={RobertaModel: lambda _: None,
AutoTokenizer: lambda _: None,
dict: lambda _: None})
def load():
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
links = np.load('data.npy', allow_pickle=True)
image_embeddings = torch.load('embedding.pt')
return text_encoder, tokenizer, links, image_embeddings
text_encoder, tokenizer, links, image_embeddings = load()
def get_html(url_list, height=224):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url in url_list:
html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
html = html + html2
html += "</div>"
return html
def compute_embeddings(query):
with torch.no_grad():
return text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
st.cache(show_spinner=False)
def image_search(query, top_k=8):
text_embedding = compute_embeddings(query)
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
return [links[i] for i in indices[:top_k]]
description = '''
# Semantic image search :)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('', value='غروب خورشید')
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()