Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
from html import escape | |
import torch | |
from transformers import RobertaModel, AutoTokenizer | |
def load(): | |
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text') | |
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval() | |
image_embeddings = torch.load('embedding.pt') | |
links = np.load('data.npy', allow_pickle=True) | |
return tokenizer, text_encoder, links, image_embeddings | |
tokenizer, text_encoder, links, image_embeddings = load() | |
def image_search(query, top_k=10): | |
with torch.no_grad(): | |
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output | |
_, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True) | |
return [links[i] for i in indices[:top_k]] | |
def get_html(url_list): | |
html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" | |
for url in url_list: | |
html2 = f"<img style='height: 200px; margin: 2px' src='{escape(url)}'>" | |
html = html + html2 | |
html += "</div>" | |
return html | |
description = ''' | |
# Persian (fa) image search | |
- Enter your query and hit enter | |
Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/) | |
''' | |
def main(): | |
st.markdown(''' | |
<style> | |
.block-container{ | |
max-width: 1200px; | |
} | |
section.main>div:first-child { | |
padding-top: 0px; | |
} | |
section:not(.main)>div:first-child { | |
padding-top: 30px; | |
} | |
div.reportview-container > section:first-child{ | |
max-width: 320px; | |
} | |
#MainMenu { | |
visibility: hidden; | |
} | |
footer { | |
visibility: hidden; | |
} | |
</style>''', | |
unsafe_allow_html=True) | |
st.sidebar.markdown(description) | |
_, c, _ = st.columns((1, 3, 1)) | |
query = c.text_input('Search Box (type in fa)', value='گل صورتی') | |
c.text("It'll take about 20s to load all new images") | |
if len(query) > 0: | |
results = image_search(query) | |
st.markdown(get_html(results), unsafe_allow_html=True) | |
if __name__ == '__main__': | |
main() |