Spaces:
Runtime error
Runtime error
File size: 4,114 Bytes
dbb1308 83305ef cf029f9 83305ef 77403d5 83305ef 77403d5 3bc904e 50cd9e5 9e0ec3b a42c923 dbb1308 d21fc46 77403d5 fba895a dbb1308 77403d5 b2f105a a193d64 dbb1308 a193d64 83305ef b33e6dd cf029f9 50cd9e5 cf029f9 83305ef a193d64 83305ef a193d64 83305ef dbb1308 77403d5 9e0ec3b b2f105a dbb1308 b2f105a 50cd9e5 dbb1308 83305ef 50cd9e5 b33e6dd cf029f9 83305ef 50cd9e5 83305ef 50cd9e5 dbb1308 a193d64 50cd9e5 a193d64 50cd9e5 a193d64 50cd9e5 a193d64 50cd9e5 9e0ec3b 50cd9e5 b2f105a dbb1308 a42c923 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
from upstash_vector import AsyncIndex
from transformers import AutoFeatureExtractor, AutoModel
from datasets import load_dataset
index = AsyncIndex.from_env()
model_ckpt = "google/vit-base-patch16-224-in21k"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size
dataset = load_dataset("BounharAbdelaziz/Face-Aging-Dataset")
TOP_K = 1000
BASE_COUNT=4
MAX_COUNT = 10
with gr.Blocks() as demo:
gr.Markdown(
"""
# Find Your Twins
Upload your face and find the most similar faces from [Face Aging Dataset](https://huggingface.co/datasets/BounharAbdelaziz/Face-Aging-Dataset) using Google's [VIT](https://huggingface.co/google/vit-base-patch16-224-in21k) model. For best results please use 1x1 ratio face images, take a look at examples. Also increasing count in the advanced section results with more accurate searches. Disclaimer, this demo doesn't find your twins :), it finds similar face parts, shapes, features(nose, cheek, face, forehead shapes) that are encoded in the model. The Vector similarity search is powered by [Upstash Vector](https://upstash.com) 🚀. You can check our blog [post](https://huggingface.co/blog/omerXfaruq/serverless-image-similarity-with-upstash-vector) to learn more.
"""
)
with gr.Tab("Basic"):
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil")
with gr.Column(scale=2):
output_images = gr.Gallery()
@input_image.change(inputs=input_image, outputs=output_images)
async def find_similar_faces(image):
if image is None:
return None
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
embed = outputs.last_hidden_state[0][0]
result = await index.query(vector=embed.tolist(), top_k=TOP_K)
return [dataset["train"][int(vector.id)]["image"] for vector in result[:BASE_COUNT]]
gr.Examples(
examples=[
dataset["train"][6]["image"],
dataset["train"][7]["image"],
dataset["train"][8]["image"],
],
inputs=input_image,
outputs=output_images,
fn=find_similar_faces,
cache_examples=False,
)
with gr.Tab("Advanced"):
with gr.Row():
with gr.Column(scale=1):
adv_input_image = gr.Image(type="pil")
adv_image_count = gr.Slider(1, MAX_COUNT, BASE_COUNT, label="Image Count")
adv_button = gr.Button("Submit")
with gr.Column(scale=2):
adv_output_images = gr.Gallery()
async def find_similar_faces(image, count):
if image is None:
return None
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
embed = outputs.last_hidden_state[0][0]
result = await index.query(
vector=embed.tolist(), top_k=TOP_K
)
return [dataset["train"][int(vector.id)]["image"] for vector in result[:int(count)]]
adv_button.click(
fn=find_similar_faces,
inputs=[adv_input_image, adv_image_count],
outputs=[adv_output_images],
)
adv_input_image.change(
fn=find_similar_faces,
inputs=[adv_input_image, adv_image_count],
outputs=[adv_output_images],
)
gr.Examples(
examples=[
[dataset["train"][6]["image"], BASE_COUNT],
[dataset["train"][7]["image"], BASE_COUNT],
[dataset["train"][8]["image"], BASE_COUNT],
],
inputs=[adv_input_image, adv_image_count],
outputs=adv_output_images,
fn=find_similar_faces,
cache_examples=False,
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=40)
demo.launch()
|