AkashDataScience commited on
Commit
78f0cbc
·
1 Parent(s): c44d42e

Updated model loading method

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -10,11 +10,8 @@ if torch.cuda.is_available():
10
  elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
11
  device = "mps"
12
 
13
- ckpt = torch.load("gpt2.pt", map_location=torch.device(device))
14
- config = GPTConfig(**ckpt['model_args'])
15
- model = GPT(config)
16
- state_dict = ckpt['model']
17
- model.load_state_dict(state_dict)
18
 
19
  model.to(device)
20
 
 
10
  elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
11
  device = "mps"
12
 
13
+ model = GPT(GPTConfig())
14
+ model.load_state_dict(torch.load("nanogpt.pth", map_location=torch.device(device)), strict=False)
 
 
 
15
 
16
  model.to(device)
17