Spaces:
Running
Running
Update pages/19_RNN_Shakespeare.py
Browse files- pages/19_RNN_Shakespeare.py +11 -7
pages/19_RNN_Shakespeare.py
CHANGED
@@ -66,11 +66,11 @@ if st.button("Train and Generate"):
|
|
66 |
|
67 |
X = np.reshape(dataX, (len(dataX), seq_length, 1))
|
68 |
X = X / float(len(chars))
|
69 |
-
Y = np.
|
70 |
|
71 |
# Convert to PyTorch tensors
|
72 |
X_tensor = torch.tensor(X, dtype=torch.float32)
|
73 |
-
Y_tensor = torch.tensor(
|
74 |
|
75 |
# Model initialization
|
76 |
model = LSTMModel(input_size=1, hidden_size=hidden_size, output_size=len(chars), num_layers=num_layers)
|
@@ -81,10 +81,11 @@ if st.button("Train and Generate"):
|
|
81 |
|
82 |
# Training the model
|
83 |
for epoch in range(num_epochs):
|
84 |
-
h = (torch.zeros(num_layers,
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
88 |
|
89 |
# Forward pass
|
90 |
outputs, h = model(inputs, h)
|
@@ -96,7 +97,10 @@ if st.button("Train and Generate"):
|
|
96 |
loss.backward()
|
97 |
optimizer.step()
|
98 |
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
# Text generation
|
102 |
generated_text = generate_text(model, start_string, generate_length, char_to_int, int_to_char, num_layers, hidden_size)
|
|
|
66 |
|
67 |
X = np.reshape(dataX, (len(dataX), seq_length, 1))
|
68 |
X = X / float(len(chars))
|
69 |
+
Y = np.array(dataY)
|
70 |
|
71 |
# Convert to PyTorch tensors
|
72 |
X_tensor = torch.tensor(X, dtype=torch.float32)
|
73 |
+
Y_tensor = torch.tensor(Y, dtype=torch.long)
|
74 |
|
75 |
# Model initialization
|
76 |
model = LSTMModel(input_size=1, hidden_size=hidden_size, output_size=len(chars), num_layers=num_layers)
|
|
|
81 |
|
82 |
# Training the model
|
83 |
for epoch in range(num_epochs):
|
84 |
+
h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
|
85 |
+
epoch_loss = 0
|
86 |
+
for i in range(len(dataX)):
|
87 |
+
inputs = X_tensor[i].unsqueeze(0)
|
88 |
+
targets = Y_tensor[i].unsqueeze(0)
|
89 |
|
90 |
# Forward pass
|
91 |
outputs, h = model(inputs, h)
|
|
|
97 |
loss.backward()
|
98 |
optimizer.step()
|
99 |
|
100 |
+
epoch_loss += loss.item()
|
101 |
+
|
102 |
+
avg_loss = epoch_loss / len(dataX)
|
103 |
+
st.write(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
104 |
|
105 |
# Text generation
|
106 |
generated_text = generate_text(model, start_string, generate_length, char_to_int, int_to_char, num_layers, hidden_size)
|