Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,17 @@ import sys
|
|
7 |
import FFV1MT_MS
|
8 |
import flow_tools
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
model = FFV1MT_MS.FFV1DNN()
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
print('Number fo parameters: {}'.format(model.num_parameters()))
|
14 |
model.to(device)
|
15 |
-
model_dict = torch.load('Model_example.pth.tar'
|
16 |
# save model
|
17 |
model.load_state_dict(model_dict, strict=True)
|
18 |
model.eval()
|
|
|
7 |
import FFV1MT_MS
|
8 |
import flow_tools
|
9 |
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
generator = torch.Generator('cuda').manual_seed(seed)
|
12 |
+
else:
|
13 |
+
generator = torch.Generator().manual_seed(seed)
|
14 |
+
|
15 |
model = FFV1MT_MS.FFV1DNN()
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
print('Number fo parameters: {}'.format(model.num_parameters()))
|
19 |
model.to(device)
|
20 |
+
model_dict = torch.load('Model_example.pth.tar')['state_dict']
|
21 |
# save model
|
22 |
model.load_state_dict(model_dict, strict=True)
|
23 |
model.eval()
|