|
from src.config import Config |
|
from src.tokenizer import Tokenizer |
|
from src.model import TransformerModel |
|
from src.dataset import TranslationDataset |
|
from src.eval import evaluate |
|
from torch.utils.data import DataLoader |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import json |
|
from torch.utils.data import Dataset, DataLoader |
|
import wandb |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from tqdm import tqdm |
|
from loguru import logger |
|
import matplotlib.pyplot as plt |
|
import wandb |
|
import os |
|
|
|
|
|
torch.cuda.set_device(5) |
|
print(os.environ.get("CUDA_VISIBLE_DEVICES")) |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f'使用设备:{DEVICE}') |
|
|
|
|
|
def train_epoch(model, data_loader, optimizer, criterion): |
|
model.train() |
|
epoch_loss = 0 |
|
for src, tgt in data_loader: |
|
src, tgt = src.to(DEVICE), tgt.to(DEVICE) |
|
tgt_input = tgt[:, :-1] |
|
tgt_output = tgt[:, 1:] |
|
|
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(DEVICE) |
|
|
|
optimizer.zero_grad() |
|
output = model(src, tgt_input, tgt_mask=tgt_mask) |
|
output = output.reshape(-1, output.shape[-1]) |
|
tgt_output = tgt_output.reshape(-1) |
|
|
|
loss = criterion(output, tgt_output) |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
|
|
return epoch_loss / len(data_loader) |
|
|
|
|
|
def train(): |
|
config = Config('./config.yaml') |
|
exp_name = config.config['EXP'] |
|
if os.path.exists(f'./logs/{exp_name}.log'): |
|
os.remove(f'./logs/{exp_name}.log') |
|
wandb.init(project="final_transformer_translation", name=config.config['EXP']) |
|
cn_tokenizer = Tokenizer('./wordtable/word2int_cn.json', './wordtable/int2word_cn.json') |
|
en_tokenizer = Tokenizer('./wordtable/word2int_en.json', './wordtable/int2word_en.json') |
|
model = TransformerModel(config, en_tokenizer, cn_tokenizer).to(DEVICE) |
|
train_dataset = TranslationDataset(file_path='./data/training.txt', src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer,max_len=config.config['MAX_SEQ_LEN']) |
|
test_dataset = TranslationDataset(file_path='./data/testing.txt', src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer,max_len=config.config['MAX_SEQ_LEN']) |
|
train_loader = DataLoader(train_dataset, batch_size=config.config['BATCH_SIZE'], shuffle=True) |
|
optimizer = optim.AdamW(model.parameters(), lr=config.config['LEARNING_RATE']) |
|
criterion = nn.CrossEntropyLoss(ignore_index=en_tokenizer.word2int["<PAD>"]) |
|
|
|
train_losses = [] |
|
|
|
for epoch in tqdm(range(config.config['EPOCHS'] + 1)): |
|
train_loss = train_epoch(model, train_loader, optimizer, criterion) |
|
wandb.log({"train_loss":train_loss, "epoch":epoch}) |
|
logger.info(train_loss) |
|
if epoch % config.config['EVAL_PER_EPOCH'] == 0 and epoch > 0: |
|
none_socre, exp_score, floor_score, average_score = evaluate(model, config, 'eval') |
|
wandb.log({'none_score':none_socre, 'exp_score':exp_score, 'floor_score':floor_score, 'average_score':average_score}) |
|
model_path = f'./models/{exp_name}_epoch_{epoch}.pth' |
|
torch.save(model.state_dict(), model_path) |
|
logger.info(f"模型已保存至 {model_path}") |
|
model.train() |
|
none_socre, exp_score, floor_score, average_score = evaluate(model, config, 'test') |
|
|
|
wandb.finish() |
|
|
|
if __name__=='__main__': |
|
train() |