File size: 4,016 Bytes
c81898a ba03fb2 c81898a b59b1d0 c81898a a55de09 c81898a a55de09 000d238 c81898a a55de09 c81898a ba03fb2 c81898a ba03fb2 c81898a a55de09 c81898a 000d238 c81898a a55de09 c81898a a55de09 c81898a a55de09 c81898a a55de09 c81898a ff968d5 9ea8c8c ba03fb2 9ea8c8c a55de09 c81898a a55de09 c81898a ff968d5 c81898a 7600dc3 555584f 7600dc3 586f7e5 55fea56 586f7e5 c81898a a55de09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import streamlit as st
import pandas as pd, numpy as np
from html import escape
import os
from transformers import CLIPProcessor, CLIPModel
@st.cache(
show_spinner=False,
hash_funcs={
CLIPModel: lambda _: None,
CLIPProcessor: lambda _: None,
dict: lambda _: None,
},
)
def load():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
for k in [0, 1]:
embeddings[k] = np.divide(
embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True))
)
return model, processor, df, embeddings
model, processor, df, embeddings = load()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
def get_html(url_list, height=200):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url, title, link in url_list:
html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
if len(link) > 0:
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
html = html + html2
html += "</div>"
return html
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
text_embeddings = compute_text_embeddings([query]).detach().numpy()
k = 0 if corpus == "Unsplash" else 1
results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[
-1 : -n_results - 1 : -1
]
return [
(
df[k].iloc[i]["path"],
df[k].iloc[i]["tooltip"] + source[k],
df[k].iloc[i]["link"],
)
for i in results
]
description = """
# Semantic image search
**Enter your query and hit enter**
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
"""
def main():
st.markdown(
"""
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>""",
unsafe_allow_html=True,
)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input("", value="clouds at sunset")
corpus = st.radio("", ["Unsplash", "Movies"])
if len(query) > 0:
results = image_search(query, corpus)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == "__main__":
main()
|