Suburst's picture
Upload 27 files
f8bd4d2 verified
from src.config import Config
from src.tokenizer import Tokenizer
from src.model import TransformerModel
from src.dataset import TranslationDataset
from src.eval import evaluate
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import json
from torch.utils.data import Dataset, DataLoader
import wandb
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
from loguru import logger
import matplotlib.pyplot as plt
import wandb
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "5"
torch.cuda.set_device(5)
print(os.environ.get("CUDA_VISIBLE_DEVICES"))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f'使用设备:{DEVICE}')
def train_epoch(model, data_loader, optimizer, criterion):
model.train()
epoch_loss = 0
for src, tgt in data_loader:
src, tgt = src.to(DEVICE), tgt.to(DEVICE)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(DEVICE)
optimizer.zero_grad()
output = model(src, tgt_input, tgt_mask=tgt_mask)
output = output.reshape(-1, output.shape[-1])
tgt_output = tgt_output.reshape(-1)
loss = criterion(output, tgt_output)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(data_loader)
def train():
config = Config('./config.yaml')
exp_name = config.config['EXP']
if os.path.exists(f'./logs/{exp_name}.log'):
os.remove(f'./logs/{exp_name}.log')
wandb.init(project="final_transformer_translation", name=config.config['EXP'])
cn_tokenizer = Tokenizer('./wordtable/word2int_cn.json', './wordtable/int2word_cn.json')
en_tokenizer = Tokenizer('./wordtable/word2int_en.json', './wordtable/int2word_en.json')
model = TransformerModel(config, en_tokenizer, cn_tokenizer).to(DEVICE)
train_dataset = TranslationDataset(file_path='./data/training.txt', src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer,max_len=config.config['MAX_SEQ_LEN'])
test_dataset = TranslationDataset(file_path='./data/testing.txt', src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer,max_len=config.config['MAX_SEQ_LEN'])
train_loader = DataLoader(train_dataset, batch_size=config.config['BATCH_SIZE'], shuffle=True)
optimizer = optim.AdamW(model.parameters(), lr=config.config['LEARNING_RATE'])
criterion = nn.CrossEntropyLoss(ignore_index=en_tokenizer.word2int["<PAD>"])
train_losses = []
for epoch in tqdm(range(config.config['EPOCHS'] + 1)):
train_loss = train_epoch(model, train_loader, optimizer, criterion)
wandb.log({"train_loss":train_loss, "epoch":epoch})
logger.info(train_loss)
if epoch % config.config['EVAL_PER_EPOCH'] == 0 and epoch > 0:
none_socre, exp_score, floor_score, average_score = evaluate(model, config, 'eval')
wandb.log({'none_score':none_socre, 'exp_score':exp_score, 'floor_score':floor_score, 'average_score':average_score})
model_path = f'./models/{exp_name}_epoch_{epoch}.pth'
torch.save(model.state_dict(), model_path)
logger.info(f"模型已保存至 {model_path}")
model.train()
none_socre, exp_score, floor_score, average_score = evaluate(model, config, 'test')
wandb.finish()
if __name__=='__main__':
train()