Spaces:
Runtime error
Runtime error
Update predict.py
Browse files- predict.py +1 -1
predict.py
CHANGED
@@ -10,7 +10,7 @@ def predict_one_image(path) :
|
|
10 |
image = read_image(path)
|
11 |
image = get_valid_augs()(image=image)['image']
|
12 |
image = torch.tensor(image,dtype=torch.float)
|
13 |
-
image = image.reshape((1,3,
|
14 |
model = CustomModel()
|
15 |
#loading ckpt
|
16 |
model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))
|
|
|
10 |
image = read_image(path)
|
11 |
image = get_valid_augs()(image=image)['image']
|
12 |
image = torch.tensor(image,dtype=torch.float)
|
13 |
+
image = image.reshape((1,3,512,512))
|
14 |
model = CustomModel()
|
15 |
#loading ckpt
|
16 |
model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))
|