|
|
|
""" |
|
Evaluation script for Next Word Prediction model. |
|
Loads the trained model and SentencePiece model, |
|
prepares the validation dataset, and computes: |
|
- Perplexity (using average loss) |
|
- Top-k Accuracy (e.g., top-3 accuracy) |
|
Usage: |
|
python evaluate_next_word.py --data_path data.csv \ |
|
--sp_model_path spm.model --model_save_path best_model.pth \ |
|
[--batch_size 512] [--top_k 3] |
|
""" |
|
|
|
import os |
|
import sys |
|
import math |
|
import argparse |
|
import logging |
|
import pandas as pd |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
import sentencepiece as spm |
|
|
|
|
|
logging.basicConfig( |
|
stream=sys.stdout, |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
|
|
|
|
class NextWordSPDataset(Dataset): |
|
def __init__(self, sentences, sp): |
|
self.sp = sp |
|
self.samples = [] |
|
self.prepare_samples(sentences) |
|
|
|
def prepare_samples(self, sentences): |
|
for sentence in sentences: |
|
token_ids = self.sp.encode(sentence.strip(), out_type=int) |
|
|
|
for i in range(1, len(token_ids)): |
|
self.samples.append(( |
|
torch.tensor(token_ids[:i], dtype=torch.long), |
|
torch.tensor(token_ids[i], dtype=torch.long) |
|
)) |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
return self.samples[idx] |
|
|
|
def sp_collate_fn(batch): |
|
inputs, targets = zip(*batch) |
|
padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=0) |
|
targets = torch.stack(targets) |
|
return padded_inputs, targets |
|
|
|
|
|
class LSTMNextWordModel(nn.Module): |
|
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout, fc_dropout=0.3): |
|
super(LSTMNextWordModel, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) |
|
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, |
|
batch_first=True, dropout=dropout) |
|
self.layer_norm = nn.LayerNorm(hidden_dim) |
|
self.dropout = nn.Dropout(fc_dropout) |
|
self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) |
|
self.fc2 = nn.Linear(hidden_dim // 2, vocab_size) |
|
|
|
def forward(self, x): |
|
emb = self.embedding(x) |
|
output, _ = self.lstm(emb) |
|
last_output = output[:, -1, :] |
|
norm_output = self.layer_norm(last_output) |
|
norm_output = self.dropout(norm_output) |
|
fc1_out = torch.relu(self.fc1(norm_output)) |
|
fc1_out = self.dropout(fc1_out) |
|
logits = self.fc2(fc1_out) |
|
return logits |
|
|
|
|
|
def evaluate_perplexity(model, dataloader, criterion, device): |
|
model.eval() |
|
total_loss = 0.0 |
|
total_samples = 0 |
|
with torch.no_grad(): |
|
for inputs, targets in dataloader: |
|
inputs = inputs.to(device) |
|
targets = targets.to(device) |
|
logits = model(inputs) |
|
loss = criterion(logits, targets) |
|
total_loss += loss.item() * inputs.size(0) |
|
total_samples += inputs.size(0) |
|
avg_loss = total_loss / total_samples |
|
perplexity = math.exp(avg_loss) |
|
return perplexity |
|
|
|
def evaluate_topk_accuracy(model, dataloader, k, device): |
|
model.eval() |
|
correct = 0 |
|
total = 0 |
|
with torch.no_grad(): |
|
for inputs, targets in dataloader: |
|
inputs = inputs.to(device) |
|
targets = targets.to(device) |
|
logits = model(inputs) |
|
|
|
_, topk_indices = torch.topk(logits, k, dim=-1) |
|
for i in range(len(targets)): |
|
if targets[i] in topk_indices[i]: |
|
correct += 1 |
|
total += targets.size(0) |
|
accuracy = correct / total if total > 0 else 0 |
|
return accuracy |
|
|
|
|
|
def main(args): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logging.info("Using device: %s", device) |
|
|
|
|
|
if not os.path.exists(args.sp_model_path): |
|
logging.error("SentencePiece model not found at %s", args.sp_model_path) |
|
sys.exit(1) |
|
sp = spm.SentencePieceProcessor() |
|
sp.load(args.sp_model_path) |
|
logging.info("Loaded SentencePiece model from %s", args.sp_model_path) |
|
|
|
|
|
if not os.path.exists(args.data_path): |
|
logging.error("Data CSV file not found at %s", args.data_path) |
|
sys.exit(1) |
|
df = pd.read_csv(args.data_path) |
|
if 'data' not in df.columns: |
|
logging.error("CSV file must contain a 'data' column.") |
|
sys.exit(1) |
|
sentences = df['data'].tolist() |
|
|
|
split_index = int(len(sentences) * 0.9) |
|
valid_sentences = sentences[split_index:] |
|
logging.info("Validation sentences: %d", len(valid_sentences)) |
|
|
|
valid_dataset = NextWordSPDataset(valid_sentences, sp) |
|
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, |
|
shuffle=False, collate_fn=sp_collate_fn) |
|
|
|
|
|
vocab_size = sp.get_piece_size() |
|
embed_dim = args.embed_dim |
|
hidden_dim = args.hidden_dim |
|
num_layers = args.num_layers |
|
dropout = args.dropout |
|
model = LSTMNextWordModel(vocab_size, embed_dim, hidden_dim, num_layers, dropout) |
|
model.to(device) |
|
|
|
|
|
if not os.path.exists(args.model_save_path): |
|
logging.error("Model checkpoint not found at %s", args.model_save_path) |
|
sys.exit(1) |
|
model.load_state_dict(torch.load(args.model_save_path, map_location=device)) |
|
logging.info("Loaded model checkpoint from %s", args.model_save_path) |
|
|
|
|
|
|
|
class LabelSmoothingLoss(nn.Module): |
|
def __init__(self, smoothing=0.1): |
|
super(LabelSmoothingLoss, self).__init__() |
|
self.smoothing = smoothing |
|
|
|
def forward(self, pred, target): |
|
confidence = 1.0 - self.smoothing |
|
vocab_size = pred.size(1) |
|
one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1) |
|
smoothed_target = one_hot * confidence + self.smoothing / (vocab_size - 1) |
|
log_prob = torch.log_softmax(pred, dim=-1) |
|
loss = -(smoothed_target * log_prob).sum(dim=1).mean() |
|
return loss |
|
|
|
criterion = LabelSmoothingLoss(smoothing=args.label_smoothing) |
|
|
|
|
|
val_perplexity = evaluate_perplexity(model, valid_loader, criterion, device) |
|
topk_accuracy = evaluate_topk_accuracy(model, valid_loader, args.top_k, device) |
|
|
|
logging.info("Validation Perplexity: %.4f", val_perplexity) |
|
logging.info("Top-%d Accuracy: %.4f", args.top_k, topk_accuracy) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Evaluate Next Word Prediction Model") |
|
parser.add_argument('--data_path', type=str, default='data.csv', help="Path to CSV file with a 'data' column") |
|
parser.add_argument('--sp_model_path', type=str, default='spm.model', help="Path to the SentencePiece model file") |
|
parser.add_argument('--model_save_path', type=str, default='best_model.pth', help="Path to the trained model checkpoint") |
|
parser.add_argument('--batch_size', type=int, default=512, help="Batch size for evaluation") |
|
parser.add_argument('--top_k', type=int, default=3, help="Top-k value for computing accuracy") |
|
|
|
parser.add_argument('--embed_dim', type=int, default=256, help="Embedding dimension") |
|
parser.add_argument('--hidden_dim', type=int, default=256, help="Hidden dimension") |
|
parser.add_argument('--num_layers', type=int, default=2, help="Number of LSTM layers") |
|
parser.add_argument('--dropout', type=float, default=0.3, help="Dropout rate") |
|
parser.add_argument('--label_smoothing', type=float, default=0.1, help="Label smoothing factor") |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|