Spaces:
Runtime error
Runtime error
File size: 2,204 Bytes
4f12085 4f58d6c 4f12085 28dd995 2c174d9 4f12085 9213df3 c3b6f23 e3b861a 9213df3 4f12085 260e0f1 16b83cd 9492b8f c3b6f23 a7746f5 4f12085 e3b861a 4f12085 cd19d7a c3b6f23 4f12085 9213df3 b5870b3 4f12085 cd19d7a 4f12085 62a6cd4 8f4e395 4f12085 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
image_embeddings = torch.load('embedding.pt')
links = np.load('data.npy', allow_pickle=True)
def get_html(url_list):
html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url in url_list:
html2 = f"<img style='height: 200px; margin: 2px' src='{escape(url)}'>"
html = html + html2
html += "</div>"
return html
def image_search(query, top_k=10):
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
return [links[i] for i in indices[:top_k]]
description = '''
# Persian (fa) image search
- Enter your query and hit enter
Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from Unsplash
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('Search Box', value='گل صورتی')
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main() |