Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd, numpy as np | |
from html import escape | |
import os | |
from transformers import CLIPProcessor, CLIPModel | |
def load(): | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')} | |
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')} | |
for k in [0, 1]: | |
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True))) | |
return model, processor, df, embeddings | |
model, processor, df, embeddings = load() | |
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'} | |
def get_html(url_list, height=200): | |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" | |
for url, title, link in url_list: | |
html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>" | |
if len(link) > 0: | |
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>" | |
html = html + html2 | |
html += "</div>" | |
return html | |
def compute_text_embeddings(list_of_strings): | |
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) | |
return model.get_text_features(**inputs) | |
st.cache(show_spinner=False) | |
def image_search(query, corpus, n_results=24): | |
text_embeddings = compute_text_embeddings([query]).detach().numpy() | |
k = 0 if corpus == 'Unsplash' else 1 | |
results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1] | |
return [(df[k].iloc[i]['path'], | |
df[k].iloc[i]['tooltip'] + source[k], | |
df[k].iloc[i]['link']) for i in results] | |
description = ''' | |
# Semantic image search | |
**Enter your query and hit enter** | |
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* | |
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe* | |
''' | |
def main(): | |
st.markdown(''' | |
<style> | |
.block-container{ | |
max-width: 1200px; | |
} | |
div.row-widget.stRadio > div{ | |
flex-direction:row; | |
display: flex; | |
justify-content: center; | |
} | |
div.row-widget.stRadio > div > label{ | |
margin-left: 5px; | |
margin-right: 5px; | |
} | |
section.main>div:first-child { | |
padding-top: 0px; | |
} | |
section:not(.main)>div:first-child { | |
padding-top: 30px; | |
} | |
div.reportview-container > section:first-child{ | |
max-width: 320px; | |
} | |
#MainMenu { | |
visibility: hidden; | |
} | |
footer { | |
visibility: hidden; | |
} | |
</style>''', | |
unsafe_allow_html=True) | |
st.sidebar.markdown(description) | |
_, c, _ = st.columns((1, 3, 1)) | |
query = c.text_input('', value='clouds at sunset') | |
corpus = st.radio('', ["Unsplash","Movies"]) | |
if len(query) > 0: | |
results = image_search(query, corpus) | |
st.markdown(get_html(results), unsafe_allow_html=True) | |
if __name__ == '__main__': | |
main() | |