KabeerAmjad commited on
Commit
d9d7936
1 Parent(s): 6e9fd21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -22
app.py CHANGED
@@ -27,6 +27,7 @@ model.eval() # Set model to evaluation mode
27
  try:
28
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
29
  model.load_state_dict(state_dict)
 
30
  except RuntimeError as e:
31
  print("Error loading state_dict:", e)
32
  print("Ensure that the saved model architecture matches ResNet50.")
@@ -43,32 +44,43 @@ preprocess = transforms.Compose([
43
  ])
44
 
45
  # Load labels
46
- with open("config.json") as f:
47
- labels = json.load(f)
 
 
 
 
48
 
49
  # Function to predict image class
50
  def predict(image):
51
- # Convert the uploaded file to a PIL image
52
- input_image = image.convert("RGB")
53
-
54
- # Preprocess the image
55
- input_tensor = preprocess(input_image)
56
- input_batch = input_tensor.unsqueeze(0) # Add batch dimension
57
-
58
- # Check if a GPU is available and move the input and model to GPU
59
- if torch.cuda.is_available():
60
- input_batch = input_batch.to('cuda')
61
- model.to('cuda')
62
-
63
- # Perform inference
64
- with torch.no_grad():
65
- output = model(input_batch)
66
-
67
- # Get the predicted class with the highest score
68
- _, predicted_idx = torch.max(output, 1)
69
- predicted_class = labels[str(predicted_idx.item())]
 
 
 
70
 
71
- return f"Predicted class: {predicted_class}"
 
 
 
 
72
 
73
  # Set up the Gradio interface
74
  iface = gr.Interface(
 
27
  try:
28
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
29
  model.load_state_dict(state_dict)
30
+ print("Model loaded successfully.")
31
  except RuntimeError as e:
32
  print("Error loading state_dict:", e)
33
  print("Ensure that the saved model architecture matches ResNet50.")
 
44
  ])
45
 
46
  # Load labels
47
+ try:
48
+ with open("config.json") as f:
49
+ labels = json.load(f)
50
+ print("Labels loaded successfully.")
51
+ except Exception as e:
52
+ print("Error loading labels:", e)
53
 
54
  # Function to predict image class
55
  def predict(image):
56
+ try:
57
+ # Convert the uploaded file to a PIL image
58
+ input_image = image.convert("RGB")
59
+
60
+ # Preprocess the image
61
+ input_tensor = preprocess(input_image)
62
+ input_batch = input_tensor.unsqueeze(0) # Add batch dimension
63
+
64
+ # Check if a GPU is available and move the input and model to GPU
65
+ if torch.cuda.is_available():
66
+ input_batch = input_batch.to('cuda')
67
+ model.to('cuda')
68
+ else:
69
+ print("GPU not available, using CPU.")
70
+
71
+ # Perform inference
72
+ with torch.no_grad():
73
+ output = model(input_batch)
74
+
75
+ # Get the predicted class with the highest score
76
+ _, predicted_idx = torch.max(output, 1)
77
+ predicted_class = labels[str(predicted_idx.item())]
78
 
79
+ return f"Predicted class: {predicted_class}"
80
+
81
+ except Exception as e:
82
+ print(f"Error during prediction: {e}")
83
+ return "An error occurred during prediction. Please try again."
84
 
85
  # Set up the Gradio interface
86
  iface = gr.Interface(