matteopilotto commited on
Commit
36b277a
·
1 Parent(s): 16d81eb

fix map_location=torch.device('cpu')

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -10,7 +10,7 @@ class_names = ['pizza', 'steak', 'sushi']
10
  examples = [os.path.join('examples', img) for img in os.listdir('examples')]
11
 
12
  model, preprocess = create_effnetb2_model(num_classes=3, seed=42)
13
- model.load_state_dict(torch.load('effnetb2_20_percent.pth'), map_location=torch.device('cpu'))
14
 
15
  def predict(img: PIL.Image) -> Tuple[Dict, float]:
16
 
 
10
  examples = [os.path.join('examples', img) for img in os.listdir('examples')]
11
 
12
  model, preprocess = create_effnetb2_model(num_classes=3, seed=42)
13
+ model.load_state_dict(torch.load('effnetb2_20_percent.pth', map_location=torch.device('cpu')))
14
 
15
  def predict(img: PIL.Image) -> Tuple[Dict, float]:
16