from src.config import Config from src.tokenizer import Tokenizer from src.model import TransformerModel from src.dataset import TranslationDataset from src.eval import evaluate from torch.utils.data import DataLoader import torch import torch.nn as nn import torch.optim as optim import json from torch.utils.data import Dataset, DataLoader import wandb from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm from loguru import logger import matplotlib.pyplot as plt import wandb import os #os.environ["CUDA_VISIBLE_DEVICES"] = "5" torch.cuda.set_device(5) print(os.environ.get("CUDA_VISIBLE_DEVICES")) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f'使用设备:{DEVICE}') def train_epoch(model, data_loader, optimizer, criterion): model.train() epoch_loss = 0 for src, tgt in data_loader: src, tgt = src.to(DEVICE), tgt.to(DEVICE) tgt_input = tgt[:, :-1] tgt_output = tgt[:, 1:] tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(DEVICE) optimizer.zero_grad() output = model(src, tgt_input, tgt_mask=tgt_mask) output = output.reshape(-1, output.shape[-1]) tgt_output = tgt_output.reshape(-1) loss = criterion(output, tgt_output) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() epoch_loss += loss.item() return epoch_loss / len(data_loader) def train(): config = Config('./config.yaml') exp_name = config.config['EXP'] if os.path.exists(f'./logs/{exp_name}.log'): os.remove(f'./logs/{exp_name}.log') wandb.init(project="final_transformer_translation", name=config.config['EXP']) cn_tokenizer = Tokenizer('./wordtable/word2int_cn.json', './wordtable/int2word_cn.json') en_tokenizer = Tokenizer('./wordtable/word2int_en.json', './wordtable/int2word_en.json') model = TransformerModel(config, en_tokenizer, cn_tokenizer).to(DEVICE) train_dataset = TranslationDataset(file_path='./data/training.txt', src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer,max_len=config.config['MAX_SEQ_LEN']) test_dataset = TranslationDataset(file_path='./data/testing.txt', src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer,max_len=config.config['MAX_SEQ_LEN']) train_loader = DataLoader(train_dataset, batch_size=config.config['BATCH_SIZE'], shuffle=True) optimizer = optim.AdamW(model.parameters(), lr=config.config['LEARNING_RATE']) criterion = nn.CrossEntropyLoss(ignore_index=en_tokenizer.word2int[""]) train_losses = [] for epoch in tqdm(range(config.config['EPOCHS'] + 1)): train_loss = train_epoch(model, train_loader, optimizer, criterion) wandb.log({"train_loss":train_loss, "epoch":epoch}) logger.info(train_loss) if epoch % config.config['EVAL_PER_EPOCH'] == 0 and epoch > 0: none_socre, exp_score, floor_score, average_score = evaluate(model, config, 'eval') wandb.log({'none_score':none_socre, 'exp_score':exp_score, 'floor_score':floor_score, 'average_score':average_score}) model_path = f'./models/{exp_name}_epoch_{epoch}.pth' torch.save(model.state_dict(), model_path) logger.info(f"模型已保存至 {model_path}") model.train() none_socre, exp_score, floor_score, average_score = evaluate(model, config, 'test') wandb.finish() if __name__=='__main__': train()