File size: 3,054 Bytes
cacf2d0
 
 
85c68e5
cacf2d0
be551bc
85c68e5
 
34c6949
0895376
 
cacf2d0
 
 
 
f7da7b5
 
cacf2d0
 
 
c4a7337
cacf2d0
 
 
c1a06ed
696769d
cacf2d0
 
 
c4a7337
cacf2d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725fc62
cacf2d0
 
 
 
 
fed1d29
cacf2d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
from utils.predict import predict, predict_batch
import os
import glob

##Create list of examples to be loaded
example_list = glob.glob("examples/set2/*")
example_list = list(map(lambda el:[el], example_list))

demo = gr.Blocks()

with demo:
    
    gr.Markdown("# **<p align='center'>ShiftViT: A Vision Transformer without Attention</p>**")
    gr.Markdown("This space demonstrates the use of ShiftViT proposed in the paper: <a href=\"https://arxiv.org/abs/2201.10801/\">When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism</a> for image classification task.")
    gr.Markdown("Vision Transformers have lately become very popular for computer vision problems and a lot researchers attribute their success to the attention layers.")
    gr.Markdown("The authors of the ShiftViT paper have tried to show via the ShiftViT model that even without the attention operation, ViTs can reach SoTA results.")
    
    with gr.Tabs():
        
        with gr.TabItem("Batch Predict"):
            
            gr.Markdown("Just click *Run Model* below:")
            with gr.Box():
                gr.Markdown("**Top 3 Predictions** \n")
                output_df = gr.Dataframe(headers=["image","first", "second","third"],datatype=["str", "str", "str", "str"], label="Model Output")
                gr.Markdown("**Output Plot** \n")
                output_plot = gr.Image(type='filepath')

            gr.Markdown("**Predict**")
            
            with gr.Box():
                with gr.Row():
                    compute_button = gr.Button("Run Model")
                
        
        with gr.TabItem("Upload & Predict"):
            with gr.Box():
                
                with gr.Row():
                    input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
                    output_label = gr.Label(label="Model", show_label=True)
            
            gr.Markdown("**Predict**")
            
            with gr.Box():
                with gr.Row():
                    submit_button = gr.Button("Submit")
            
            gr.Markdown("**Examples:**")
            gr.Markdown("The model is trained to classify images belonging to the following classes:")

            with gr.Column():
                gr.Examples(example_list, [input_image], output_label, predict, cache_examples=True)
        
    
    compute_button.click(predict_batch, inputs=input_image, outputs=[output_plot,output_df])
    submit_button.click(predict, inputs=input_image, outputs=output_label)
    
    gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a> <br> Based on this <a href=\"https://keras.io/examples/vision/shiftvit/\">Keras example</a> by <a href=\"https://twitter.com/ariG23498\">Aritra Roy Gosthipaty</a> and <a href=\"https://twitter.com/ritwik_raha\">Ritwik Raha</a> <br> Demo Powered by this <a href=\"https://huggingface.co/shivi/shiftvit/\">ShiftViT model</a>')
    
demo.launch()