Spaces:
Running
Running
Update pages/17_RNN_News.py
Browse files- pages/17_RNN_News.py +5 -4
pages/17_RNN_News.py
CHANGED
@@ -52,6 +52,7 @@ def load_data():
|
|
52 |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
53 |
vocab.set_default_index(vocab["<unk>"])
|
54 |
|
|
|
55 |
global text_pipeline, label_pipeline
|
56 |
text_pipeline = lambda x: vocab(tokenizer(x))
|
57 |
label_pipeline = lambda x: int(x) - 1
|
@@ -80,7 +81,7 @@ def train_network(net, iterator, optimizer, criterion, epochs):
|
|
80 |
for texts, labels in iterator:
|
81 |
texts, labels = texts.to(device), labels.to(device)
|
82 |
optimizer.zero_grad()
|
83 |
-
predictions = net(texts)
|
84 |
loss = criterion(predictions, labels)
|
85 |
loss.backward()
|
86 |
optimizer.step()
|
@@ -102,7 +103,7 @@ def evaluate_network(net, iterator, criterion):
|
|
102 |
with torch.no_grad():
|
103 |
for texts, labels in iterator:
|
104 |
texts, labels = texts.to(device), labels.to(device)
|
105 |
-
predictions = net(texts)
|
106 |
loss = criterion(predictions, labels)
|
107 |
epoch_loss += loss.item()
|
108 |
_, predicted = torch.max(predictions, 1)
|
@@ -168,8 +169,8 @@ if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
|
|
168 |
# Display results in a table
|
169 |
st.write('Ground Truth vs Predicted')
|
170 |
results = pd.DataFrame({
|
171 |
-
'Ground Truth':
|
172 |
-
'Predicted':
|
173 |
})
|
174 |
st.table(results.head(50)) # Display first 50 results for brevity
|
175 |
|
|
|
52 |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
53 |
vocab.set_default_index(vocab["<unk>"])
|
54 |
|
55 |
+
# Define the text and label processing pipelines globally
|
56 |
global text_pipeline, label_pipeline
|
57 |
text_pipeline = lambda x: vocab(tokenizer(x))
|
58 |
label_pipeline = lambda x: int(x) - 1
|
|
|
81 |
for texts, labels in iterator:
|
82 |
texts, labels = texts.to(device), labels.to(device)
|
83 |
optimizer.zero_grad()
|
84 |
+
predictions = net(texts)
|
85 |
loss = criterion(predictions, labels)
|
86 |
loss.backward()
|
87 |
optimizer.step()
|
|
|
103 |
with torch.no_grad():
|
104 |
for texts, labels in iterator:
|
105 |
texts, labels = texts.to(device), labels.to(device)
|
106 |
+
predictions = net(texts)
|
107 |
loss = criterion(predictions, labels)
|
108 |
epoch_loss += loss.item()
|
109 |
_, predicted = torch.max(predictions, 1)
|
|
|
169 |
# Display results in a table
|
170 |
st.write('Ground Truth vs Predicted')
|
171 |
results = pd.DataFrame({
|
172 |
+
'Ground Truth': all_labels,
|
173 |
+
'Predicted': all_predictions
|
174 |
})
|
175 |
st.table(results.head(50)) # Display first 50 results for brevity
|
176 |
|