Spaces:
Running
Running
Update pages/17_RNN.py
Browse files- pages/17_RNN.py +11 -4
pages/17_RNN.py
CHANGED
@@ -10,7 +10,7 @@ data_size, vocab_size = len(sequence), len(chars)
|
|
10 |
|
11 |
# Create mappings from characters to indices and vice versa
|
12 |
char_to_idx = {ch: i for i, ch in enumerate(chars)}
|
13 |
-
idx_to_char = {i
|
14 |
|
15 |
# Convert the sequence to indices
|
16 |
indices = np.array([char_to_idx[ch] for ch in sequence])
|
@@ -44,6 +44,8 @@ criterion = nn.NLLLoss()
|
|
44 |
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
|
45 |
|
46 |
def char_tensor(char):
|
|
|
|
|
47 |
tensor = torch.zeros(1, vocab_size)
|
48 |
tensor[0][char_to_idx[char]] = 1
|
49 |
return tensor
|
@@ -70,6 +72,8 @@ for epoch in range(n_epochs):
|
|
70 |
print("Training complete.")
|
71 |
|
72 |
def generate(start_char, predict_len=100):
|
|
|
|
|
73 |
hidden = rnn.init_hidden()
|
74 |
input_char = char_tensor(start_char)
|
75 |
predicted_str = start_char
|
@@ -92,6 +96,9 @@ start_char = st.text_input('Enter a starting character:', 'h')
|
|
92 |
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
|
93 |
|
94 |
if st.button('Generate Text'):
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
10 |
|
11 |
# Create mappings from characters to indices and vice versa
|
12 |
char_to_idx = {ch: i for i, ch in enumerate(chars)}
|
13 |
+
idx_to_char = {i, ch for i, ch in enumerate(chars)}
|
14 |
|
15 |
# Convert the sequence to indices
|
16 |
indices = np.array([char_to_idx[ch] for ch in sequence])
|
|
|
44 |
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
|
45 |
|
46 |
def char_tensor(char):
|
47 |
+
if char not in char_to_idx:
|
48 |
+
raise ValueError(f"Character '{char}' not in vocabulary.")
|
49 |
tensor = torch.zeros(1, vocab_size)
|
50 |
tensor[0][char_to_idx[char]] = 1
|
51 |
return tensor
|
|
|
72 |
print("Training complete.")
|
73 |
|
74 |
def generate(start_char, predict_len=100):
|
75 |
+
if start_char not in char_to_idx:
|
76 |
+
raise ValueError(f"Start character '{start_char}' not in vocabulary.")
|
77 |
hidden = rnn.init_hidden()
|
78 |
input_char = char_tensor(start_char)
|
79 |
predicted_str = start_char
|
|
|
96 |
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
|
97 |
|
98 |
if st.button('Generate Text'):
|
99 |
+
try:
|
100 |
+
generated_text = generate(start_char, predict_len)
|
101 |
+
st.write('Generated Text:')
|
102 |
+
st.text(generated_text)
|
103 |
+
except ValueError as e:
|
104 |
+
st.error(str(e))
|