eaglelandsonce commited on
Commit
b980d33
·
verified ·
1 Parent(s): b8c5c1a

Update pages/19_RNN_Shakespeare.py

Browse files
Files changed (1) hide show
  1. pages/19_RNN_Shakespeare.py +4 -4
pages/19_RNN_Shakespeare.py CHANGED
@@ -22,7 +22,7 @@ class LSTMModel(nn.Module):
22
  def generate_text(model, start_str, length, char_to_int, int_to_char, num_layers, hidden_size):
23
  model.eval()
24
  input_seq = [char_to_int[c] for c in start_str]
25
- input_seq = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
26
  h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
27
  generated_text = start_str
28
 
@@ -31,7 +31,7 @@ def generate_text(model, start_str, length, char_to_int, int_to_char, num_layers
31
  _, predicted = torch.max(output, 1)
32
  predicted_char = int_to_char[predicted.item()]
33
  generated_text += predicted_char
34
- input_seq = torch.tensor([[char_to_int[predicted_char]]], dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
35
 
36
  return generated_text
37
 
@@ -70,7 +70,7 @@ if st.button("Train and Generate"):
70
  if len(dataX) == 0:
71
  st.error("Not enough data to create input-output pairs. Please provide more text data.")
72
  else:
73
- X = np.reshape(dataX, (len(dataX), seq_length, 1))
74
  X = X / float(len(chars))
75
  Y = np.array(dataY)
76
 
@@ -90,7 +90,7 @@ if st.button("Train and Generate"):
90
  h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
91
  epoch_loss = 0
92
  for i in range(len(dataX)):
93
- inputs = X_tensor[i].unsqueeze(0)
94
  targets = Y_tensor[i].unsqueeze(0)
95
 
96
  # Forward pass
 
22
  def generate_text(model, start_str, length, char_to_int, int_to_char, num_layers, hidden_size):
23
  model.eval()
24
  input_seq = [char_to_int[c] for c in start_str]
25
+ input_seq = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
26
  h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
27
  generated_text = start_str
28
 
 
31
  _, predicted = torch.max(output, 1)
32
  predicted_char = int_to_char[predicted.item()]
33
  generated_text += predicted_char
34
+ input_seq = torch.tensor([[char_to_int[predicted_char]]], dtype=torch.long)
35
 
36
  return generated_text
37
 
 
70
  if len(dataX) == 0:
71
  st.error("Not enough data to create input-output pairs. Please provide more text data.")
72
  else:
73
+ X = np.reshape(dataX, (len(dataX), seq_length))
74
  X = X / float(len(chars))
75
  Y = np.array(dataY)
76
 
 
90
  h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
91
  epoch_loss = 0
92
  for i in range(len(dataX)):
93
+ inputs = X_tensor[i].unsqueeze(0).unsqueeze(-1)
94
  targets = Y_tensor[i].unsqueeze(0)
95
 
96
  # Forward pass