bhaveshgoel07 commited on
Commit
43be314
·
1 Parent(s): 07e6e6d

Fixed errors

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -37,18 +37,22 @@ transform = transforms.Compose([
37
 
38
  # Prediction function
39
  def predict(image):
40
- image = transform(image).unsqueeze(0) # Add batch dimension
41
- with torch.no_grad():
42
- output = model(image)
43
- probabilities = nn.Softmax(dim=1)(output)
44
- predicted_class = torch.argmax(probabilities, dim=1)
45
- return {str(i): probabilities[0][i].item() for i in range(10)}
 
 
 
 
46
 
47
  # Create the Gradio interface
48
  interface = gr.Interface(
49
  fn=predict,
50
  inputs=gr.Sketchpad(),
51
- outputs=gr.Label(num_top_classes=10)
52
  )
53
 
54
  # Launch the interface
 
37
 
38
  # Prediction function
39
  def predict(image):
40
+ try:
41
+ image = transform(image).unsqueeze(0) # Add batch dimension
42
+ with torch.no_grad():
43
+ output = model(image)
44
+ probabilities = nn.Softmax(dim=1)(output)
45
+ predicted_class = torch.argmax(probabilities, dim=1)
46
+ return {str(i): probabilities[0][i].item() for i in range(10)}
47
+ except Exception as e:
48
+ print(f"Error in predict function: {e}")
49
+ return {"error": str(e)}
50
 
51
  # Create the Gradio interface
52
  interface = gr.Interface(
53
  fn=predict,
54
  inputs=gr.Sketchpad(),
55
+ outputs=gr.Label()
56
  )
57
 
58
  # Launch the interface