Spaces:
Runtime error
Runtime error
File size: 3,054 Bytes
cacf2d0 85c68e5 cacf2d0 be551bc 85c68e5 34c6949 0895376 cacf2d0 f7da7b5 cacf2d0 c4a7337 cacf2d0 c1a06ed 696769d cacf2d0 c4a7337 cacf2d0 725fc62 cacf2d0 fed1d29 cacf2d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
from utils.predict import predict, predict_batch
import os
import glob
##Create list of examples to be loaded
example_list = glob.glob("examples/set2/*")
example_list = list(map(lambda el:[el], example_list))
demo = gr.Blocks()
with demo:
gr.Markdown("# **<p align='center'>ShiftViT: A Vision Transformer without Attention</p>**")
gr.Markdown("This space demonstrates the use of ShiftViT proposed in the paper: <a href=\"https://arxiv.org/abs/2201.10801/\">When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism</a> for image classification task.")
gr.Markdown("Vision Transformers have lately become very popular for computer vision problems and a lot researchers attribute their success to the attention layers.")
gr.Markdown("The authors of the ShiftViT paper have tried to show via the ShiftViT model that even without the attention operation, ViTs can reach SoTA results.")
with gr.Tabs():
with gr.TabItem("Batch Predict"):
gr.Markdown("Just click *Run Model* below:")
with gr.Box():
gr.Markdown("**Top 3 Predictions** \n")
output_df = gr.Dataframe(headers=["image","first", "second","third"],datatype=["str", "str", "str", "str"], label="Model Output")
gr.Markdown("**Output Plot** \n")
output_plot = gr.Image(type='filepath')
gr.Markdown("**Predict**")
with gr.Box():
with gr.Row():
compute_button = gr.Button("Run Model")
with gr.TabItem("Upload & Predict"):
with gr.Box():
with gr.Row():
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
output_label = gr.Label(label="Model", show_label=True)
gr.Markdown("**Predict**")
with gr.Box():
with gr.Row():
submit_button = gr.Button("Submit")
gr.Markdown("**Examples:**")
gr.Markdown("The model is trained to classify images belonging to the following classes:")
with gr.Column():
gr.Examples(example_list, [input_image], output_label, predict, cache_examples=True)
compute_button.click(predict_batch, inputs=input_image, outputs=[output_plot,output_df])
submit_button.click(predict, inputs=input_image, outputs=output_label)
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a> <br> Based on this <a href=\"https://keras.io/examples/vision/shiftvit/\">Keras example</a> by <a href=\"https://twitter.com/ariG23498\">Aritra Roy Gosthipaty</a> and <a href=\"https://twitter.com/ritwik_raha\">Ritwik Raha</a> <br> Demo Powered by this <a href=\"https://huggingface.co/shivi/shiftvit/\">ShiftViT model</a>')
demo.launch()
|