SeemG commited on
Commit
ce6ef21
·
verified ·
1 Parent(s): 32d93f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -12,7 +12,10 @@ class GPTConfig:
12
  n_head: int = 12 # number of heads
13
  n_embd: int = 768 # embedding dimension
14
 
15
-
 
 
 
16
 
17
  # Define generation function
18
  def generate_text(prompt, max_length=50, num_return_sequences=10):
 
12
  n_head: int = 12 # number of heads
13
  n_embd: int = 768 # embedding dimension
14
 
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ model2 = model.GPT(GPTConfig())
17
+ model2 = model2.to(device)
18
+ model2.load_state_dict(torch.load('gpt_124M_1.pth', map_location=torch.device(device)))
19
 
20
  # Define generation function
21
  def generate_text(prompt, max_length=50, num_return_sequences=10):