Update TractionModel.py
Browse files- app/TractionModel.py +2 -2
app/TractionModel.py
CHANGED
@@ -53,7 +53,7 @@ def create_model():
|
|
53 |
return model
|
54 |
|
55 |
|
56 |
-
def load_weights(model, path='model.pt'):
|
57 |
-
checkpoint = torch.load(path, map_location=torch.device(
|
58 |
model.load_state_dict(checkpoint)
|
59 |
return model
|
|
|
53 |
return model
|
54 |
|
55 |
|
56 |
+
def load_weights(model, path='model.pt', device_='cpu'):
|
57 |
+
checkpoint = torch.load(path, map_location=torch.device(device_))
|
58 |
model.load_state_dict(checkpoint)
|
59 |
return model
|