File size: 2,661 Bytes
e3d923c 66410c3 e3d923c 66410c3 fe3bafe 66410c3 e3d923c fe3bafe e3d923c 66410c3 fe3bafe e3d923c 66410c3 e3d923c 66410c3 fe3bafe e3d923c 66410c3 e3d923c 66410c3 e3d923c fe3bafe e3d923c 66410c3 e3d923c 66410c3 e3d923c 66410c3 e3d923c 2b95d34 fe3bafe e3d923c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import pickle
import gradio as gr
import sentence_transformers
from PIL import Image
from sentence_transformers import SentenceTransformer, util
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
# (Pdb) query_emb.shape
# torch.Size([1, 512])
# (Pdb) img_emb.shape
# (24996, 512)
## Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
# Open the precomputed embeddings
emb_filename = "lv-handbags.pkl"
# emb_filename = 'unsplash-25k-photos-embeddings.pkl'
with open(emb_filename, "rb") as fIn:
img_names, img_emb = pickle.load(fIn)
# print(f'img_emb: {print(img_emb)}')
# print(f'img_names: {print(img_names)}')
def search_text(query, top_k=4):
""" " Search an image based on the text query.
Args:
query ([string]): [query you want search for]
top_k (int, optional): [Amount of images o return]. Defaults to 1.
Returns:
[list]: [list of images that are related to the query.]
"""
# First, we encode the query.
inputs = tokenizer([query], padding=True, return_tensors="pt")
query_emb = model.get_text_features(**inputs)
# import pdb; pdb.set_trace()
# Then, we use the util.semantic_search function, which computes the cosine-similarity
# between the query embedding and all image embeddings.
# It then returns the top_k highest ranked images, which we output
hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0]
image = []
for hit in hits:
# print(img_names[hit['corpus_id']])
# object = Image.open(os.path.join("photos/", img_names[hit['corpus_id']]))
object = Image.open(os.path.join("lvphotos/", img_names[hit["corpus_id"]]))
image.append(object)
# print(f'array length is: {len(image)}')
return image
iface = gr.Interface(
title="Hushh Vibe Search Model on Louis Vuitton API",
description="Quick demo of using text to perform vector search on an image collection",
article="TBD",
fn=search_text,
inputs=[
gr.Textbox(
lines=4,
label="Write what you are looking for in an image...",
placeholder="Text Here...",
)
],
outputs=[
gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery", columns=2
)
],
examples=[
[("Vacation Star")],
[("Rock Star")],
[("Barbie")],
[("Small Purse")],
[("Big Bag")],
],
).launch(debug=True)
|