prithivMLmods commited on
Commit
364cb51
·
verified ·
1 Parent(s): 614af1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -35
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  from gender_classification import gender_classification
3
  from emotion_classification import emotion_classification
4
  from dog_breed import dog_breed_classification
@@ -14,7 +16,7 @@ from alphabet_sign_language_detection import sign_language_classification
14
  from rice_leaf_disease import classify_leaf_disease
15
  from traffic_density import traffic_density_classification
16
 
17
- # Main classification function that calls the appropriate model based on selection.
18
  def classify(image, model_name):
19
  if model_name == "gender":
20
  return gender_classification(image)
@@ -58,42 +60,98 @@ def select_model(model_name):
58
  model_variants[model_name] = "primary"
59
  return (model_name, *(gr.update(variant=model_variants[key]) for key in model_variants))
60
 
61
- with gr.Blocks() as demo:
62
- with gr.Sidebar():
63
- gr.Markdown("# Choose Domain")
64
- with gr.Row():
65
- age_btn = gr.Button("Age Classification", variant="primary")
66
- gender_btn = gr.Button("Gender Classification", variant="secondary")
67
- emotion_btn = gr.Button("Emotion Classification", variant="secondary")
68
- dog_breed_btn = gr.Button("Dog Breed Classification", variant="secondary")
69
- deepfake_btn = gr.Button("Deepfake vs Real", variant="secondary")
70
- gym_workout_btn = gr.Button("Gym Workout Classification", variant="secondary")
71
- waste_btn = gr.Button("Waste Classification", variant="secondary")
72
- mnist_btn = gr.Button("Digit Classify (0-9)", variant="secondary")
73
- fashion_mnist_btn = gr.Button("Fashion MNIST Classification", variant="secondary")
74
- food_btn = gr.Button("Indian/Western Food", variant="secondary")
75
- bird_btn = gr.Button("Bird Species Classification", variant="secondary")
76
- leaf_disease_btn = gr.Button("Rice Leaf Disease", variant="secondary")
77
- sign_language_btn = gr.Button("Alphabet Sign Language", variant="secondary")
78
- traffic_density_btn = gr.Button("Traffic Density", variant="secondary")
79
-
80
- selected_model = gr.State("age")
81
- gr.Markdown("### Current Model:")
82
- model_display = gr.Textbox(value="age", interactive=False)
83
- selected_model.change(lambda m: m, selected_model, model_display)
84
 
85
- buttons = [gender_btn, emotion_btn, dog_breed_btn, deepfake_btn, gym_workout_btn, waste_btn, age_btn, mnist_btn, fashion_mnist_btn, food_btn, bird_btn, leaf_disease_btn, sign_language_btn, traffic_density_btn]
86
- model_names = ["gender", "emotion", "dog breed", "deepfake", "gym workout", "waste", "age", "mnist", "fashion_mnist", "food", "bird", "leaf disease", "sign language", "traffic density"]
87
-
88
- for btn, name in zip(buttons, model_names):
89
- btn.click(fn=lambda n=name: select_model(n), inputs=[], outputs=[selected_model] + buttons)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- with gr.Row():
92
- with gr.Column():
93
- image_input = gr.Image(type="numpy", label="Upload Image")
94
- analyze_btn = gr.Button("Classify / Predict")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- output_label = gr.Label(label="Prediction Scores")
97
- analyze_btn.click(fn=classify, inputs=[image_input, selected_model], outputs=output_label)
 
 
 
 
 
 
 
 
 
 
98
 
99
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModel, AutoProcessor
4
  from gender_classification import gender_classification
5
  from emotion_classification import emotion_classification
6
  from dog_breed import dog_breed_classification
 
16
  from rice_leaf_disease import classify_leaf_disease
17
  from traffic_density import traffic_density_classification
18
 
19
+ # Main classification function for multi-model classification.
20
  def classify(image, model_name):
21
  if model_name == "gender":
22
  return gender_classification(image)
 
60
  model_variants[model_name] = "primary"
61
  return (model_name, *(gr.update(variant=model_variants[key]) for key in model_variants))
62
 
63
+ # Zero-Shot Classification Setup (SigLIP models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Load the SigLIP models and processors
66
+ sg1_ckpt = "google/siglip-so400m-patch14-384"
67
+ siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="cpu").eval()
68
+ siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
69
+
70
+ sg2_ckpt = "google/siglip2-so400m-patch14-384"
71
+ siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="cpu").eval()
72
+ siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
73
+
74
+ # Utilities for zero-shot classification.
75
+ def postprocess_siglip(sg1_probs, sg2_probs, labels):
76
+ sg1_output = {labels[i]: sg1_probs[0][i].item() for i in range(len(labels))}
77
+ sg2_output = {labels[i]: sg2_probs[0][i].item() for i in range(len(labels))}
78
+ return sg1_output, sg2_output
79
+
80
+ def siglip_detector(image, texts):
81
+ sg1_inputs = siglip1_processor(
82
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
83
+ ).to("cpu")
84
+ sg2_inputs = siglip2_processor(
85
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
86
+ ).to("cpu")
87
+ with torch.no_grad():
88
+ sg1_outputs = siglip1_model(**sg1_inputs)
89
+ sg2_outputs = siglip2_model(**sg2_inputs)
90
+ sg1_logits_per_image = sg1_outputs.logits_per_image
91
+ sg2_logits_per_image = sg2_outputs.logits_per_image
92
+ sg1_probs = torch.sigmoid(sg1_logits_per_image)
93
+ sg2_probs = torch.sigmoid(sg2_logits_per_image)
94
+ return sg1_probs, sg2_probs
95
+
96
+ def infer(image, candidate_labels):
97
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
98
+ sg1_probs, sg2_probs = siglip_detector(image, candidate_labels)
99
+ return postprocess_siglip(sg1_probs, sg2_probs, labels=candidate_labels)
100
+
101
+ # Build the Gradio Interface with two tab
102
+ with gr.Blocks() as demo:
103
+ gr.Markdown("# Multi-Model & Zero-Shot Classification Interface")
104
 
105
+ with gr.Tabs():
106
+ # Tab 1: Multi-Model Classification
107
+ with gr.Tab("Multi-Model Classification"):
108
+ with gr.Sidebar():
109
+ gr.Markdown("# Choose Domain")
110
+ with gr.Row():
111
+ age_btn = gr.Button("Age Classification", variant="primary")
112
+ gender_btn = gr.Button("Gender Classification", variant="secondary")
113
+ emotion_btn = gr.Button("Emotion Classification", variant="secondary")
114
+ dog_breed_btn = gr.Button("Dog Breed Classification", variant="secondary")
115
+ deepfake_btn = gr.Button("Deepfake vs Real", variant="secondary")
116
+ gym_workout_btn = gr.Button("Gym Workout Classification", variant="secondary")
117
+ waste_btn = gr.Button("Waste Classification", variant="secondary")
118
+ mnist_btn = gr.Button("Digit Classify (0-9)", variant="secondary")
119
+ fashion_mnist_btn = gr.Button("Fashion MNIST Classification", variant="secondary")
120
+ food_btn = gr.Button("Indian/Western Food", variant="secondary")
121
+ bird_btn = gr.Button("Bird Species Classification", variant="secondary")
122
+ leaf_disease_btn = gr.Button("Rice Leaf Disease", variant="secondary")
123
+ sign_language_btn = gr.Button("Alphabet Sign Language", variant="secondary")
124
+ traffic_density_btn = gr.Button("Traffic Density", variant="secondary")
125
+
126
+ selected_model = gr.State("age")
127
+ gr.Markdown("### Current Model:")
128
+ model_display = gr.Textbox(value="age", interactive=False)
129
+ selected_model.change(lambda m: m, selected_model, model_display)
130
+
131
+ buttons = [gender_btn, emotion_btn, dog_breed_btn, deepfake_btn, gym_workout_btn, waste_btn, age_btn, mnist_btn, fashion_mnist_btn, food_btn, bird_btn, leaf_disease_btn, sign_language_btn, traffic_density_btn]
132
+ model_names = ["gender", "emotion", "dog breed", "deepfake", "gym workout", "waste", "age", "mnist", "fashion_mnist", "food", "bird", "leaf disease", "sign language", "traffic density"]
133
+
134
+ for btn, name in zip(buttons, model_names):
135
+ btn.click(fn=lambda n=name: select_model(n), inputs=[], outputs=[selected_model] + buttons)
136
+
137
+ with gr.Row():
138
+ with gr.Column():
139
+ image_input = gr.Image(type="numpy", label="Upload Image")
140
+ analyze_btn = gr.Button("Classify / Predict")
141
+ output_label = gr.Label(label="Prediction Scores")
142
+ analyze_btn.click(fn=classify, inputs=[image_input, selected_model], outputs=output_label)
143
 
144
+ # Tab 2: Zero-Shot Classification (SigLIP)
145
+ with gr.Tab("Zero-Shot Classification"):
146
+ gr.Markdown("## Compare SigLIP 1 and SigLIP 2 on Zero-Shot Classification")
147
+ with gr.Row():
148
+ with gr.Column():
149
+ zs_image_input = gr.Image(type="pil", label="Upload Image")
150
+ zs_text_input = gr.Textbox(label="Input a list of labels (comma separated)")
151
+ zs_run_button = gr.Button("Run")
152
+ with gr.Column():
153
+ siglip1_output = gr.Label(label="SigLIP 1 Output", num_top_classes=3)
154
+ siglip2_output = gr.Label(label="SigLIP 2 Output", num_top_classes=3)
155
+ zs_run_button.click(fn=infer, inputs=[zs_image_input, zs_text_input], outputs=[siglip1_output, siglip2_output])
156
 
157
  demo.launch()