KittyCat00 commited on
Commit
8ba60a0
·
verified ·
1 Parent(s): 6fa3bf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -485,19 +485,27 @@ def main(input_text, max_new_tokens):
485
  else:
486
  device = torch.device("cpu")
487
 
488
- # weights = torch.load("model_and_optimizer.pth", map_location=torch.device(device))
489
- weights = torch.load("model_and_optimizer.pth", weights_only=False)
490
 
491
- model = GPTModel({
492
- "vocab_size": 50257, # Vocabulary size
493
- "context_length": 512, # Shortened context length (orig: 1024)
494
- "emb_dim": 768, # Embedding dimension
495
- "n_heads": 12, # Number of attention heads
496
- "n_layers": 12, # Number of layers
497
- "drop_rate": 0.3, # Dropout rate
498
- "qkv_bias": False # Query-key-value bias
499
- }).to(device)
500
- model.load_state_dict(weights['model_state_dict'])
 
 
 
 
 
 
 
 
 
501
  model.eval()
502
 
503
  context_size = model.pos_emb.weight.shape[0]
 
485
  else:
486
  device = torch.device("cpu")
487
 
488
+ checkpoint = torch.load("model_and_optimizer.pth", weights_only=True)
 
489
 
490
+ model = GPTModel(GPT_CONFIG_124M)
491
+ model.load_state_dict(checkpoint["model_state_dict"])
492
+
493
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)
494
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
495
+
496
+ # weights = torch.load("model_and_optimizer.pth", map_location=torch.device(device))
497
+ # weights = torch.load("model_and_optimizer.pth", weights_only=False)
498
+
499
+ # model = GPTModel({
500
+ # "vocab_size": 50257, # Vocabulary size
501
+ # "context_length": 512, # Shortened context length (orig: 1024)
502
+ # "emb_dim": 768, # Embedding dimension
503
+ # "n_heads": 12, # Number of attention heads
504
+ # "n_layers": 12, # Number of layers
505
+ # "drop_rate": 0.3, # Dropout rate
506
+ # "qkv_bias": False # Query-key-value bias
507
+ # }).to(device)
508
+ # model.load_state_dict(weights['model_state_dict'])
509
  model.eval()
510
 
511
  context_size = model.pos_emb.weight.shape[0]