anjikum commited on
Commit
a90f86d
·
verified ·
1 Parent(s): 5bf1142

Adjusting CPU for inferencing

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -4,9 +4,14 @@ from torchvision import models
4
  from PIL import Image
5
  import gradio as gr
6
 
 
 
 
7
  # Load your trained ResNet-50 model
8
- model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
9
- model.load_state_dict(torch.load("model.pth")) # Load the trained weights (.pth)
 
 
10
  model.eval() # Set model to evaluation mode
11
 
12
  # Define the transformation required for the input image
@@ -18,7 +23,6 @@ transform = transforms.Compose([
18
  ])
19
 
20
  # Define the labels for ImageNet (or your specific dataset labels)
21
- # This is typically a list of class labels for classification
22
  LABELS = ["class_1", "class_2", "class_3", "class_4", "class_5", # Replace with your classes
23
  "class_6", "class_7", "class_8", "class_9", "class_10"]
24
 
@@ -27,8 +31,12 @@ def predict(image):
27
  image = Image.open(image).convert("RGB") # Open the image and convert to RGB
28
  image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
29
 
 
 
 
30
  with torch.no_grad():
31
  outputs = model(image) # Get model predictions
 
32
  _, predicted = torch.max(outputs, 1) # Get the class with highest probability
33
  return LABELS[predicted.item()] # Return the predicted class label
34
 
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+ # Force CPU usage
8
+ device = torch.device('cpu')
9
+
10
  # Load your trained ResNet-50 model
11
+ model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
12
+ model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
13
+ model.to(device) # Move model to CPU (even if you have a GPU)
14
+
15
  model.eval() # Set model to evaluation mode
16
 
17
  # Define the transformation required for the input image
 
23
  ])
24
 
25
  # Define the labels for ImageNet (or your specific dataset labels)
 
26
  LABELS = ["class_1", "class_2", "class_3", "class_4", "class_5", # Replace with your classes
27
  "class_6", "class_7", "class_8", "class_9", "class_10"]
28
 
 
31
  image = Image.open(image).convert("RGB") # Open the image and convert to RGB
32
  image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
33
 
34
+ # Move the image tensor to CPU as well
35
+ image = image.to(device)
36
+
37
  with torch.no_grad():
38
  outputs = model(image) # Get model predictions
39
+
40
  _, predicted = torch.max(outputs, 1) # Get the class with highest probability
41
  return LABELS[predicted.item()] # Return the predicted class label
42