VoyagerYuan commited on
Commit
ac36d2b
·
1 Parent(s): fdb8b30

1st upload

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ wikitext-2/wiki.train.tokens filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import re
9
+ import plotly.graph_objects as go
10
+
11
+ # Assuming TransformerEncoder and TransformerDecoder are defined above
12
+ EMBEDDING_DIM = 16
13
+ HIDDEN_DIM = 16
14
+ LATENT_DIM = 16 # Dimension of the latent space
15
+ SEQ_LEN = 16 # Max length of the sequence
16
+ NHEAD = 4 # Number of heads in multi-head attention
17
+ NUM_LAYERS = 2 # Number of transformer layers
18
+
19
+ # Gumbel softmax temperature
20
+ TAU = 1.0
21
+
22
+ LEARNING_RATE = 1e-3
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ torch.random.manual_seed(1024)
26
+
27
+ # Pass embeded into decoder instead of using the original x
28
+ class TransformerEncoder(nn.Module):
29
+ def __init__(self, d_model=EMBEDDING_DIM, nhead=NHEAD, num_layers=NUM_LAYERS):
30
+ super(TransformerEncoder, self).__init__()
31
+ self.embedding = nn.Embedding(VOCAB_SIZE, d_model)
32
+ self.transformer_encoder = nn.TransformerEncoder(
33
+ nn.TransformerEncoderLayer(d_model, nhead), num_layers
34
+ )
35
+ self.fc_logits = nn.Linear(d_model, LATENT_DIM)
36
+
37
+ def forward(self, x):
38
+ embedded = self.embedding(x).permute(1, 0, 2) # Transformer expects seq_len, batch, features
39
+ transformed = self.transformer_encoder(embedded)
40
+ # Use the final state to predict logits for latent space
41
+ logits = self.fc_logits(transformed[-1])
42
+ return logits, embedded
43
+
44
+
45
+ class TransformerDecoder(nn.Module):
46
+ def __init__(self, d_model=EMBEDDING_DIM, nhead=NHEAD, num_layers=NUM_LAYERS):
47
+ super(TransformerDecoder, self).__init__()
48
+ self.embedding = nn.Embedding(VOCAB_SIZE, d_model)
49
+ self.transformer_decoder = nn.TransformerDecoder(
50
+ nn.TransformerDecoderLayer(d_model, nhead), num_layers
51
+ )
52
+ self.fc_out = nn.Linear(d_model, VOCAB_SIZE)
53
+ self.fc_z = nn.Linear(LATENT_DIM, d_model) # Convert z to feature size for transformer
54
+
55
+ def forward(self, embedded, z):
56
+ # embedded = self.embedding(x).permute(1, 0, 2) # Transformer expects [seq_len, batch, features], permute函数用于改变张量的维度顺序
57
+ z_adjusted = self.fc_z(z).unsqueeze(0)
58
+ output = self.transformer_decoder(embedded, z_adjusted)
59
+ return self.fc_out(output.permute(1, 0, 2))
60
+
61
+
62
+ class TransformerCVAE(nn.Module):
63
+ def __init__(self):
64
+ super(TransformerCVAE, self).__init__()
65
+ self.encoder = TransformerEncoder()
66
+ self.decoder = TransformerDecoder()
67
+
68
+ def reparameterize(self, logits):
69
+ return F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1)
70
+
71
+ def forward(self, x):
72
+ logits, emb = self.encoder(x)
73
+ z = self.reparameterize(logits)
74
+ return self.decoder(emb, z), logits
75
+
76
+ def load_and_preprocess_wikitext(file_path):
77
+ with open(file_path, 'r', encoding='utf-8') as f:
78
+ text = f.read()
79
+
80
+ # Use regular expressions to split the text into sentences
81
+ sentences = re.split(r'(?<=[.!?])\s+', text)
82
+ sentences = [sentence.strip() for sentence in sentences]
83
+
84
+ return sentences
85
+
86
+ train_file_path = "wikitext-2/wiki.train.tokens"
87
+ test_file_path = "wikitext-2/wiki.test.tokens"
88
+ val_file_path = "wikitext-2/wiki.valid.tokens"
89
+
90
+ wikitext_sentences_train = load_and_preprocess_wikitext(train_file_path)
91
+ wikitext_sentences_test = load_and_preprocess_wikitext(test_file_path)
92
+ wikitext_sentences_val = load_and_preprocess_wikitext(val_file_path)
93
+
94
+ # Hyperparameters
95
+ BATCH_SIZE = 32
96
+ PAD_TOKEN = "<PAD>"
97
+ UNK_TOKEN = "<UNK>"
98
+
99
+ # Tokenize the data
100
+ tokens = [word for sentence in wikitext_sentences_train for word in sentence.split()]
101
+
102
+ # Build vocabulary
103
+ vocab = [PAD_TOKEN, UNK_TOKEN] + list(set(tokens))
104
+ word_index = {word: index for index, word in enumerate(vocab)}
105
+ # 添加新的tokens
106
+ SOS_TOKEN = '<SOS>'
107
+ EOS_TOKEN = '<EOS>'
108
+ word_index[SOS_TOKEN] = len(word_index)
109
+ word_index[EOS_TOKEN] = len(word_index)
110
+ vocab = {v: k for k, v in word_index.items()}
111
+ # Convert tokens to integers
112
+ def tokenize_and_encode(text):
113
+ return [word_index.get(word, word_index[UNK_TOKEN]) for word in text.split()]
114
+
115
+ encoded_data_train = [tokenize_and_encode(sentence) for sentence in wikitext_sentences_train]
116
+
117
+ # Create a PyTorch Dataset
118
+ class WikiDataset(Dataset):
119
+ def __init__(self, data, sequence_length):
120
+ self.data = data
121
+ self.sequence_length = sequence_length
122
+
123
+ def __len__(self):
124
+ return len(self.data)
125
+
126
+ def __getitem__(self, idx):
127
+ sample = self.data[idx]
128
+ if len(sample) < self.sequence_length:
129
+ sample.extend([word_index[PAD_TOKEN]] * (self.sequence_length - len(sample)))
130
+ else:
131
+ sample = sample[:self.sequence_length]
132
+ return torch.tensor(sample)
133
+
134
+ # dataset = WikiDataset(encoded_data_train, SEQUENCE_LENGTH)
135
+ # dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
136
+ # Split the data into train and validation sets
137
+ dataset = WikiDataset(encoded_data_train, SEQ_LEN)
138
+ train_size = int(0.8 * len(dataset))
139
+ val_size = len(dataset) - train_size
140
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
141
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
142
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
143
+
144
+ VOCAB_SIZE = len(vocab)
145
+
146
+
147
+
148
+
149
+ class MultiMultiSignalingGame:
150
+ def __init__(self, senders: list, receivers: list, optimizer, criterion):
151
+ self.senders = senders
152
+ self.receivers = receivers
153
+ self.optimizer = optimizer
154
+ self.criterion = criterion
155
+
156
+ def play_round(self, states):
157
+ all_decoded_outputs = []
158
+ all_logits = []
159
+ interactions = []
160
+
161
+ for i, sender in enumerate(self.senders):
162
+ # Sender encodes the state
163
+ logits, emb = sender(states[i])
164
+ all_logits.append(logits)
165
+ z = F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1)
166
+
167
+ _, input_sentence_ids = torch.max(states[i], dim=1)
168
+ input_sentence_ids = input_sentence_ids.cpu().numpy()
169
+ input_sentence = ' '.join([vocab[idx] for idx in input_sentence_ids])
170
+
171
+ # Each receiver decodes the signal from the sender
172
+ for j, receiver in enumerate(self.receivers):
173
+ decoded_output = receiver(emb, z)
174
+ all_decoded_outputs.append(decoded_output)
175
+
176
+ _, output_sentence_ids = torch.max(decoded_output[0], dim=1)
177
+ output_sentence_ids = output_sentence_ids.cpu().numpy()
178
+ output_sentence = ' '.join([vocab[idx] for idx in output_sentence_ids])
179
+
180
+ interactions.append((i, j, input_sentence, output_sentence))
181
+
182
+ # Calculate loss
183
+ loss, recon_loss, kld_loss = self.compute_loss(states, all_decoded_outputs, all_logits, beta=1.0)
184
+
185
+ # Update model parameters
186
+ self.optimizer.zero_grad()
187
+ loss.backward()
188
+ self.optimizer.step()
189
+
190
+ return loss.item(), recon_loss.item(), kld_loss.item(), interactions
191
+
192
+ def compute_loss(self, original_states, decoded_states, logits, beta):
193
+ recon_loss = sum([self.criterion(decoded_state.view(-1, VOCAB_SIZE), original_state.view(-1))
194
+ for original_state, decoded_state in zip(original_states * len(self.receivers), decoded_states)])
195
+
196
+ # Calculate KLD loss
197
+ kld_losses = []
198
+ for logit in logits:
199
+ mean, logvar = torch.chunk(logit, 2, dim=-1)
200
+ kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
201
+ kld_losses.append(kld_loss)
202
+
203
+ return recon_loss + beta * sum(kld_losses), recon_loss, sum(kld_losses)
204
+
205
+
206
+ def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
207
+ para_checker = st.empty()
208
+ para_checker.text(f"NUM_SENDERS: {NUM_SENDERS}, NUM_RECEIVERS: {NUM_RECEIVERS}, num_rounds: {num_rounds}, EMBEDDING_DIM: {EMBEDDING_DIM}, HIDDEN_DIM: {HIDDEN_DIM}, LATENT_DIM: {LATENT_DIM}, SEQ_LEN: {SEQ_LEN}, TAU: {TAU}, nhead: {NHEAD}, num_layers: {NUM_LAYERS}, BATCH_SIZE: {BATCH_SIZE}")
209
+ senders = [TransformerEncoder().to(device) for _ in range(NUM_SENDERS)]
210
+ receivers = [TransformerDecoder().to(device) for _ in range(NUM_RECEIVERS)]
211
+
212
+ params = [list(sender.parameters()) for sender in senders]
213
+ params.extend([list(receiver.parameters()) for receiver in receivers])
214
+ # optimizer = torch.optim.Adam([param for sublist in params for param in sublist], lr=0.001)
215
+ if OPTMIZER == "Adam":
216
+ optimizer = torch.optim.Adam([param for sublist in params for param in sublist], lr=LEARNING_RATE)
217
+ elif OPTMIZER == "AdamW":
218
+ optimizer = torch.optim.AdamW([param for sublist in params for param in sublist], lr=LEARNING_RATE)
219
+ elif OPTMIZER == "SGD":
220
+ optimizer = torch.optim.SGD([param for sublist in params for param in sublist], lr=LEARNING_RATE)
221
+
222
+ criterion = torch.nn.CrossEntropyLoss()
223
+
224
+ game = MultiMultiSignalingGame(senders, receivers, optimizer, criterion)
225
+
226
+ losses = []
227
+ recon_losses = []
228
+ kld_losses = []
229
+ input_sentences = []
230
+ output_sentences = []
231
+
232
+ # Use Streamlit's progress bar
233
+ progress_bar = st.progress(0)
234
+ loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
235
+ interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
236
+
237
+ for round in range(num_rounds):
238
+ states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
239
+ loss, recon_loss, kld_loss, interactions = game.play_round(states)
240
+ losses.append(loss)
241
+ recon_losses.append(recon_loss)
242
+ kld_losses.append(kld_loss)
243
+ # 刷新显示每轮的损失
244
+ fig, ax = plt.subplots()
245
+ ax.plot(losses, label='Total Losses', color='blue')
246
+ ax.plot(recon_losses, label='Reconstruction Losses', color='green')
247
+ ax.plot(kld_losses, label='KLD Losses', color='red')
248
+ ax.set_xlabel('Round')
249
+ ax.set_ylabel('Loss')
250
+ ax.legend()
251
+ loss_plot_placeholder.pyplot(fig)
252
+ # 刷新显示每次交互的句子
253
+ interaction_str = "\n\n".join([f"Sender {i} -> Receiver {j}\nSend(encode): {input_sentence}\nReceive(decode): {output_sentence}"
254
+ for i, j, input_sentence, output_sentence in interactions])
255
+ interactions_placeholder.text(interaction_str)
256
+
257
+ progress_bar.progress(round / num_rounds)
258
+
259
+ # Dynamic plotting of the losses
260
+ fig, ax = plt.subplots()
261
+ ax.plot(losses, label='Total Losses', color='blue')
262
+ ax.plot(recon_losses, label='Reconstruction Losses', color='green')
263
+ ax.plot(kld_losses, label='KLD Losses', color='red')
264
+ ax.set_xlabel('Round')
265
+ ax.set_ylabel('Loss')
266
+ ax.legend()
267
+ st.pyplot(fig)
268
+
269
+ # Streamlit UI
270
+ st.title('Multi-Agents Signaling Game')
271
+
272
+ NUM_SENDERS = st.sidebar.slider("NUM_SENDERS", 1, 10, 2)
273
+ NUM_RECEIVERS = st.sidebar.slider("NUM_RECEIVERS", 1, 10, 2)
274
+ num_rounds = st.sidebar.slider("num_rounds", 1000, 100000, 10000, 1000)
275
+
276
+ use_cosine_annealing = st.sidebar.checkbox("Use Annealing")
277
+ if use_cosine_annealing:
278
+ annealing_strategy = st.sidebar.selectbox("Annealing Strategy", ["linear", "cosine"])
279
+ TAU = st.sidebar.slider("Start Temp.", 0.1, 10.0, 1.0)
280
+ final_tau = st.sidebar.slider("Final Temp.", 0.1, 10.0, 1.0)
281
+ else:
282
+ annealing_strategy = None
283
+ TAU = st.sidebar.slider("TAU", 0.1, 10.0, 1.0)
284
+
285
+ optimizer_options = ["Adam", "AdamW", "SGD"]
286
+ OPTMIZER = st.sidebar.selectbox("Optimizer", optimizer_options)
287
+ LEARNING_RATE = st.sidebar.slider("Learning Rate", 1e-5, 1e-2, 1e-3)
288
+
289
+
290
+ EMBEDDING_DIM = st.sidebar.slider("EMBEDDING_DIM", 1, 128, 16)
291
+ HIDDEN_DIM = st.sidebar.slider("HIDDEN_DIM", 1, 128, 16)
292
+ LATENT_DIM = st.sidebar.slider("LATENT_DIM", 1, 128, 16)
293
+ SEQ_LEN = st.sidebar.slider("SEQ_LEN", 1, 128, 16)
294
+ # TAU = st.sidebar.slider("TAU", 0.1, 10.0, 1.0)
295
+ NHEAD = st.sidebar.slider("nhead", 1, 8, 4)
296
+ NUM_LAYERS = st.sidebar.slider("num_layers", 1, 6, 2)
297
+ BATCH_SIZE = st.sidebar.slider("BATCH_SIZE", 1, 128, 32)
298
+
299
+ if st.sidebar.button('Start'):
300
+ train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds)
wikitext-2/wiki.test.tokens ADDED
The diff for this file is too large to render. See raw diff
 
wikitext-2/wiki.train.tokens ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e9fa1ad55b1c2c95b08e37dd8e653f638fac2c6de904b79e813611eefbc985f
3
+ size 10797148
wikitext-2/wiki.valid.tokens ADDED
The diff for this file is too large to render. See raw diff