VoyagerYuan
commited on
Commit
·
530a351
1
Parent(s):
10ea276
Update app.py
Browse files
app.py
CHANGED
@@ -143,9 +143,6 @@ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
143 |
|
144 |
VOCAB_SIZE = len(vocab)
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
class MultiMultiSignalingGame:
|
150 |
def __init__(self, senders: list, receivers: list, optimizer, criterion):
|
151 |
self.senders = senders
|
@@ -230,8 +227,8 @@ def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
|
|
230 |
output_sentences = []
|
231 |
|
232 |
# Use Streamlit's progress bar
|
233 |
-
progress_bar = st.progress(0)
|
234 |
loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
|
|
|
235 |
interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
|
236 |
|
237 |
for round in range(num_rounds):
|
@@ -249,13 +246,12 @@ def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
|
|
249 |
ax.set_ylabel('Loss')
|
250 |
ax.legend()
|
251 |
loss_plot_placeholder.pyplot(fig)
|
|
|
252 |
# 刷新显示每次交互的句子
|
253 |
interaction_str = "\n\n".join([f"Sender {i} -> Receiver {j}\nSend(encode): {input_sentence}\nReceive(decode): {output_sentence}"
|
254 |
for i, j, input_sentence, output_sentence in interactions])
|
255 |
interactions_placeholder.text(interaction_str)
|
256 |
|
257 |
-
progress_bar.progress(round / num_rounds)
|
258 |
-
|
259 |
# Dynamic plotting of the losses
|
260 |
fig, ax = plt.subplots()
|
261 |
ax.plot(losses, label='Total Losses', color='blue')
|
|
|
143 |
|
144 |
VOCAB_SIZE = len(vocab)
|
145 |
|
|
|
|
|
|
|
146 |
class MultiMultiSignalingGame:
|
147 |
def __init__(self, senders: list, receivers: list, optimizer, criterion):
|
148 |
self.senders = senders
|
|
|
227 |
output_sentences = []
|
228 |
|
229 |
# Use Streamlit's progress bar
|
|
|
230 |
loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
|
231 |
+
progress_bar = st.progress(0)
|
232 |
interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
|
233 |
|
234 |
for round in range(num_rounds):
|
|
|
246 |
ax.set_ylabel('Loss')
|
247 |
ax.legend()
|
248 |
loss_plot_placeholder.pyplot(fig)
|
249 |
+
progress_bar.progress(round / num_rounds)
|
250 |
# 刷新显示每次交互的句子
|
251 |
interaction_str = "\n\n".join([f"Sender {i} -> Receiver {j}\nSend(encode): {input_sentence}\nReceive(decode): {output_sentence}"
|
252 |
for i, j, input_sentence, output_sentence in interactions])
|
253 |
interactions_placeholder.text(interaction_str)
|
254 |
|
|
|
|
|
255 |
# Dynamic plotting of the losses
|
256 |
fig, ax = plt.subplots()
|
257 |
ax.plot(losses, label='Total Losses', color='blue')
|