Trickshotblaster commited on
Commit
6c242fa
·
1 Parent(s): 01db43e
Files changed (1) hide show
  1. gpt.py +1 -1
gpt.py CHANGED
@@ -113,7 +113,7 @@ torch.set_float32_matmul_precision('high')
113
  my_GPT = GPT(enc.n_vocab, block_size, n_layers, n_heads, d_model, dropout=0.1) #enc.n_vocab
114
  my_GPT = my_GPT.to(device)
115
  my_GPT = torch.compile(my_GPT)
116
- my_GPT.load_state_dict(torch.load('latest_model_finetune.pth'))
117
  my_GPT.eval()
118
 
119
  eot = enc._special_tokens['<|endoftext|>']
 
113
  my_GPT = GPT(enc.n_vocab, block_size, n_layers, n_heads, d_model, dropout=0.1) #enc.n_vocab
114
  my_GPT = my_GPT.to(device)
115
  my_GPT = torch.compile(my_GPT)
116
+ my_GPT.load_state_dict(torch.load('latest_model_finetune.pth', map_location=torch.device('cpu')))
117
  my_GPT.eval()
118
 
119
  eot = enc._special_tokens['<|endoftext|>']