Spaces:
Running
Running
Update pages/17_RNN.py
Browse files- pages/17_RNN.py +39 -21
pages/17_RNN.py
CHANGED
@@ -3,6 +3,14 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
import numpy as np
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# Define the dataset
|
7 |
sequence = "In the vast expanse of the digital realm."
|
8 |
chars = list(set(sequence))
|
@@ -36,11 +44,20 @@ class RNN(nn.Module):
|
|
36 |
# Hyperparameters
|
37 |
n_hidden = 128
|
38 |
learning_rate = 0.005
|
39 |
-
n_epochs = 500
|
40 |
|
41 |
# Initialize the model, loss function, and optimizer
|
42 |
rnn = RNN(vocab_size, n_hidden, vocab_size)
|
43 |
criterion = nn.NLLLoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
|
45 |
|
46 |
def char_tensor(char):
|
@@ -50,26 +67,31 @@ def char_tensor(char):
|
|
50 |
tensor[0][char_to_idx[char]] = 1
|
51 |
return tensor
|
52 |
|
53 |
-
# Training
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
print(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')
|
71 |
|
72 |
-
|
|
|
|
|
73 |
|
74 |
def generate(start_char, predict_len=100):
|
75 |
if start_char not in char_to_idx:
|
@@ -88,10 +110,6 @@ def generate(start_char, predict_len=100):
|
|
88 |
|
89 |
return predicted_str
|
90 |
|
91 |
-
# Streamlit interface
|
92 |
-
st.title('RNN Character Prediction')
|
93 |
-
st.write('This app uses a Recurrent Neural Network (RNN) to predict the next character in a given string.')
|
94 |
-
|
95 |
start_char = st.text_input('Enter a starting character:', 'h')
|
96 |
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
|
97 |
|
|
|
3 |
import torch.nn as nn
|
4 |
import numpy as np
|
5 |
|
6 |
+
# Introduction
|
7 |
+
st.title('RNN Character Prediction')
|
8 |
+
st.write("""
|
9 |
+
This app demonstrates how to train a Recurrent Neural Network (RNN) to predict the next character in a given string.
|
10 |
+
The RNN learns the sequence of characters from a provided text and generates text based on the learned patterns.
|
11 |
+
You can choose different options for training the RNN and see how it affects the generated text.
|
12 |
+
""")
|
13 |
+
|
14 |
# Define the dataset
|
15 |
sequence = "In the vast expanse of the digital realm."
|
16 |
chars = list(set(sequence))
|
|
|
44 |
# Hyperparameters
|
45 |
n_hidden = 128
|
46 |
learning_rate = 0.005
|
|
|
47 |
|
48 |
# Initialize the model, loss function, and optimizer
|
49 |
rnn = RNN(vocab_size, n_hidden, vocab_size)
|
50 |
criterion = nn.NLLLoss()
|
51 |
+
|
52 |
+
# Define training options
|
53 |
+
options = {
|
54 |
+
'Quick Train (100 epochs)': 100,
|
55 |
+
'Medium Train (500 epochs)': 500,
|
56 |
+
'Long Train (1000 epochs)': 1000
|
57 |
+
}
|
58 |
+
|
59 |
+
train_option = st.selectbox('Select training option:', list(options.keys()))
|
60 |
+
n_epochs = options[train_option]
|
61 |
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
|
62 |
|
63 |
def char_tensor(char):
|
|
|
67 |
tensor[0][char_to_idx[char]] = 1
|
68 |
return tensor
|
69 |
|
70 |
+
# Training function
|
71 |
+
def train_model(n_epochs):
|
72 |
+
for epoch in range(n_epochs):
|
73 |
+
hidden = rnn.init_hidden()
|
74 |
+
rnn.zero_grad()
|
75 |
+
loss = 0
|
76 |
|
77 |
+
for i in range(data_size - 1):
|
78 |
+
input_char = char_tensor(sequence[i])
|
79 |
+
target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long)
|
80 |
+
|
81 |
+
output, hidden = rnn(input_char, hidden)
|
82 |
+
loss += criterion(output, target_char)
|
83 |
+
|
84 |
+
loss.backward()
|
85 |
+
optimizer.step()
|
86 |
|
87 |
+
if epoch % 10 == 0:
|
88 |
+
st.write(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')
|
89 |
|
90 |
+
st.write("Training complete.")
|
|
|
91 |
|
92 |
+
# Train the model
|
93 |
+
if st.button('Train Model'):
|
94 |
+
train_model(n_epochs)
|
95 |
|
96 |
def generate(start_char, predict_len=100):
|
97 |
if start_char not in char_to_idx:
|
|
|
110 |
|
111 |
return predicted_str
|
112 |
|
|
|
|
|
|
|
|
|
113 |
start_char = st.text_input('Enter a starting character:', 'h')
|
114 |
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
|
115 |
|