Suburst's picture
Upload 27 files
f8bd4d2 verified
raw
history blame
10.8 kB
import torch
import torch.nn as nn
from torch.nn import Transformer
from loguru import logger
import os
DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
def get_sinusoid_encoding_table(max_len, d_model):
pos_encoding = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pos_encoding[:, 0::2] = torch.sin(position * div_term)
pos_encoding[:, 1::2] = torch.cos(position * div_term)
pos_encoding = pos_encoding.unsqueeze(0)
return nn.Parameter(pos_encoding, requires_grad=False)
class TransformerModel(nn.Module):
def __init__(self, config, src_tokenizer, tgt_tokenizer):
super(TransformerModel, self).__init__()
# 从配置中获取参数
global DEVICE
config = config.config
DEVICE = torch.device(config['DEVICE'])
self.batch_size = config.get('BATCH_SIZE')
self.epochs = config.get('EPOCHS')
self.learning_rate = config.get('LEARNING_RATE')
self.max_seq_len = config.get('MAX_SEQ_LEN')
self.d_model = config.get('D_MODEL')
self.n_head = config.get('N_HEAD')
self.num_layers = config.get('NUM_LAYERS')
self.dim_feedforward = config.get('DIM_FEEDFORWARD')
self.dropout = config.get('DROPOUT')
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.transformer = Transformer(d_model=self.d_model,
nhead=self.n_head,
num_encoder_layers=self.num_layers,
num_decoder_layers=self.num_layers,
dim_feedforward=self.dim_feedforward,
dropout = self.dropout
)
# 这里可以添加一些用于嵌入和最终输出的层,例如
#self.embedding = nn.Embedding(self.max_seq_len, self.d_model)
src_vocab_size = len(src_tokenizer.word2int)
tgt_vocab_size = len(tgt_tokenizer.word2int)
self.src_embedding = nn.Embedding(src_vocab_size, self.d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, self.d_model)
self.fc = nn.Linear(self.d_model, tgt_vocab_size)
# 添加位置编码
self.pos_encoding = get_sinusoid_encoding_table(self.max_seq_len, self.d_model)
self.to(DEVICE)
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
# 对输入进行嵌入
src_emb = self.src_embedding(src) * (self.d_model ** 0.5)
tgt_emb = self.tgt_embedding(tgt) * (self.d_model ** 0.5)
# 添加位置编码
src_emb = src_emb+ self.pos_encoding[0, :src_emb.size(1)]
tgt_emb = tgt_emb + self.pos_encoding[0, :tgt_emb.size(1)]
# 期望src和tgt的形状为 (seq_len, batch_size)
output = self.transformer(
src_emb.permute(1, 0, 2),
tgt_emb.permute(1, 0, 2),
src_mask=src_mask,
tgt_mask=tgt_mask,
src_key_padding_mask=src_padding_mask,
tgt_key_padding_mask=tgt_padding_mask,
)
output = self.fc(output.permute(1, 0, 2))
return output
def decode_sentence(self, ids):
int2word = self.tgt_tokenizer.int2word
word2int = self.tgt_tokenizer.word2int
tokens = [int2word[id] for id in ids if id not in {word2int["<PAD>"], word2int["<BOS>"], word2int["<EOS>"]}]
return tokens, " ".join(tokens) # 中文不需要空格
def encode_sentence(self, sentence):
tokens = sentence.split()
word2int = self.src_tokenizer.word2int
max_len = self.max_seq_len
ids = [word2int.get(token, word2int["<UNK>"]) for token in tokens]
ids = [word2int["<BOS>"]] + ids[:max_len - 2] + [word2int["<EOS>"]]
ids += [word2int["<PAD>"]] * (max_len - len(ids))
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE) # 添加 batch 维度
def translate(self, sentence):
src_tensor = self.encode_sentence(sentence).to(DEVICE)
tgt_tensor = tgt_tensor = torch.tensor([self.tgt_tokenizer.word2int["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE)
for _ in range(self.max_seq_len):
# 生成目标序列的 mask
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE)
# 推理得到输出
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask)
# 取最后一个时间步的预测结果
next_token = output[:, -1, :].argmax(dim=-1).item()
# 将预测的 token 添加到目标序列中
tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1)
# 如果预测到 <EOS>,停止生成
if next_token == self.tgt_tokenizer.word2int["<EOS>"]:
break
return self.decode_sentence(tgt_tensor.squeeze(0).tolist())
def translate_no_unk(self, sentence):
src_tensor = self.encode_sentence(sentence)
tgt_tensor = torch.tensor([self.tgt_tokenizer.word2int["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE)
for _ in range(self.max_seq_len):
# 生成目标序列的 mask
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE)
# 推理得到输出
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask)
# 取最后一个时间步的预测结果
logits = output[:, -1, :]
sorted_indices = torch.argsort(logits, dim=-1, descending=True)
next_token = sorted_indices[0, 0].item()
unk_id = self.tgt_tokenizer.word2int["<UNK>"]
# 如果最大概率是 <UNK>,选择概率第二大的 token
if next_token == unk_id:
next_token = sorted_indices[0, 1].item()
# 将预测的 token 添加到目标序列中
tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1)
# 如果预测到 <EOS>,停止生成
if next_token == self.tgt_tokenizer.word2int["<EOS>"]:
break
return self.decode_sentence(tgt_tensor.squeeze(0).tolist())
def translate_beam_search(self, sentence, beam_size=3):
src_tensor = self.encode_sentence(sentence)
bos_id = self.tgt_tokenizer.word2int["<BOS>"]
eos_id = self.tgt_tokenizer.word2int["<EOS>"]
unk_id = self.tgt_tokenizer.word2int["<UNK>"]
# 初始化 beam
beams = [([bos_id], 0)] # (sequence, score)
for _ in range(self.max_seq_len):
new_beams = []
for seq, score in beams:
tgt_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE)
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask)
logits = output[:, -1, :]
probs = torch.softmax(logits, dim=-1).squeeze(0)
sorted_indices = torch.argsort(probs, dim=-1, descending=True)
top_indices = sorted_indices[:beam_size]
for idx in top_indices:
next_token = idx.item()
if next_token == unk_id:
continue
new_seq = seq + [next_token]
new_score = score + torch.log(probs[next_token]).item()
new_beams.append((new_seq, new_score))
# 按分数排序并选择前 beam_size 个
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否有句子以 <EOS> 结尾
ended_beams = [beam for beam in beams if beam[0][-1] == eos_id]
if ended_beams:
best_beam = max(ended_beams, key=lambda x: x[1])
return self.decode_sentence(best_beam[0])
# 如果没有以 <EOS> 结尾的句子,选择分数最高的
best_beam = max(beams, key=lambda x: x[1])
return self.decode_sentence(best_beam[0])
def translate_no_unk_beam_search(self, sentence, beam_size=3):
src_tensor = self.encode_sentence(sentence)
bos_id = self.tgt_tokenizer.word2int["<BOS>"]
eos_id = self.tgt_tokenizer.word2int["<EOS>"]
unk_id = self.tgt_tokenizer.word2int["<UNK>"]
# 初始化 beam
beams = [([bos_id], 0)] # (sequence, score)
for _ in range(self.max_seq_len):
new_beams = []
for seq, score in beams:
tgt_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE)
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask)
logits = output[:, -1, :]
probs = torch.softmax(logits, dim=-1).squeeze(0)
sorted_indices = torch.argsort(probs, dim=-1, descending=True)
top_indices = sorted_indices[:beam_size + 1] # 多取一个以防第一个是 <UNK>
valid_count = 0
for idx in top_indices:
next_token = idx.item()
if next_token == unk_id:
continue
new_seq = seq + [next_token]
new_score = score + torch.log(probs[next_token]).item()
new_beams.append((new_seq, new_score))
valid_count += 1
if valid_count >= beam_size:
break
# 按分数排序并选择前 beam_size 个
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否有句子以 <EOS> 结尾
ended_beams = [beam for beam in beams if beam[0][-1] == eos_id]
if ended_beams:
best_beam = max(ended_beams, key=lambda x: x[1])
return self.decode_sentence(best_beam[0])
# 如果没有以 <EOS> 结尾的句子,选择分数最高的
best_beam = max(beams, key=lambda x: x[1])
return self.decode_sentence(best_beam[0])