CLIPfa-Demo / app.py
SajjadAyoubi's picture
Update app.py
1710a12
raw
history blame
2.71 kB
import streamlit as st
import pandas as pd
import numpy as np
from html import escape
import os
import torch
import transformers
from transformers import RobertaModel, AutoTokenizer
#@st.cache(show_spinner=False)
#def load():
# text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
# image_embeddings = torch.load('embedding.pt')
# links = np.load('data.npy', allow_pickle=True)
# return text_encoder, links, image_embeddings
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
image_embeddings = torch.load('embedding.pt')
links = np.load('data.npy', allow_pickle=True)
def get_html(url_list, height=180):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url in url_list:
html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
html = html + html2
html += "</div>"
return html
def image_search(query, top_k=8):
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
return [links[i] for i in indices[:top_k]]
description = '''
# Semantic image search :)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('Search text', value='مرغ دریای')
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()