ved1beta commited on
Commit
7d86658
·
1 Parent(s): e99930d

i love you

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -44,6 +44,9 @@ transform = transforms.Compose([
44
  ])
45
 
46
  def predict(img):
 
 
 
47
  # Convert to PIL Image if needed
48
  if not isinstance(img, Image.Image):
49
  img = Image.fromarray(img)
@@ -54,25 +57,27 @@ def predict(img):
54
  # Get predictions
55
  with torch.no_grad():
56
  outputs = model(img)
57
- probabilities = F.softmax(outputs, dim=1)
58
 
59
- # Get top 3 predictions
60
- probs, indices = torch.topk(probabilities, 3)
61
- predictions = []
62
- for prob, idx in zip(probs[0], indices[0]):
63
- predictions.append((classes[idx], float(prob)))
 
 
 
64
 
65
- # Format the results
66
- return {pred[0]: pred[1] for pred in predictions}
67
 
68
  # Create Gradio interface
69
  iface = gr.Interface(
70
  fn=predict,
71
  inputs=gr.Image(type="pil"),
72
- outputs=gr.Label(num_top_classes=3),
73
  examples=[["example1.jpg"], ["example2.jpg"]], # Optional: Add example images
74
  title="CIFAR-10 Image Classifier",
75
- description="Upload an image to classify it into one of these categories: plane, car, bird, cat, deer, dog, frog, horse, ship, or truck"
76
  )
77
 
78
  # Launch the app
 
44
  ])
45
 
46
  def predict(img):
47
+ if img is None:
48
+ return None
49
+
50
  # Convert to PIL Image if needed
51
  if not isinstance(img, Image.Image):
52
  img = Image.fromarray(img)
 
57
  # Get predictions
58
  with torch.no_grad():
59
  outputs = model(img)
60
+ probabilities = F.softmax(outputs, dim=1)[0]
61
 
62
+ # Create dictionary with all classes and their probabilities
63
+ predictions = {
64
+ classes[i]: float(probabilities[i]) * 100 # Convert to percentage
65
+ for i in range(len(classes))
66
+ }
67
+
68
+ # Sort predictions by probability
69
+ sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
70
 
71
+ return sorted_predictions
 
72
 
73
  # Create Gradio interface
74
  iface = gr.Interface(
75
  fn=predict,
76
  inputs=gr.Image(type="pil"),
77
+ outputs=gr.Label(num_top_classes=10), # Show all 10 classes
78
  examples=[["example1.jpg"], ["example2.jpg"]], # Optional: Add example images
79
  title="CIFAR-10 Image Classifier",
80
+ description="Upload an image to classify it into one of these categories: plane, car, bird, cat, deer, dog, frog, horse, ship, or truck. Results show prediction confidence for all classes as percentages."
81
  )
82
 
83
  # Launch the app