shiftvit / app.py
shivalikasingh's picture
Update app.py
c4a7337
import gradio as gr
from utils.predict import predict, predict_batch
import os
import glob
##Create list of examples to be loaded
example_list = glob.glob("examples/set2/*")
example_list = list(map(lambda el:[el], example_list))
demo = gr.Blocks()
with demo:
gr.Markdown("# **<p align='center'>ShiftViT: A Vision Transformer without Attention</p>**")
gr.Markdown("This space demonstrates the use of ShiftViT proposed in the paper: <a href=\"https://arxiv.org/abs/2201.10801/\">When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism</a> for image classification task.")
gr.Markdown("Vision Transformers have lately become very popular for computer vision problems and a lot researchers attribute their success to the attention layers.")
gr.Markdown("The authors of the ShiftViT paper have tried to show via the ShiftViT model that even without the attention operation, ViTs can reach SoTA results.")
with gr.Tabs():
with gr.TabItem("Batch Predict"):
gr.Markdown("Just click *Run Model* below:")
with gr.Box():
gr.Markdown("**Top 3 Predictions** \n")
output_df = gr.Dataframe(headers=["image","first", "second","third"],datatype=["str", "str", "str", "str"], label="Model Output")
gr.Markdown("**Output Plot** \n")
output_plot = gr.Image(type='filepath')
gr.Markdown("**Predict**")
with gr.Box():
with gr.Row():
compute_button = gr.Button("Run Model")
with gr.TabItem("Upload & Predict"):
with gr.Box():
with gr.Row():
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
output_label = gr.Label(label="Model", show_label=True)
gr.Markdown("**Predict**")
with gr.Box():
with gr.Row():
submit_button = gr.Button("Submit")
gr.Markdown("**Examples:**")
gr.Markdown("The model is trained to classify images belonging to the following classes:")
with gr.Column():
gr.Examples(example_list, [input_image], output_label, predict, cache_examples=True)
compute_button.click(predict_batch, inputs=input_image, outputs=[output_plot,output_df])
submit_button.click(predict, inputs=input_image, outputs=output_label)
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a> <br> Based on this <a href=\"https://keras.io/examples/vision/shiftvit/\">Keras example</a> by <a href=\"https://twitter.com/ariG23498\">Aritra Roy Gosthipaty</a> and <a href=\"https://twitter.com/ritwik_raha\">Ritwik Raha</a> <br> Demo Powered by this <a href=\"https://huggingface.co/shivi/shiftvit/\">ShiftViT model</a>')
demo.launch()