sunana commited on
Commit
e4929da
1 Parent(s): b72463f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -7,12 +7,17 @@ import sys
7
  import FFV1MT_MS
8
  import flow_tools
9
 
 
 
 
 
 
10
  model = FFV1MT_MS.FFV1DNN()
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  print('Number fo parameters: {}'.format(model.num_parameters()))
14
  model.to(device)
15
- model_dict = torch.load('Model_example.pth.tar', map_location=device)['state_dict']
16
  # save model
17
  model.load_state_dict(model_dict, strict=True)
18
  model.eval()
 
7
  import FFV1MT_MS
8
  import flow_tools
9
 
10
+ if torch.cuda.is_available():
11
+ generator = torch.Generator('cuda').manual_seed(seed)
12
+ else:
13
+ generator = torch.Generator().manual_seed(seed)
14
+
15
  model = FFV1MT_MS.FFV1DNN()
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  print('Number fo parameters: {}'.format(model.num_parameters()))
19
  model.to(device)
20
+ model_dict = torch.load('Model_example.pth.tar')['state_dict']
21
  # save model
22
  model.load_state_dict(model_dict, strict=True)
23
  model.eval()