File size: 4,020 Bytes
de0d854
 
 
 
 
8a73c6f
 
 
 
 
 
 
 
68c5fd5
 
1f5dc84
 
 
 
 
c0bfa85
1f5dc84
 
 
 
de0d854
1f5dc84
de0d854
1f5dc84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a73c6f
 
 
 
 
 
 
 
 
 
1f5dc84
 
 
da75a25
 
1f5dc84
 
 
 
8a73c6f
 
 
 
 
 
1f5dc84
8a73c6f
 
 
 
 
 
 
 
 
1f5dc84
8a73c6f
 
1f5dc84
8a73c6f
1f5dc84
8a73c6f
 
 
1f5dc84
 
da75a25
 
1f5dc84
 
 
 
 
 
 
 
 
 
 
 
 
de0d854
1f5dc84
 
 
 
da75a25
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
import torch
import torch.nn as nn
import numpy as np

# Introduction
st.title('RNN Character Prediction')
st.write("""
This app demonstrates how to train a Recurrent Neural Network (RNN) to predict the next character in a given string. 
The RNN learns the sequence of characters from a provided text and generates text based on the learned patterns.
You can choose different options for training the RNN and see how it affects the generated text.
""")

# User input for the training string
sequence = st.text_area('Enter the training string:', 'In the vast expanse of the digital realm.')
chars = list(set(sequence))
data_size, vocab_size = len(sequence), len(chars)

# Create mappings from characters to indices and vice versa
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

# Convert the sequence to indices
indices = np.array([char_to_idx[ch] for ch in sequence])

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)

# Hyperparameters
n_hidden = 128
learning_rate = 0.005

# Initialize the model, loss function, and optimizer
rnn = RNN(vocab_size, n_hidden, vocab_size)
criterion = nn.NLLLoss()

# Define training options
options = {
    'Quick Train (100 epochs)': 100,
    'Medium Train (500 epochs)': 500,
    'Long Train (1000 epochs)': 1000
}

train_option = st.selectbox('Select training option:', list(options.keys()))
n_epochs = options[train_option]
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

def char_tensor(char):
    if char not in char_to_idx:
        raise ValueError(f"Character '{char}' not in vocabulary.")
    tensor = torch.zeros(1, vocab_size)
    tensor[0][char_to_idx[char]] = 1
    return tensor

# Training function
def train_model(n_epochs):
    for epoch in range(n_epochs):
        hidden = rnn.init_hidden()
        rnn.zero_grad()
        loss = 0

        for i in range(data_size - 1):
            input_char = char_tensor(sequence[i])
            target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long)

            output, hidden = rnn(input_char, hidden)
            loss += criterion(output, target_char)

        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            st.write(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')

    st.write("Training complete.")

# Train the model
if st.button('Train Model'):
    train_model(n_epochs)

def generate(start_char, predict_len=100):
    if start_char not in char_to_idx:
        raise ValueError(f"Start character '{start_char}' not in vocabulary.")
    hidden = rnn.init_hidden()
    input_char = char_tensor(start_char)
    predicted_str = start_char

    for _ in range(predict_len):
        output, hidden = rnn(input_char, hidden)
        topv, topi = output.topk(1)
        predicted_char_idx = topi[0][0].item()
        predicted_char = idx_to_char[predicted_char_idx]
        predicted_str += predicted_char
        input_char = char_tensor(predicted_char)

    return predicted_str

start_char = st.text_input('Enter a starting character:', 'h')
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)

if st.button('Generate Text'):
    try:
        generated_text = generate(start_char, predict_len)
        st.write('Generated Text:')
        st.text(generated_text)
    except ValueError as e:
        st.error(str(e))