ClassCat commited on
Commit
1757726
1 Parent(s): 08ccc4a

update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -21,7 +21,7 @@ model = SegResNet(
21
  )
22
 
23
  model.load_state_dict(
24
- torch.load("model.pt", map_location=torch.device('cpu'))
25
  )
26
 
27
  # define inference method
@@ -68,7 +68,7 @@ import torchvision
68
  def load_sample(index):
69
  #sample_index = index
70
 
71
- sample = torch.load(f"val{index-1}.pt")
72
  imgs = []
73
  for i in range(4):
74
  imgs.append(sample["image"][i, :, :, 70])
@@ -91,7 +91,7 @@ def load_sample(index):
91
 
92
  def predict(sample_index):
93
  print(sample_index)
94
- sample = torch.load(f"val{sample_index-1}.pt")
95
  model.eval()
96
  with torch.no_grad():
97
  # select one image to evaluate and visualize the model output
 
21
  )
22
 
23
  model.load_state_dict(
24
+ torch.load("weights/model.pt", map_location=torch.device('cpu'))
25
  )
26
 
27
  # define inference method
 
68
  def load_sample(index):
69
  #sample_index = index
70
 
71
+ sample = torch.load(f"samples/val{index-1}.pt")
72
  imgs = []
73
  for i in range(4):
74
  imgs.append(sample["image"][i, :, :, 70])
 
91
 
92
  def predict(sample_index):
93
  print(sample_index)
94
+ sample = torch.load(f"samples/val{sample_index-1}.pt")
95
  model.eval()
96
  with torch.no_grad():
97
  # select one image to evaluate and visualize the model output