Update model.py
Browse files
model.py
CHANGED
@@ -52,7 +52,7 @@ class ViTRecognitionModel(nn.Module):
|
|
52 |
|
53 |
def load_model(model_path, device='cpu'):
|
54 |
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
|
55 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
56 |
model.to(device)
|
57 |
model.eval()
|
58 |
return model
|
|
|
52 |
|
53 |
def load_model(model_path, device='cpu'):
|
54 |
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
|
55 |
+
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
56 |
model.to(device)
|
57 |
model.eval()
|
58 |
return model
|