VoyagerYuan commited on
Commit
4ea15bc
·
1 Parent(s): 530a351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -200,7 +200,7 @@ class MultiMultiSignalingGame:
200
  return recon_loss + beta * sum(kld_losses), recon_loss, sum(kld_losses)
201
 
202
 
203
- def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
204
  # para_checker = st.empty()
205
  # para_checker.text(f"NUM_SENDERS: {NUM_SENDERS}, NUM_RECEIVERS: {NUM_RECEIVERS}, num_rounds: {num_rounds}, EMBEDDING_DIM: {EMBEDDING_DIM}, HIDDEN_DIM: {HIDDEN_DIM}, LATENT_DIM: {LATENT_DIM}, SEQ_LEN: {SEQ_LEN}, TAU: {TAU}, nhead: {NHEAD}, num_layers: {NUM_LAYERS}, BATCH_SIZE: {BATCH_SIZE}")
206
  senders = [TransformerEncoder().to(device) for _ in range(NUM_SENDERS)]
@@ -246,6 +246,9 @@ def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
246
  ax.set_ylabel('Loss')
247
  ax.legend()
248
  loss_plot_placeholder.pyplot(fig)
 
 
 
249
  progress_bar.progress(round / num_rounds)
250
  # 刷新显示每次交互的句子
251
  interaction_str = "\n\n".join([f"Sender {i} -> Receiver {j}\nSend(encode): {input_sentence}\nReceive(decode): {output_sentence}"
@@ -293,4 +296,4 @@ with advanced_settings:
293
  BATCH_SIZE = st.slider("BATCH_SIZE", 1, 128, 32)
294
 
295
  if st.sidebar.button('Start'):
296
- train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds)
 
200
  return recon_loss + beta * sum(kld_losses), recon_loss, sum(kld_losses)
201
 
202
 
203
+ def run_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
204
  # para_checker = st.empty()
205
  # para_checker.text(f"NUM_SENDERS: {NUM_SENDERS}, NUM_RECEIVERS: {NUM_RECEIVERS}, num_rounds: {num_rounds}, EMBEDDING_DIM: {EMBEDDING_DIM}, HIDDEN_DIM: {HIDDEN_DIM}, LATENT_DIM: {LATENT_DIM}, SEQ_LEN: {SEQ_LEN}, TAU: {TAU}, nhead: {NHEAD}, num_layers: {NUM_LAYERS}, BATCH_SIZE: {BATCH_SIZE}")
206
  senders = [TransformerEncoder().to(device) for _ in range(NUM_SENDERS)]
 
246
  ax.set_ylabel('Loss')
247
  ax.legend()
248
  loss_plot_placeholder.pyplot(fig)
249
+ # Close the figure to free up memory
250
+ plt.close(fig)
251
+
252
  progress_bar.progress(round / num_rounds)
253
  # 刷新显示每次交互的句子
254
  interaction_str = "\n\n".join([f"Sender {i} -> Receiver {j}\nSend(encode): {input_sentence}\nReceive(decode): {output_sentence}"
 
296
  BATCH_SIZE = st.slider("BATCH_SIZE", 1, 128, 32)
297
 
298
  if st.sidebar.button('Start'):
299
+ run_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds)