ved1beta
commited on
Commit
·
7d86658
1
Parent(s):
e99930d
i love you
Browse files
app.py
CHANGED
@@ -44,6 +44,9 @@ transform = transforms.Compose([
|
|
44 |
])
|
45 |
|
46 |
def predict(img):
|
|
|
|
|
|
|
47 |
# Convert to PIL Image if needed
|
48 |
if not isinstance(img, Image.Image):
|
49 |
img = Image.fromarray(img)
|
@@ -54,25 +57,27 @@ def predict(img):
|
|
54 |
# Get predictions
|
55 |
with torch.no_grad():
|
56 |
outputs = model(img)
|
57 |
-
probabilities = F.softmax(outputs, dim=1)
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
return {pred[0]: pred[1] for pred in predictions}
|
67 |
|
68 |
# Create Gradio interface
|
69 |
iface = gr.Interface(
|
70 |
fn=predict,
|
71 |
inputs=gr.Image(type="pil"),
|
72 |
-
outputs=gr.Label(num_top_classes=
|
73 |
examples=[["example1.jpg"], ["example2.jpg"]], # Optional: Add example images
|
74 |
title="CIFAR-10 Image Classifier",
|
75 |
-
description="Upload an image to classify it into one of these categories: plane, car, bird, cat, deer, dog, frog, horse, ship, or truck"
|
76 |
)
|
77 |
|
78 |
# Launch the app
|
|
|
44 |
])
|
45 |
|
46 |
def predict(img):
|
47 |
+
if img is None:
|
48 |
+
return None
|
49 |
+
|
50 |
# Convert to PIL Image if needed
|
51 |
if not isinstance(img, Image.Image):
|
52 |
img = Image.fromarray(img)
|
|
|
57 |
# Get predictions
|
58 |
with torch.no_grad():
|
59 |
outputs = model(img)
|
60 |
+
probabilities = F.softmax(outputs, dim=1)[0]
|
61 |
|
62 |
+
# Create dictionary with all classes and their probabilities
|
63 |
+
predictions = {
|
64 |
+
classes[i]: float(probabilities[i]) * 100 # Convert to percentage
|
65 |
+
for i in range(len(classes))
|
66 |
+
}
|
67 |
+
|
68 |
+
# Sort predictions by probability
|
69 |
+
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
|
70 |
|
71 |
+
return sorted_predictions
|
|
|
72 |
|
73 |
# Create Gradio interface
|
74 |
iface = gr.Interface(
|
75 |
fn=predict,
|
76 |
inputs=gr.Image(type="pil"),
|
77 |
+
outputs=gr.Label(num_top_classes=10), # Show all 10 classes
|
78 |
examples=[["example1.jpg"], ["example2.jpg"]], # Optional: Add example images
|
79 |
title="CIFAR-10 Image Classifier",
|
80 |
+
description="Upload an image to classify it into one of these categories: plane, car, bird, cat, deer, dog, frog, horse, ship, or truck. Results show prediction confidence for all classes as percentages."
|
81 |
)
|
82 |
|
83 |
# Launch the app
|