Spaces:
Runtime error
Runtime error
File size: 2,805 Bytes
4f12085 6b7167b 4f58d6c 4f12085 90ecddb 4f12085 91228a0 4f12085 f0fa5d9 7a4d3b4 e3b861a 4f12085 e3b861a 4f12085 e3b861a 4f12085 260e0f1 4f12085 3abd33c 9d1171b 225ca56 6cbb017 4f12085 9d1171b 4f12085 e3b861a 4f12085 76fc8a2 4f12085 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import streamlit as st
import pandas as pd
import numpy as np
from html import escape
import os
import torch
import transformers
from transformers import RobertaModel, AutoTokenizer
@st.cache(hash_funcs={transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: hash}, suppress_st_warning=True, allow_output_mutation=True)
def load():
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
links = np.load('data.npy', allow_pickle=True)
image_embeddings = torch.load('embedding.pt')
return text_encoder, tokenizer, links, image_embeddings
text_encoder, tokenizer, links, image_embeddings = load()
def get_html(url_list, height=224):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url in url_list:
html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
html = html + html2
html += "</div>"
return html
def compute_embeddings(query):
with torch.no_grad():
return text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
st.cache(show_spinner=False)
def image_search(query, top_k=8):
text_embedding = compute_embeddings(query)
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
return [links[i] for i in indices[:top_k]]
description = '''
# Semantic image search :)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('', value='غروب خورشید')
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main() |