ReacherTN commited on
Commit
d383e61
·
1 Parent(s): 8693834

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +1 -1
predict.py CHANGED
@@ -10,7 +10,7 @@ def predict_one_image(path) :
10
  image = read_image(path)
11
  image = get_valid_augs()(image=image)['image']
12
  image = torch.tensor(image,dtype=torch.float)
13
- image = image.reshape((1,3,224,224))
14
  model = CustomModel()
15
  #loading ckpt
16
  model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))
 
10
  image = read_image(path)
11
  image = get_valid_augs()(image=image)['image']
12
  image = torch.tensor(image,dtype=torch.float)
13
+ image = image.reshape((1,3,512,512))
14
  model = CustomModel()
15
  #loading ckpt
16
  model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))