update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ model = SegResNet(
|
|
21 |
)
|
22 |
|
23 |
model.load_state_dict(
|
24 |
-
torch.load("model.pt", map_location=torch.device('cpu'))
|
25 |
)
|
26 |
|
27 |
# define inference method
|
@@ -68,7 +68,7 @@ import torchvision
|
|
68 |
def load_sample(index):
|
69 |
#sample_index = index
|
70 |
|
71 |
-
sample = torch.load(f"val{index-1}.pt")
|
72 |
imgs = []
|
73 |
for i in range(4):
|
74 |
imgs.append(sample["image"][i, :, :, 70])
|
@@ -91,7 +91,7 @@ def load_sample(index):
|
|
91 |
|
92 |
def predict(sample_index):
|
93 |
print(sample_index)
|
94 |
-
sample = torch.load(f"val{sample_index-1}.pt")
|
95 |
model.eval()
|
96 |
with torch.no_grad():
|
97 |
# select one image to evaluate and visualize the model output
|
|
|
21 |
)
|
22 |
|
23 |
model.load_state_dict(
|
24 |
+
torch.load("weights/model.pt", map_location=torch.device('cpu'))
|
25 |
)
|
26 |
|
27 |
# define inference method
|
|
|
68 |
def load_sample(index):
|
69 |
#sample_index = index
|
70 |
|
71 |
+
sample = torch.load(f"samples/val{index-1}.pt")
|
72 |
imgs = []
|
73 |
for i in range(4):
|
74 |
imgs.append(sample["image"][i, :, :, 70])
|
|
|
91 |
|
92 |
def predict(sample_index):
|
93 |
print(sample_index)
|
94 |
+
sample = torch.load(f"samples/val{sample_index-1}.pt")
|
95 |
model.eval()
|
96 |
with torch.no_grad():
|
97 |
# select one image to evaluate and visualize the model output
|