0-ma commited on
Commit
054f74a
·
verified ·
1 Parent(s): 2559312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -16,7 +16,7 @@ model_names = [
16
  "0-ma/vit-geometric-shapes-tiny",
17
  ]
18
 
19
- examples = [
20
  'example/1_None.jpg',
21
  'example/2_Circle.jpg',
22
  'example/3_Triangle.jpg',
@@ -25,7 +25,7 @@ examples = [
25
  'example/6_Hexagone.jpg'
26
  ]
27
 
28
- labels = [example.split("_")[1].split(".")[0] for example in examples]
29
 
30
  feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name) for model_name in model_names}
31
  classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names}
@@ -49,6 +49,9 @@ def predict(image, selected_model):
49
  title = "Geometric Shape Classifier"
50
  description = "Select a model to classify geometric shapes."
51
 
 
 
 
52
  # Create the Gradio interface
53
  iface = gr.Interface(
54
  fn=predict,
@@ -62,4 +65,5 @@ iface = gr.Interface(
62
  examples=examples
63
  )
64
 
65
- # Launch the interf
 
 
16
  "0-ma/vit-geometric-shapes-tiny",
17
  ]
18
 
19
+ example_images = [
20
  'example/1_None.jpg',
21
  'example/2_Circle.jpg',
22
  'example/3_Triangle.jpg',
 
25
  'example/6_Hexagone.jpg'
26
  ]
27
 
28
+ labels = [example.split("_")[1].split(".")[0] for example in example_images]
29
 
30
  feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name) for model_name in model_names}
31
  classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names}
 
49
  title = "Geometric Shape Classifier"
50
  description = "Select a model to classify geometric shapes."
51
 
52
+ # Create examples with both image and default model
53
+ examples = [[img, model_names[0]] for img in example_images]
54
+
55
  # Create the Gradio interface
56
  iface = gr.Interface(
57
  fn=predict,
 
65
  examples=examples
66
  )
67
 
68
+ # Launch the interface
69
+ iface.launch()