ClassCat commited on
Commit
78911a4
1 Parent(s): 9f45196

update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -9,6 +9,9 @@ from PIL import Image
9
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
10
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
11
 
 
 
 
12
  def classify_image(image):
13
 
14
  with torch.no_grad():
@@ -34,7 +37,10 @@ with gr.Blocks(title="ViT ImageNet Classification - ClassCat",
34
  gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">ViT - ImageNet Classification</div>""")
35
 
36
  with gr.Row():
37
- input_image = gr.Image(type="pil", image_mode="RGB", shape=(224, 224))
 
 
 
38
  output_label=gr.Label(label="Probabilities", num_top_classes=3)
39
 
40
  send_btn = gr.Button("Infer")
 
9
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
10
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
11
 
12
+ examples_dir = './samples'
13
+ example_files = glob.glob(os.path.join(examples_dir, '*.jpg'))
14
+
15
  def classify_image(image):
16
 
17
  with torch.no_grad():
 
37
  gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">ViT - ImageNet Classification</div>""")
38
 
39
  with gr.Row():
40
+ with gr.Column():
41
+ input_image = gr.Image(type="pil", image_mode="RGB", shape=(224, 224))
42
+ gr.Examples(example_files, inputs=input_image)
43
+
44
  output_label=gr.Label(label="Probabilities", num_top_classes=3)
45
 
46
  send_btn = gr.Button("Infer")