eaglelandsonce commited on
Commit
8a73c6f
·
verified ·
1 Parent(s): 0f1974a

Update pages/17_RNN.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN.py +39 -21
pages/17_RNN.py CHANGED
@@ -3,6 +3,14 @@ import torch
3
  import torch.nn as nn
4
  import numpy as np
5
 
 
 
 
 
 
 
 
 
6
  # Define the dataset
7
  sequence = "In the vast expanse of the digital realm."
8
  chars = list(set(sequence))
@@ -36,11 +44,20 @@ class RNN(nn.Module):
36
  # Hyperparameters
37
  n_hidden = 128
38
  learning_rate = 0.005
39
- n_epochs = 500
40
 
41
  # Initialize the model, loss function, and optimizer
42
  rnn = RNN(vocab_size, n_hidden, vocab_size)
43
  criterion = nn.NLLLoss()
 
 
 
 
 
 
 
 
 
 
44
  optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
45
 
46
  def char_tensor(char):
@@ -50,26 +67,31 @@ def char_tensor(char):
50
  tensor[0][char_to_idx[char]] = 1
51
  return tensor
52
 
53
- # Training loop
54
- for epoch in range(n_epochs):
55
- hidden = rnn.init_hidden()
56
- rnn.zero_grad()
57
- loss = 0
 
58
 
59
- for i in range(data_size - 1):
60
- input_char = char_tensor(sequence[i])
61
- target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long)
62
-
63
- output, hidden = rnn(input_char, hidden)
64
- loss += criterion(output, target_char)
 
 
 
65
 
66
- loss.backward()
67
- optimizer.step()
68
 
69
- if epoch % 10 == 0:
70
- print(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')
71
 
72
- print("Training complete.")
 
 
73
 
74
  def generate(start_char, predict_len=100):
75
  if start_char not in char_to_idx:
@@ -88,10 +110,6 @@ def generate(start_char, predict_len=100):
88
 
89
  return predicted_str
90
 
91
- # Streamlit interface
92
- st.title('RNN Character Prediction')
93
- st.write('This app uses a Recurrent Neural Network (RNN) to predict the next character in a given string.')
94
-
95
  start_char = st.text_input('Enter a starting character:', 'h')
96
  predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
97
 
 
3
  import torch.nn as nn
4
  import numpy as np
5
 
6
+ # Introduction
7
+ st.title('RNN Character Prediction')
8
+ st.write("""
9
+ This app demonstrates how to train a Recurrent Neural Network (RNN) to predict the next character in a given string.
10
+ The RNN learns the sequence of characters from a provided text and generates text based on the learned patterns.
11
+ You can choose different options for training the RNN and see how it affects the generated text.
12
+ """)
13
+
14
  # Define the dataset
15
  sequence = "In the vast expanse of the digital realm."
16
  chars = list(set(sequence))
 
44
  # Hyperparameters
45
  n_hidden = 128
46
  learning_rate = 0.005
 
47
 
48
  # Initialize the model, loss function, and optimizer
49
  rnn = RNN(vocab_size, n_hidden, vocab_size)
50
  criterion = nn.NLLLoss()
51
+
52
+ # Define training options
53
+ options = {
54
+ 'Quick Train (100 epochs)': 100,
55
+ 'Medium Train (500 epochs)': 500,
56
+ 'Long Train (1000 epochs)': 1000
57
+ }
58
+
59
+ train_option = st.selectbox('Select training option:', list(options.keys()))
60
+ n_epochs = options[train_option]
61
  optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
62
 
63
  def char_tensor(char):
 
67
  tensor[0][char_to_idx[char]] = 1
68
  return tensor
69
 
70
+ # Training function
71
+ def train_model(n_epochs):
72
+ for epoch in range(n_epochs):
73
+ hidden = rnn.init_hidden()
74
+ rnn.zero_grad()
75
+ loss = 0
76
 
77
+ for i in range(data_size - 1):
78
+ input_char = char_tensor(sequence[i])
79
+ target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long)
80
+
81
+ output, hidden = rnn(input_char, hidden)
82
+ loss += criterion(output, target_char)
83
+
84
+ loss.backward()
85
+ optimizer.step()
86
 
87
+ if epoch % 10 == 0:
88
+ st.write(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')
89
 
90
+ st.write("Training complete.")
 
91
 
92
+ # Train the model
93
+ if st.button('Train Model'):
94
+ train_model(n_epochs)
95
 
96
  def generate(start_char, predict_len=100):
97
  if start_char not in char_to_idx:
 
110
 
111
  return predicted_str
112
 
 
 
 
 
113
  start_char = st.text_input('Enter a starting character:', 'h')
114
  predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
115