Spaces:
Running
Running
Update pages/17_RNN.py
Browse files- pages/17_RNN.py +13 -10
pages/17_RNN.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
import torch.optim as optim
|
5 |
from torchtext.data.utils import get_tokenizer
|
6 |
-
from torchtext.vocab import build_vocab_from_iterator
|
7 |
from torchtext.datasets import IMDB
|
8 |
from torch.utils.data import DataLoader, random_split
|
9 |
import matplotlib.pyplot as plt
|
@@ -29,6 +29,13 @@ class RNN(nn.Module):
|
|
29 |
out = self.fc(out[:, -1, :])
|
30 |
return out
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# Function to load the data
|
33 |
@st.cache_data
|
34 |
def load_data():
|
@@ -39,7 +46,7 @@ def load_data():
|
|
39 |
for _, text in data_iter:
|
40 |
yield tokenizer(text)
|
41 |
|
42 |
-
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
|
43 |
vocab.set_default_index(vocab["<unk>"])
|
44 |
|
45 |
# Define the text and label processing pipelines
|
@@ -57,13 +64,6 @@ def load_data():
|
|
57 |
train_texts, train_labels = process_data(train_iter)
|
58 |
test_texts, test_labels = process_data(test_iter)
|
59 |
|
60 |
-
# Create a custom collate function to pad sequences
|
61 |
-
def collate_batch(batch):
|
62 |
-
texts, labels = zip(*batch)
|
63 |
-
text_lengths = [len(text) for text in texts]
|
64 |
-
texts_padded = pad_sequence(texts, batch_first=True, padding_value=vocab["<pad>"])
|
65 |
-
return texts_padded, torch.tensor(labels, dtype=torch.float), text_lengths
|
66 |
-
|
67 |
# Create DataLoaders
|
68 |
train_dataset = list(zip(train_texts, train_labels))
|
69 |
test_dataset = list(zip(test_texts, test_labels))
|
@@ -124,6 +124,9 @@ def evaluate_network(net, iterator, criterion):
|
|
124 |
|
125 |
# Load the data
|
126 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
127 |
vocab, train_loader, valid_loader, test_loader = load_data()
|
128 |
|
129 |
# Streamlit interface
|
@@ -198,5 +201,5 @@ if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'
|
|
198 |
st.write('Ground Truth vs Predicted for Sample Texts')
|
199 |
for i, (text, true_label, predicted) in enumerate(samples):
|
200 |
st.write(f'Sample {i+1}')
|
201 |
-
st.text(' '.join([vocab.
|
202 |
st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')
|
|
|
3 |
import torch.nn as nn
|
4 |
import torch.optim as optim
|
5 |
from torchtext.data.utils import get_tokenizer
|
6 |
+
from torchtext.vocab import build_vocab_from_iterator
|
7 |
from torchtext.datasets import IMDB
|
8 |
from torch.utils.data import DataLoader, random_split
|
9 |
import matplotlib.pyplot as plt
|
|
|
29 |
out = self.fc(out[:, -1, :])
|
30 |
return out
|
31 |
|
32 |
+
# Create a custom collate function to pad sequences
|
33 |
+
def collate_batch(batch):
|
34 |
+
texts, labels = zip(*batch)
|
35 |
+
text_lengths = [len(text) for text in texts]
|
36 |
+
texts_padded = pad_sequence(texts, batch_first=True, padding_value=vocab["<pad>"])
|
37 |
+
return texts_padded, torch.tensor(labels, dtype=torch.float), text_lengths
|
38 |
+
|
39 |
# Function to load the data
|
40 |
@st.cache_data
|
41 |
def load_data():
|
|
|
46 |
for _, text in data_iter:
|
47 |
yield tokenizer(text)
|
48 |
|
49 |
+
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
50 |
vocab.set_default_index(vocab["<unk>"])
|
51 |
|
52 |
# Define the text and label processing pipelines
|
|
|
64 |
train_texts, train_labels = process_data(train_iter)
|
65 |
test_texts, test_labels = process_data(test_iter)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
# Create DataLoaders
|
68 |
train_dataset = list(zip(train_texts, train_labels))
|
69 |
test_dataset = list(zip(test_texts, test_labels))
|
|
|
124 |
|
125 |
# Load the data
|
126 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
127 |
+
|
128 |
+
# Display a loading message with some vertical space
|
129 |
+
st.markdown("<div style='margin-top: 50px;'><b>Loading data...</b></div>", unsafe_allow_html=True)
|
130 |
vocab, train_loader, valid_loader, test_loader = load_data()
|
131 |
|
132 |
# Streamlit interface
|
|
|
201 |
st.write('Ground Truth vs Predicted for Sample Texts')
|
202 |
for i, (text, true_label, predicted) in enumerate(samples):
|
203 |
st.write(f'Sample {i+1}')
|
204 |
+
st.text(' '.join([vocab.get_itos()[token] for token in text]))
|
205 |
st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')
|