File size: 8,696 Bytes
bc5b02b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
#!/usr/bin/env python
"""
Evaluation script for Next Word Prediction model.
Loads the trained model and SentencePiece model,
prepares the validation dataset, and computes:
- Perplexity (using average loss)
- Top-k Accuracy (e.g., top-3 accuracy)
Usage:
python evaluate_next_word.py --data_path data.csv \
--sp_model_path spm.model --model_save_path best_model.pth \
[--batch_size 512] [--top_k 3]
"""
import os
import sys
import math
import argparse
import logging
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import sentencepiece as spm
# ---------------------- Logging Configuration ----------------------
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# ---------------------- Dataset Definition ----------------------
class NextWordSPDataset(Dataset):
def __init__(self, sentences, sp):
self.sp = sp
self.samples = []
self.prepare_samples(sentences)
def prepare_samples(self, sentences):
for sentence in sentences:
token_ids = self.sp.encode(sentence.strip(), out_type=int)
# For each sentence, create (input_sequence, target) pairs.
for i in range(1, len(token_ids)):
self.samples.append((
torch.tensor(token_ids[:i], dtype=torch.long),
torch.tensor(token_ids[i], dtype=torch.long)
))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def sp_collate_fn(batch):
inputs, targets = zip(*batch)
padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
targets = torch.stack(targets)
return padded_inputs, targets
# ---------------------- Model Definition ----------------------
class LSTMNextWordModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout, fc_dropout=0.3):
super(LSTMNextWordModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
batch_first=True, dropout=dropout)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(fc_dropout)
self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
self.fc2 = nn.Linear(hidden_dim // 2, vocab_size)
def forward(self, x):
emb = self.embedding(x)
output, _ = self.lstm(emb)
last_output = output[:, -1, :]
norm_output = self.layer_norm(last_output)
norm_output = self.dropout(norm_output)
fc1_out = torch.relu(self.fc1(norm_output))
fc1_out = self.dropout(fc1_out)
logits = self.fc2(fc1_out)
return logits
# ---------------------- Evaluation Functions ----------------------
def evaluate_perplexity(model, dataloader, criterion, device):
model.eval()
total_loss = 0.0
total_samples = 0
with torch.no_grad():
for inputs, targets in dataloader:
inputs = inputs.to(device)
targets = targets.to(device)
logits = model(inputs)
loss = criterion(logits, targets)
total_loss += loss.item() * inputs.size(0)
total_samples += inputs.size(0)
avg_loss = total_loss / total_samples
perplexity = math.exp(avg_loss)
return perplexity
def evaluate_topk_accuracy(model, dataloader, k, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in dataloader:
inputs = inputs.to(device)
targets = targets.to(device)
logits = model(inputs)
# Get top-k predictions for each sample
_, topk_indices = torch.topk(logits, k, dim=-1)
for i in range(len(targets)):
if targets[i] in topk_indices[i]:
correct += 1
total += targets.size(0)
accuracy = correct / total if total > 0 else 0
return accuracy
# ---------------------- Main Evaluation Routine ----------------------
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Using device: %s", device)
# Load SentencePiece model
if not os.path.exists(args.sp_model_path):
logging.error("SentencePiece model not found at %s", args.sp_model_path)
sys.exit(1)
sp = spm.SentencePieceProcessor()
sp.load(args.sp_model_path)
logging.info("Loaded SentencePiece model from %s", args.sp_model_path)
# Load data and prepare validation set
if not os.path.exists(args.data_path):
logging.error("Data CSV file not found at %s", args.data_path)
sys.exit(1)
df = pd.read_csv(args.data_path)
if 'data' not in df.columns:
logging.error("CSV file must contain a 'data' column.")
sys.exit(1)
sentences = df['data'].tolist()
# Use a portion for validation. Here, we assume last 10% is validation.
split_index = int(len(sentences) * 0.9)
valid_sentences = sentences[split_index:]
logging.info("Validation sentences: %d", len(valid_sentences))
valid_dataset = NextWordSPDataset(valid_sentences, sp)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size,
shuffle=False, collate_fn=sp_collate_fn)
# Initialize model. You may need to adjust these parameters to match your training.
vocab_size = sp.get_piece_size()
embed_dim = args.embed_dim
hidden_dim = args.hidden_dim
num_layers = args.num_layers
dropout = args.dropout
model = LSTMNextWordModel(vocab_size, embed_dim, hidden_dim, num_layers, dropout)
model.to(device)
# Load the trained model weights
if not os.path.exists(args.model_save_path):
logging.error("Model checkpoint not found at %s", args.model_save_path)
sys.exit(1)
model.load_state_dict(torch.load(args.model_save_path, map_location=device))
logging.info("Loaded model checkpoint from %s", args.model_save_path)
# Define the loss criterion.
# Note: If you used label smoothing during training, you can reuse that here.
class LabelSmoothingLoss(nn.Module):
def __init__(self, smoothing=0.1):
super(LabelSmoothingLoss, self).__init__()
self.smoothing = smoothing
def forward(self, pred, target):
confidence = 1.0 - self.smoothing
vocab_size = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)
smoothed_target = one_hot * confidence + self.smoothing / (vocab_size - 1)
log_prob = torch.log_softmax(pred, dim=-1)
loss = -(smoothed_target * log_prob).sum(dim=1).mean()
return loss
criterion = LabelSmoothingLoss(smoothing=args.label_smoothing)
# Evaluate perplexity and top-k accuracy
val_perplexity = evaluate_perplexity(model, valid_loader, criterion, device)
topk_accuracy = evaluate_topk_accuracy(model, valid_loader, args.top_k, device)
logging.info("Validation Perplexity: %.4f", val_perplexity)
logging.info("Top-%d Accuracy: %.4f", args.top_k, topk_accuracy)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate Next Word Prediction Model")
parser.add_argument('--data_path', type=str, default='data.csv', help="Path to CSV file with a 'data' column")
parser.add_argument('--sp_model_path', type=str, default='spm.model', help="Path to the SentencePiece model file")
parser.add_argument('--model_save_path', type=str, default='best_model.pth', help="Path to the trained model checkpoint")
parser.add_argument('--batch_size', type=int, default=512, help="Batch size for evaluation")
parser.add_argument('--top_k', type=int, default=3, help="Top-k value for computing accuracy")
# Model hyperparameters (should match those used in training)
parser.add_argument('--embed_dim', type=int, default=256, help="Embedding dimension")
parser.add_argument('--hidden_dim', type=int, default=256, help="Hidden dimension")
parser.add_argument('--num_layers', type=int, default=2, help="Number of LSTM layers")
parser.add_argument('--dropout', type=float, default=0.3, help="Dropout rate")
parser.add_argument('--label_smoothing', type=float, default=0.1, help="Label smoothing factor")
args = parser.parse_args()
main(args)
|