Tanzeer commited on
Commit
36945ed
·
1 Parent(s): e84f8ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -4,9 +4,10 @@ import numpy as np
4
  import torch
5
  from torchvision import transforms
6
  from PIL import Image
 
7
 
8
  # Load the model and set the device
9
- model = TranSalNet() # Assuming you have defined your model
10
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
11
  model.eval()
12
  device = torch.device('cpu')
 
4
  import torch
5
  from torchvision import transforms
6
  from PIL import Image
7
+ from TranSalNet_Res import TranSalNet
8
 
9
  # Load the model and set the device
10
+ model = TranSalNet()
11
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
12
  model.eval()
13
  device = torch.device('cpu')