Commit
·
ac36d2b
1
Parent(s):
fdb8b30
1st upload
Browse files- .gitattributes +1 -0
- app.py +300 -0
- wikitext-2/wiki.test.tokens +0 -0
- wikitext-2/wiki.train.tokens +3 -0
- wikitext-2/wiki.valid.tokens +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
wikitext-2/wiki.train.tokens filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
import re
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
|
11 |
+
# Assuming TransformerEncoder and TransformerDecoder are defined above
|
12 |
+
EMBEDDING_DIM = 16
|
13 |
+
HIDDEN_DIM = 16
|
14 |
+
LATENT_DIM = 16 # Dimension of the latent space
|
15 |
+
SEQ_LEN = 16 # Max length of the sequence
|
16 |
+
NHEAD = 4 # Number of heads in multi-head attention
|
17 |
+
NUM_LAYERS = 2 # Number of transformer layers
|
18 |
+
|
19 |
+
# Gumbel softmax temperature
|
20 |
+
TAU = 1.0
|
21 |
+
|
22 |
+
LEARNING_RATE = 1e-3
|
23 |
+
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
torch.random.manual_seed(1024)
|
26 |
+
|
27 |
+
# Pass embeded into decoder instead of using the original x
|
28 |
+
class TransformerEncoder(nn.Module):
|
29 |
+
def __init__(self, d_model=EMBEDDING_DIM, nhead=NHEAD, num_layers=NUM_LAYERS):
|
30 |
+
super(TransformerEncoder, self).__init__()
|
31 |
+
self.embedding = nn.Embedding(VOCAB_SIZE, d_model)
|
32 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
33 |
+
nn.TransformerEncoderLayer(d_model, nhead), num_layers
|
34 |
+
)
|
35 |
+
self.fc_logits = nn.Linear(d_model, LATENT_DIM)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
embedded = self.embedding(x).permute(1, 0, 2) # Transformer expects seq_len, batch, features
|
39 |
+
transformed = self.transformer_encoder(embedded)
|
40 |
+
# Use the final state to predict logits for latent space
|
41 |
+
logits = self.fc_logits(transformed[-1])
|
42 |
+
return logits, embedded
|
43 |
+
|
44 |
+
|
45 |
+
class TransformerDecoder(nn.Module):
|
46 |
+
def __init__(self, d_model=EMBEDDING_DIM, nhead=NHEAD, num_layers=NUM_LAYERS):
|
47 |
+
super(TransformerDecoder, self).__init__()
|
48 |
+
self.embedding = nn.Embedding(VOCAB_SIZE, d_model)
|
49 |
+
self.transformer_decoder = nn.TransformerDecoder(
|
50 |
+
nn.TransformerDecoderLayer(d_model, nhead), num_layers
|
51 |
+
)
|
52 |
+
self.fc_out = nn.Linear(d_model, VOCAB_SIZE)
|
53 |
+
self.fc_z = nn.Linear(LATENT_DIM, d_model) # Convert z to feature size for transformer
|
54 |
+
|
55 |
+
def forward(self, embedded, z):
|
56 |
+
# embedded = self.embedding(x).permute(1, 0, 2) # Transformer expects [seq_len, batch, features], permute函数用于改变张量的维度顺序
|
57 |
+
z_adjusted = self.fc_z(z).unsqueeze(0)
|
58 |
+
output = self.transformer_decoder(embedded, z_adjusted)
|
59 |
+
return self.fc_out(output.permute(1, 0, 2))
|
60 |
+
|
61 |
+
|
62 |
+
class TransformerCVAE(nn.Module):
|
63 |
+
def __init__(self):
|
64 |
+
super(TransformerCVAE, self).__init__()
|
65 |
+
self.encoder = TransformerEncoder()
|
66 |
+
self.decoder = TransformerDecoder()
|
67 |
+
|
68 |
+
def reparameterize(self, logits):
|
69 |
+
return F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
logits, emb = self.encoder(x)
|
73 |
+
z = self.reparameterize(logits)
|
74 |
+
return self.decoder(emb, z), logits
|
75 |
+
|
76 |
+
def load_and_preprocess_wikitext(file_path):
|
77 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
78 |
+
text = f.read()
|
79 |
+
|
80 |
+
# Use regular expressions to split the text into sentences
|
81 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
82 |
+
sentences = [sentence.strip() for sentence in sentences]
|
83 |
+
|
84 |
+
return sentences
|
85 |
+
|
86 |
+
train_file_path = "wikitext-2/wiki.train.tokens"
|
87 |
+
test_file_path = "wikitext-2/wiki.test.tokens"
|
88 |
+
val_file_path = "wikitext-2/wiki.valid.tokens"
|
89 |
+
|
90 |
+
wikitext_sentences_train = load_and_preprocess_wikitext(train_file_path)
|
91 |
+
wikitext_sentences_test = load_and_preprocess_wikitext(test_file_path)
|
92 |
+
wikitext_sentences_val = load_and_preprocess_wikitext(val_file_path)
|
93 |
+
|
94 |
+
# Hyperparameters
|
95 |
+
BATCH_SIZE = 32
|
96 |
+
PAD_TOKEN = "<PAD>"
|
97 |
+
UNK_TOKEN = "<UNK>"
|
98 |
+
|
99 |
+
# Tokenize the data
|
100 |
+
tokens = [word for sentence in wikitext_sentences_train for word in sentence.split()]
|
101 |
+
|
102 |
+
# Build vocabulary
|
103 |
+
vocab = [PAD_TOKEN, UNK_TOKEN] + list(set(tokens))
|
104 |
+
word_index = {word: index for index, word in enumerate(vocab)}
|
105 |
+
# 添加新的tokens
|
106 |
+
SOS_TOKEN = '<SOS>'
|
107 |
+
EOS_TOKEN = '<EOS>'
|
108 |
+
word_index[SOS_TOKEN] = len(word_index)
|
109 |
+
word_index[EOS_TOKEN] = len(word_index)
|
110 |
+
vocab = {v: k for k, v in word_index.items()}
|
111 |
+
# Convert tokens to integers
|
112 |
+
def tokenize_and_encode(text):
|
113 |
+
return [word_index.get(word, word_index[UNK_TOKEN]) for word in text.split()]
|
114 |
+
|
115 |
+
encoded_data_train = [tokenize_and_encode(sentence) for sentence in wikitext_sentences_train]
|
116 |
+
|
117 |
+
# Create a PyTorch Dataset
|
118 |
+
class WikiDataset(Dataset):
|
119 |
+
def __init__(self, data, sequence_length):
|
120 |
+
self.data = data
|
121 |
+
self.sequence_length = sequence_length
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.data)
|
125 |
+
|
126 |
+
def __getitem__(self, idx):
|
127 |
+
sample = self.data[idx]
|
128 |
+
if len(sample) < self.sequence_length:
|
129 |
+
sample.extend([word_index[PAD_TOKEN]] * (self.sequence_length - len(sample)))
|
130 |
+
else:
|
131 |
+
sample = sample[:self.sequence_length]
|
132 |
+
return torch.tensor(sample)
|
133 |
+
|
134 |
+
# dataset = WikiDataset(encoded_data_train, SEQUENCE_LENGTH)
|
135 |
+
# dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
136 |
+
# Split the data into train and validation sets
|
137 |
+
dataset = WikiDataset(encoded_data_train, SEQ_LEN)
|
138 |
+
train_size = int(0.8 * len(dataset))
|
139 |
+
val_size = len(dataset) - train_size
|
140 |
+
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
141 |
+
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
142 |
+
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
143 |
+
|
144 |
+
VOCAB_SIZE = len(vocab)
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
class MultiMultiSignalingGame:
|
150 |
+
def __init__(self, senders: list, receivers: list, optimizer, criterion):
|
151 |
+
self.senders = senders
|
152 |
+
self.receivers = receivers
|
153 |
+
self.optimizer = optimizer
|
154 |
+
self.criterion = criterion
|
155 |
+
|
156 |
+
def play_round(self, states):
|
157 |
+
all_decoded_outputs = []
|
158 |
+
all_logits = []
|
159 |
+
interactions = []
|
160 |
+
|
161 |
+
for i, sender in enumerate(self.senders):
|
162 |
+
# Sender encodes the state
|
163 |
+
logits, emb = sender(states[i])
|
164 |
+
all_logits.append(logits)
|
165 |
+
z = F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1)
|
166 |
+
|
167 |
+
_, input_sentence_ids = torch.max(states[i], dim=1)
|
168 |
+
input_sentence_ids = input_sentence_ids.cpu().numpy()
|
169 |
+
input_sentence = ' '.join([vocab[idx] for idx in input_sentence_ids])
|
170 |
+
|
171 |
+
# Each receiver decodes the signal from the sender
|
172 |
+
for j, receiver in enumerate(self.receivers):
|
173 |
+
decoded_output = receiver(emb, z)
|
174 |
+
all_decoded_outputs.append(decoded_output)
|
175 |
+
|
176 |
+
_, output_sentence_ids = torch.max(decoded_output[0], dim=1)
|
177 |
+
output_sentence_ids = output_sentence_ids.cpu().numpy()
|
178 |
+
output_sentence = ' '.join([vocab[idx] for idx in output_sentence_ids])
|
179 |
+
|
180 |
+
interactions.append((i, j, input_sentence, output_sentence))
|
181 |
+
|
182 |
+
# Calculate loss
|
183 |
+
loss, recon_loss, kld_loss = self.compute_loss(states, all_decoded_outputs, all_logits, beta=1.0)
|
184 |
+
|
185 |
+
# Update model parameters
|
186 |
+
self.optimizer.zero_grad()
|
187 |
+
loss.backward()
|
188 |
+
self.optimizer.step()
|
189 |
+
|
190 |
+
return loss.item(), recon_loss.item(), kld_loss.item(), interactions
|
191 |
+
|
192 |
+
def compute_loss(self, original_states, decoded_states, logits, beta):
|
193 |
+
recon_loss = sum([self.criterion(decoded_state.view(-1, VOCAB_SIZE), original_state.view(-1))
|
194 |
+
for original_state, decoded_state in zip(original_states * len(self.receivers), decoded_states)])
|
195 |
+
|
196 |
+
# Calculate KLD loss
|
197 |
+
kld_losses = []
|
198 |
+
for logit in logits:
|
199 |
+
mean, logvar = torch.chunk(logit, 2, dim=-1)
|
200 |
+
kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
|
201 |
+
kld_losses.append(kld_loss)
|
202 |
+
|
203 |
+
return recon_loss + beta * sum(kld_losses), recon_loss, sum(kld_losses)
|
204 |
+
|
205 |
+
|
206 |
+
def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
|
207 |
+
para_checker = st.empty()
|
208 |
+
para_checker.text(f"NUM_SENDERS: {NUM_SENDERS}, NUM_RECEIVERS: {NUM_RECEIVERS}, num_rounds: {num_rounds}, EMBEDDING_DIM: {EMBEDDING_DIM}, HIDDEN_DIM: {HIDDEN_DIM}, LATENT_DIM: {LATENT_DIM}, SEQ_LEN: {SEQ_LEN}, TAU: {TAU}, nhead: {NHEAD}, num_layers: {NUM_LAYERS}, BATCH_SIZE: {BATCH_SIZE}")
|
209 |
+
senders = [TransformerEncoder().to(device) for _ in range(NUM_SENDERS)]
|
210 |
+
receivers = [TransformerDecoder().to(device) for _ in range(NUM_RECEIVERS)]
|
211 |
+
|
212 |
+
params = [list(sender.parameters()) for sender in senders]
|
213 |
+
params.extend([list(receiver.parameters()) for receiver in receivers])
|
214 |
+
# optimizer = torch.optim.Adam([param for sublist in params for param in sublist], lr=0.001)
|
215 |
+
if OPTMIZER == "Adam":
|
216 |
+
optimizer = torch.optim.Adam([param for sublist in params for param in sublist], lr=LEARNING_RATE)
|
217 |
+
elif OPTMIZER == "AdamW":
|
218 |
+
optimizer = torch.optim.AdamW([param for sublist in params for param in sublist], lr=LEARNING_RATE)
|
219 |
+
elif OPTMIZER == "SGD":
|
220 |
+
optimizer = torch.optim.SGD([param for sublist in params for param in sublist], lr=LEARNING_RATE)
|
221 |
+
|
222 |
+
criterion = torch.nn.CrossEntropyLoss()
|
223 |
+
|
224 |
+
game = MultiMultiSignalingGame(senders, receivers, optimizer, criterion)
|
225 |
+
|
226 |
+
losses = []
|
227 |
+
recon_losses = []
|
228 |
+
kld_losses = []
|
229 |
+
input_sentences = []
|
230 |
+
output_sentences = []
|
231 |
+
|
232 |
+
# Use Streamlit's progress bar
|
233 |
+
progress_bar = st.progress(0)
|
234 |
+
loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
|
235 |
+
interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
|
236 |
+
|
237 |
+
for round in range(num_rounds):
|
238 |
+
states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
|
239 |
+
loss, recon_loss, kld_loss, interactions = game.play_round(states)
|
240 |
+
losses.append(loss)
|
241 |
+
recon_losses.append(recon_loss)
|
242 |
+
kld_losses.append(kld_loss)
|
243 |
+
# 刷新显示每轮的损失
|
244 |
+
fig, ax = plt.subplots()
|
245 |
+
ax.plot(losses, label='Total Losses', color='blue')
|
246 |
+
ax.plot(recon_losses, label='Reconstruction Losses', color='green')
|
247 |
+
ax.plot(kld_losses, label='KLD Losses', color='red')
|
248 |
+
ax.set_xlabel('Round')
|
249 |
+
ax.set_ylabel('Loss')
|
250 |
+
ax.legend()
|
251 |
+
loss_plot_placeholder.pyplot(fig)
|
252 |
+
# 刷新显示每次交互的句子
|
253 |
+
interaction_str = "\n\n".join([f"Sender {i} -> Receiver {j}\nSend(encode): {input_sentence}\nReceive(decode): {output_sentence}"
|
254 |
+
for i, j, input_sentence, output_sentence in interactions])
|
255 |
+
interactions_placeholder.text(interaction_str)
|
256 |
+
|
257 |
+
progress_bar.progress(round / num_rounds)
|
258 |
+
|
259 |
+
# Dynamic plotting of the losses
|
260 |
+
fig, ax = plt.subplots()
|
261 |
+
ax.plot(losses, label='Total Losses', color='blue')
|
262 |
+
ax.plot(recon_losses, label='Reconstruction Losses', color='green')
|
263 |
+
ax.plot(kld_losses, label='KLD Losses', color='red')
|
264 |
+
ax.set_xlabel('Round')
|
265 |
+
ax.set_ylabel('Loss')
|
266 |
+
ax.legend()
|
267 |
+
st.pyplot(fig)
|
268 |
+
|
269 |
+
# Streamlit UI
|
270 |
+
st.title('Multi-Agents Signaling Game')
|
271 |
+
|
272 |
+
NUM_SENDERS = st.sidebar.slider("NUM_SENDERS", 1, 10, 2)
|
273 |
+
NUM_RECEIVERS = st.sidebar.slider("NUM_RECEIVERS", 1, 10, 2)
|
274 |
+
num_rounds = st.sidebar.slider("num_rounds", 1000, 100000, 10000, 1000)
|
275 |
+
|
276 |
+
use_cosine_annealing = st.sidebar.checkbox("Use Annealing")
|
277 |
+
if use_cosine_annealing:
|
278 |
+
annealing_strategy = st.sidebar.selectbox("Annealing Strategy", ["linear", "cosine"])
|
279 |
+
TAU = st.sidebar.slider("Start Temp.", 0.1, 10.0, 1.0)
|
280 |
+
final_tau = st.sidebar.slider("Final Temp.", 0.1, 10.0, 1.0)
|
281 |
+
else:
|
282 |
+
annealing_strategy = None
|
283 |
+
TAU = st.sidebar.slider("TAU", 0.1, 10.0, 1.0)
|
284 |
+
|
285 |
+
optimizer_options = ["Adam", "AdamW", "SGD"]
|
286 |
+
OPTMIZER = st.sidebar.selectbox("Optimizer", optimizer_options)
|
287 |
+
LEARNING_RATE = st.sidebar.slider("Learning Rate", 1e-5, 1e-2, 1e-3)
|
288 |
+
|
289 |
+
|
290 |
+
EMBEDDING_DIM = st.sidebar.slider("EMBEDDING_DIM", 1, 128, 16)
|
291 |
+
HIDDEN_DIM = st.sidebar.slider("HIDDEN_DIM", 1, 128, 16)
|
292 |
+
LATENT_DIM = st.sidebar.slider("LATENT_DIM", 1, 128, 16)
|
293 |
+
SEQ_LEN = st.sidebar.slider("SEQ_LEN", 1, 128, 16)
|
294 |
+
# TAU = st.sidebar.slider("TAU", 0.1, 10.0, 1.0)
|
295 |
+
NHEAD = st.sidebar.slider("nhead", 1, 8, 4)
|
296 |
+
NUM_LAYERS = st.sidebar.slider("num_layers", 1, 6, 2)
|
297 |
+
BATCH_SIZE = st.sidebar.slider("BATCH_SIZE", 1, 128, 32)
|
298 |
+
|
299 |
+
if st.sidebar.button('Start'):
|
300 |
+
train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds)
|
wikitext-2/wiki.test.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
wikitext-2/wiki.train.tokens
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e9fa1ad55b1c2c95b08e37dd8e653f638fac2c6de904b79e813611eefbc985f
|
3 |
+
size 10797148
|
wikitext-2/wiki.valid.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|