import streamlit as st import pandas as pd, numpy as np from html import escape import os from transformers import CLIPProcessor, CLIPModel @st.cache( show_spinner=False, hash_funcs={ CLIPModel: lambda _: None, CLIPProcessor: lambda _: None, dict: lambda _: None, }, ) def load(): model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")} for k in [0, 1]: embeddings[k] = np.divide( embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True)) ) return model, processor, df, embeddings model, processor, df, embeddings = load() source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} def get_html(url_list, height=200): html = "
" for url, title, link in url_list: html2 = f"" if len(link) > 0: html2 = f"" + html2 + "" html = html + html2 html += "
" return html def compute_text_embeddings(list_of_strings): inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) return model.get_text_features(**inputs) st.cache(show_spinner=False) def image_search(query, corpus, n_results=24): text_embeddings = compute_text_embeddings([query]).detach().numpy() k = 0 if corpus == "Unsplash" else 1 results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[ -1 : -n_results - 1 : -1 ] return [ ( df[k].iloc[i]["path"], df[k].iloc[i]["tooltip"] + source[k], df[k].iloc[i]["link"], ) for i in results ] description = """ # Semantic image search **Enter your query and hit enter** *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* *Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe* """ def main(): st.markdown( """ """, unsafe_allow_html=True, ) st.sidebar.markdown(description) _, c, _ = st.columns((1, 3, 1)) query = c.text_input("", value="clouds at sunset") corpus = st.radio("", ["Unsplash", "Movies"]) if len(query) > 0: results = image_search(query, corpus) st.markdown(get_html(results), unsafe_allow_html=True) if __name__ == "__main__": main()