Spaces:
Running
Running
Update pages/17_RNN_News.py
Browse files- pages/17_RNN_News.py +3 -3
pages/17_RNN_News.py
CHANGED
@@ -6,6 +6,7 @@ from torchtext.data.utils import get_tokenizer
|
|
6 |
from torchtext.vocab import build_vocab_from_iterator
|
7 |
from torchtext.datasets import AG_NEWS
|
8 |
from torch.utils.data import DataLoader, random_split
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
import pandas as pd
|
11 |
import numpy as np
|
@@ -28,12 +29,11 @@ class RNN(nn.Module):
|
|
28 |
|
29 |
# Create a custom collate function to pad sequences
|
30 |
def collate_batch(batch):
|
31 |
-
label_list, text_list
|
32 |
for _label, _text in batch:
|
33 |
label_list.append(label_pipeline(_label))
|
34 |
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
|
35 |
text_list.append(processed_text)
|
36 |
-
lengths.append(processed_text.size(0))
|
37 |
labels = torch.tensor(label_list, dtype=torch.int64)
|
38 |
texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
|
39 |
return texts, labels
|
@@ -191,4 +191,4 @@ if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'
|
|
191 |
for i, (text, true_label, predicted) in enumerate(samples):
|
192 |
st.write(f'Sample {i+1}')
|
193 |
st.text(' '.join([vocab.get_itos()[token] for token in text]))
|
194 |
-
st.write(f'Ground Truth: {
|
|
|
6 |
from torchtext.vocab import build_vocab_from_iterator
|
7 |
from torchtext.datasets import AG_NEWS
|
8 |
from torch.utils.data import DataLoader, random_split
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
import matplotlib.pyplot as plt
|
11 |
import pandas as pd
|
12 |
import numpy as np
|
|
|
29 |
|
30 |
# Create a custom collate function to pad sequences
|
31 |
def collate_batch(batch):
|
32 |
+
label_list, text_list = [], []
|
33 |
for _label, _text in batch:
|
34 |
label_list.append(label_pipeline(_label))
|
35 |
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
|
36 |
text_list.append(processed_text)
|
|
|
37 |
labels = torch.tensor(label_list, dtype=torch.int64)
|
38 |
texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
|
39 |
return texts, labels
|
|
|
191 |
for i, (text, true_label, predicted) in enumerate(samples):
|
192 |
st.write(f'Sample {i+1}')
|
193 |
st.text(' '.join([vocab.get_itos()[token] for token in text]))
|
194 |
+
st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')
|