File size: 8,696 Bytes
bc5b02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#!/usr/bin/env python
"""
Evaluation script for Next Word Prediction model.
Loads the trained model and SentencePiece model,
prepares the validation dataset, and computes:
    - Perplexity (using average loss)
    - Top-k Accuracy (e.g., top-3 accuracy)
Usage:
    python evaluate_next_word.py --data_path data.csv \
         --sp_model_path spm.model --model_save_path best_model.pth \
         [--batch_size 512] [--top_k 3]
"""

import os
import sys
import math
import argparse
import logging
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import sentencepiece as spm

# ---------------------- Logging Configuration ----------------------
logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# ---------------------- Dataset Definition ----------------------
class NextWordSPDataset(Dataset):
    def __init__(self, sentences, sp):
        self.sp = sp
        self.samples = []
        self.prepare_samples(sentences)
    
    def prepare_samples(self, sentences):
        for sentence in sentences:
            token_ids = self.sp.encode(sentence.strip(), out_type=int)
            # For each sentence, create (input_sequence, target) pairs.
            for i in range(1, len(token_ids)):
                self.samples.append((
                    torch.tensor(token_ids[:i], dtype=torch.long),
                    torch.tensor(token_ids[i], dtype=torch.long)
                ))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

def sp_collate_fn(batch):
    inputs, targets = zip(*batch)
    padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = torch.stack(targets)
    return padded_inputs, targets

# ---------------------- Model Definition ----------------------
class LSTMNextWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout, fc_dropout=0.3):
        super(LSTMNextWordModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, 
                            batch_first=True, dropout=dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(fc_dropout)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = nn.Linear(hidden_dim // 2, vocab_size)

    def forward(self, x):
        emb = self.embedding(x)
        output, _ = self.lstm(emb)
        last_output = output[:, -1, :]
        norm_output = self.layer_norm(last_output)
        norm_output = self.dropout(norm_output)
        fc1_out = torch.relu(self.fc1(norm_output))
        fc1_out = self.dropout(fc1_out)
        logits = self.fc2(fc1_out)
        return logits

# ---------------------- Evaluation Functions ----------------------
def evaluate_perplexity(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            logits = model(inputs)
            loss = criterion(logits, targets)
            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
    avg_loss = total_loss / total_samples
    perplexity = math.exp(avg_loss)
    return perplexity

def evaluate_topk_accuracy(model, dataloader, k, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            logits = model(inputs)
            # Get top-k predictions for each sample
            _, topk_indices = torch.topk(logits, k, dim=-1)
            for i in range(len(targets)):
                if targets[i] in topk_indices[i]:
                    correct += 1
            total += targets.size(0)
    accuracy = correct / total if total > 0 else 0
    return accuracy

# ---------------------- Main Evaluation Routine ----------------------
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info("Using device: %s", device)

    # Load SentencePiece model
    if not os.path.exists(args.sp_model_path):
        logging.error("SentencePiece model not found at %s", args.sp_model_path)
        sys.exit(1)
    sp = spm.SentencePieceProcessor()
    sp.load(args.sp_model_path)
    logging.info("Loaded SentencePiece model from %s", args.sp_model_path)
    
    # Load data and prepare validation set
    if not os.path.exists(args.data_path):
        logging.error("Data CSV file not found at %s", args.data_path)
        sys.exit(1)
    df = pd.read_csv(args.data_path)
    if 'data' not in df.columns:
        logging.error("CSV file must contain a 'data' column.")
        sys.exit(1)
    sentences = df['data'].tolist()
    # Use a portion for validation. Here, we assume last 10% is validation.
    split_index = int(len(sentences) * 0.9)
    valid_sentences = sentences[split_index:]
    logging.info("Validation sentences: %d", len(valid_sentences))
    
    valid_dataset = NextWordSPDataset(valid_sentences, sp)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size,
                              shuffle=False, collate_fn=sp_collate_fn)
    
    # Initialize model. You may need to adjust these parameters to match your training.
    vocab_size = sp.get_piece_size()
    embed_dim = args.embed_dim
    hidden_dim = args.hidden_dim
    num_layers = args.num_layers
    dropout = args.dropout
    model = LSTMNextWordModel(vocab_size, embed_dim, hidden_dim, num_layers, dropout)
    model.to(device)
    
    # Load the trained model weights
    if not os.path.exists(args.model_save_path):
        logging.error("Model checkpoint not found at %s", args.model_save_path)
        sys.exit(1)
    model.load_state_dict(torch.load(args.model_save_path, map_location=device))
    logging.info("Loaded model checkpoint from %s", args.model_save_path)
    
    # Define the loss criterion.
    # Note: If you used label smoothing during training, you can reuse that here.
    class LabelSmoothingLoss(nn.Module):
        def __init__(self, smoothing=0.1):
            super(LabelSmoothingLoss, self).__init__()
            self.smoothing = smoothing

        def forward(self, pred, target):
            confidence = 1.0 - self.smoothing
            vocab_size = pred.size(1)
            one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)
            smoothed_target = one_hot * confidence + self.smoothing / (vocab_size - 1)
            log_prob = torch.log_softmax(pred, dim=-1)
            loss = -(smoothed_target * log_prob).sum(dim=1).mean()
            return loss

    criterion = LabelSmoothingLoss(smoothing=args.label_smoothing)
    
    # Evaluate perplexity and top-k accuracy
    val_perplexity = evaluate_perplexity(model, valid_loader, criterion, device)
    topk_accuracy = evaluate_topk_accuracy(model, valid_loader, args.top_k, device)
    
    logging.info("Validation Perplexity: %.4f", val_perplexity)
    logging.info("Top-%d Accuracy: %.4f", args.top_k, topk_accuracy)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate Next Word Prediction Model")
    parser.add_argument('--data_path', type=str, default='data.csv', help="Path to CSV file with a 'data' column")
    parser.add_argument('--sp_model_path', type=str, default='spm.model', help="Path to the SentencePiece model file")
    parser.add_argument('--model_save_path', type=str, default='best_model.pth', help="Path to the trained model checkpoint")
    parser.add_argument('--batch_size', type=int, default=512, help="Batch size for evaluation")
    parser.add_argument('--top_k', type=int, default=3, help="Top-k value for computing accuracy")
    # Model hyperparameters (should match those used in training)
    parser.add_argument('--embed_dim', type=int, default=256, help="Embedding dimension")
    parser.add_argument('--hidden_dim', type=int, default=256, help="Hidden dimension")
    parser.add_argument('--num_layers', type=int, default=2, help="Number of LSTM layers")
    parser.add_argument('--dropout', type=float, default=0.3, help="Dropout rate")
    parser.add_argument('--label_smoothing', type=float, default=0.1, help="Label smoothing factor")
    
    args = parser.parse_args()
    main(args)