eaglelandsonce commited on
Commit
da75a25
·
verified ·
1 Parent(s): 1f5dc84

Update pages/17_RNN.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN.py +11 -4
pages/17_RNN.py CHANGED
@@ -10,7 +10,7 @@ data_size, vocab_size = len(sequence), len(chars)
10
 
11
  # Create mappings from characters to indices and vice versa
12
  char_to_idx = {ch: i for i, ch in enumerate(chars)}
13
- idx_to_char = {i: ch for i, ch in enumerate(chars)}
14
 
15
  # Convert the sequence to indices
16
  indices = np.array([char_to_idx[ch] for ch in sequence])
@@ -44,6 +44,8 @@ criterion = nn.NLLLoss()
44
  optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
45
 
46
  def char_tensor(char):
 
 
47
  tensor = torch.zeros(1, vocab_size)
48
  tensor[0][char_to_idx[char]] = 1
49
  return tensor
@@ -70,6 +72,8 @@ for epoch in range(n_epochs):
70
  print("Training complete.")
71
 
72
  def generate(start_char, predict_len=100):
 
 
73
  hidden = rnn.init_hidden()
74
  input_char = char_tensor(start_char)
75
  predicted_str = start_char
@@ -92,6 +96,9 @@ start_char = st.text_input('Enter a starting character:', 'h')
92
  predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
93
 
94
  if st.button('Generate Text'):
95
- generated_text = generate(start_char, predict_len)
96
- st.write('Generated Text:')
97
- st.text(generated_text)
 
 
 
 
10
 
11
  # Create mappings from characters to indices and vice versa
12
  char_to_idx = {ch: i for i, ch in enumerate(chars)}
13
+ idx_to_char = {i, ch for i, ch in enumerate(chars)}
14
 
15
  # Convert the sequence to indices
16
  indices = np.array([char_to_idx[ch] for ch in sequence])
 
44
  optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
45
 
46
  def char_tensor(char):
47
+ if char not in char_to_idx:
48
+ raise ValueError(f"Character '{char}' not in vocabulary.")
49
  tensor = torch.zeros(1, vocab_size)
50
  tensor[0][char_to_idx[char]] = 1
51
  return tensor
 
72
  print("Training complete.")
73
 
74
  def generate(start_char, predict_len=100):
75
+ if start_char not in char_to_idx:
76
+ raise ValueError(f"Start character '{start_char}' not in vocabulary.")
77
  hidden = rnn.init_hidden()
78
  input_char = char_tensor(start_char)
79
  predicted_str = start_char
 
96
  predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
97
 
98
  if st.button('Generate Text'):
99
+ try:
100
+ generated_text = generate(start_char, predict_len)
101
+ st.write('Generated Text:')
102
+ st.text(generated_text)
103
+ except ValueError as e:
104
+ st.error(str(e))