Spaces:
Running
Running
File size: 4,020 Bytes
de0d854 8a73c6f 68c5fd5 1f5dc84 c0bfa85 1f5dc84 de0d854 1f5dc84 de0d854 1f5dc84 8a73c6f 1f5dc84 da75a25 1f5dc84 8a73c6f 1f5dc84 8a73c6f 1f5dc84 8a73c6f 1f5dc84 8a73c6f 1f5dc84 8a73c6f 1f5dc84 da75a25 1f5dc84 de0d854 1f5dc84 da75a25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
# Introduction
st.title('RNN Character Prediction')
st.write("""
This app demonstrates how to train a Recurrent Neural Network (RNN) to predict the next character in a given string.
The RNN learns the sequence of characters from a provided text and generates text based on the learned patterns.
You can choose different options for training the RNN and see how it affects the generated text.
""")
# User input for the training string
sequence = st.text_area('Enter the training string:', 'In the vast expanse of the digital realm.')
chars = list(set(sequence))
data_size, vocab_size = len(sequence), len(chars)
# Create mappings from characters to indices and vice versa
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
# Convert the sequence to indices
indices = np.array([char_to_idx[ch] for ch in sequence])
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def init_hidden(self):
return torch.zeros(1, self.hidden_size)
# Hyperparameters
n_hidden = 128
learning_rate = 0.005
# Initialize the model, loss function, and optimizer
rnn = RNN(vocab_size, n_hidden, vocab_size)
criterion = nn.NLLLoss()
# Define training options
options = {
'Quick Train (100 epochs)': 100,
'Medium Train (500 epochs)': 500,
'Long Train (1000 epochs)': 1000
}
train_option = st.selectbox('Select training option:', list(options.keys()))
n_epochs = options[train_option]
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
def char_tensor(char):
if char not in char_to_idx:
raise ValueError(f"Character '{char}' not in vocabulary.")
tensor = torch.zeros(1, vocab_size)
tensor[0][char_to_idx[char]] = 1
return tensor
# Training function
def train_model(n_epochs):
for epoch in range(n_epochs):
hidden = rnn.init_hidden()
rnn.zero_grad()
loss = 0
for i in range(data_size - 1):
input_char = char_tensor(sequence[i])
target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long)
output, hidden = rnn(input_char, hidden)
loss += criterion(output, target_char)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
st.write(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')
st.write("Training complete.")
# Train the model
if st.button('Train Model'):
train_model(n_epochs)
def generate(start_char, predict_len=100):
if start_char not in char_to_idx:
raise ValueError(f"Start character '{start_char}' not in vocabulary.")
hidden = rnn.init_hidden()
input_char = char_tensor(start_char)
predicted_str = start_char
for _ in range(predict_len):
output, hidden = rnn(input_char, hidden)
topv, topi = output.topk(1)
predicted_char_idx = topi[0][0].item()
predicted_char = idx_to_char[predicted_char_idx]
predicted_str += predicted_char
input_char = char_tensor(predicted_char)
return predicted_str
start_char = st.text_input('Enter a starting character:', 'h')
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
if st.button('Generate Text'):
try:
generated_text = generate(start_char, predict_len)
st.write('Generated Text:')
st.text(generated_text)
except ValueError as e:
st.error(str(e))
|