File size: 17,987 Bytes
bc5b02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
#!/usr/bin/env python
"""

Next Word Prediction using an LSTM model in PyTorch with advanced improvements.

---------------------------------------------------------------------------------

This script supports two modes:



Training Mode (with --train):

  - Loads data from CSV (must contain a 'data' column)

  - Trains a SentencePiece model for subword tokenization (if not already available)

  - Uses SentencePiece to tokenize text and create a Dataset of (input_sequence, target) pairs

  - Builds and trains an LSTM-based model enhanced with:

      * Extra fully connected layer (with ReLU and dropout)

      * Layer Normalization after LSTM outputs

      * Label Smoothing Loss for improved regularization

      * Gradient clipping, Adam optimizer with weight decay, and ReduceLROnPlateau scheduling

  - Saves training/validation loss graphs

  - Converts and saves the model to TorchScript for production deployment



Inference Mode (with --inference "Your sentence"):

  - Loads the saved SentencePiece model and the TorchScript (or checkpoint) model

  - Runs inference to predict the top 3 next words/subwords



Usage:

  Training mode:

      python next_word_prediction.py --data_path data.csv --train

  Inference mode:

      python next_word_prediction.py --inference "How do you"

"""

import os
import sys
import argparse
import logging
import random
import pickle
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Import SentencePiece
import sentencepiece as spm

# ---------------------- Global Definitions ----------------------
PAD_TOKEN = '<PAD>'  # For padding (id will be 0)
UNK_TOKEN = '<UNK>'
# We use SentencePiece so our tokens come from the trained model

# Set up logging to stdout for Colab compatibility
logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# ---------------------- Label Smoothing Loss ----------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        confidence = 1.0 - self.smoothing
        vocab_size = pred.size(1)
        one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)
        smoothed_target = one_hot * confidence + self.smoothing / (vocab_size - 1)
        log_prob = torch.log_softmax(pred, dim=-1)
        loss = -(smoothed_target * log_prob).sum(dim=1).mean()
        return loss

# ---------------------- SentencePiece Functions ----------------------
def train_sentencepiece(corpus, model_prefix, vocab_size):
    temp_file = "sp_temp.txt"
    with open(temp_file, "w", encoding="utf-8") as f:
        for sentence in corpus:
            f.write(sentence.strip() + "\n")
    spm.SentencePieceTrainer.train(
        input=temp_file,
        model_prefix=model_prefix,
        vocab_size=vocab_size,
        character_coverage=1.0,
        model_type='unigram'
    )
    os.remove(temp_file)
    logging.info("SentencePiece model trained and saved with prefix '%s'", model_prefix)

def load_sentencepiece_model(model_path):
    sp = spm.SentencePieceProcessor()
    sp.load(model_path)
    logging.info("Loaded SentencePiece model from %s", model_path)
    return sp

# ---------------------- Dataset using SentencePiece ----------------------
class NextWordSPDataset(Dataset):
    def __init__(self, sentences, sp):
        logging.info("Initializing NextWordSPDataset with %d sentences", len(sentences))
        self.sp = sp
        self.samples = []
        self.prepare_samples(sentences)
        logging.info("Total samples generated: %d", len(self.samples))
    
    def prepare_samples(self, sentences):
        for idx, sentence in enumerate(sentences):
            token_ids = self.sp.encode(sentence.strip(), out_type=int)
            for i in range(1, len(token_ids)):
                self.samples.append((
                    torch.tensor(token_ids[:i], dtype=torch.long),
                    torch.tensor(token_ids[i], dtype=torch.long)
                ))
            if (idx + 1) % 1000 == 0:
                logging.debug("Processed %d/%d sentences", idx + 1, len(sentences))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

def sp_collate_fn(batch):
    inputs, targets = zip(*batch)
    padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = torch.stack(targets)
    logging.debug("Batch collated: inputs shape %s, targets shape %s", padded_inputs.shape, targets.shape)
    return padded_inputs, targets

# ---------------------- Model Definition ----------------------
class LSTMNextWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout, fc_dropout=0.3):
        super(LSTMNextWordModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, 
                            batch_first=True, dropout=dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(fc_dropout)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = nn.Linear(hidden_dim // 2, vocab_size)

    def forward(self, x):
        # Logging calls removed to allow TorchScript conversion.
        emb = self.embedding(x)
        output, _ = self.lstm(emb)
        last_output = output[:, -1, :]
        norm_output = self.layer_norm(last_output)
        norm_output = self.dropout(norm_output)
        fc1_out = torch.relu(self.fc1(norm_output))
        fc1_out = self.dropout(fc1_out)
        logits = self.fc2(fc1_out)
        return logits

# ---------------------- Training and Evaluation ----------------------
def train_model(model, train_loader, valid_loader, optimizer, criterion, scheduler, device,

                num_epochs, patience, model_save_path, clip_value=5):
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    logging.info("Starting training for %d epochs", num_epochs)

    for epoch in range(num_epochs):
        logging.info("Epoch %d started...", epoch + 1)
        model.train()
        total_loss = 0.0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            total_loss += loss.item()
            if (batch_idx + 1) % 50 == 0:
                logging.debug("Epoch %d, Batch %d: Loss = %.4f", epoch + 1, batch_idx + 1, loss.item())
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        logging.info("Epoch %d training completed. Avg Train Loss: %.4f", epoch + 1, avg_train_loss)
        
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(valid_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                total_val_loss += loss.item()
                if (batch_idx + 1) % 50 == 0:
                    logging.debug("Validation Epoch %d, Batch %d: Loss = %.4f", epoch + 1, batch_idx + 1, loss.item())
        avg_val_loss = total_val_loss / len(valid_loader)
        val_losses.append(avg_val_loss)
        logging.info("Epoch %d validation completed. Avg Val Loss: %.4f", epoch + 1, avg_val_loss)
        
        scheduler.step(avg_val_loss)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), model_save_path)
            logging.info("Checkpoint saved at epoch %d with Val Loss: %.4f", epoch + 1, avg_val_loss)
        else:
            patience_counter += 1
            logging.info("No improvement in validation loss for %d consecutive epoch(s).", patience_counter)
            if patience_counter >= patience:
                logging.info("Early stopping triggered at epoch %d", epoch + 1)
                break

    plt.figure()
    plt.plot(range(1, len(train_losses)+1), train_losses, label="Train Loss")
    plt.plot(range(1, len(val_losses)+1), val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training and Validation Loss")
    plt.savefig("loss_graph.png")
    logging.info("Loss graph saved as loss_graph.png")
    
    return train_losses, val_losses

def predict_next_word(model, sentence, sp, device, topk=3):
    """

    Given a partial sentence, uses SentencePiece to tokenize and predicts the top k next words.

    """
    logging.info("Predicting top %d next words for input sentence: '%s'", topk, sentence)
    model.eval()
    token_ids = sp.encode(sentence.strip(), out_type=int)
    logging.debug("Token IDs for prediction: %s", token_ids)
    if len(token_ids) == 0:
        logging.warning("No tokens found in input sentence.")
        return []
    input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(input_seq)
        probabilities = torch.softmax(logits, dim=-1)
        topk_result = torch.topk(probabilities, k=topk, dim=-1)
        top_indices = topk_result.indices.squeeze(0).tolist()
    predicted_pieces = [sp.id_to_piece(idx) for idx in top_indices]
    cleaned_predictions = [piece.lstrip("▁") for piece in predicted_pieces]
    logging.info("Predicted top %d next words/subwords: %s", topk, cleaned_predictions)
    return cleaned_predictions

# ---------------------- Main Function ----------------------
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info("Using device: %s", device)

    # Inference-only mode
    if args.inference is not None:
        logging.info("Running in inference-only mode with input: '%s'", args.inference)
        if not os.path.exists(args.sp_model_path):
            logging.error("SentencePiece model not found at %s. Cannot run inference.", args.sp_model_path)
            return
        sp = load_sentencepiece_model(args.sp_model_path)
        if os.path.exists(args.scripted_model_path):
            logging.info("Loading TorchScript model from %s", args.scripted_model_path)
            model = torch.jit.load(args.scripted_model_path, map_location=device)
        elif os.path.exists(args.model_save_path):
            logging.info("Loading model checkpoint from %s", args.model_save_path)
            model = LSTMNextWordModel(vocab_size=sp.get_piece_size(),
                                      embed_dim=args.embed_dim,
                                      hidden_dim=args.hidden_dim,
                                      num_layers=args.num_layers,
                                      dropout=args.dropout,
                                      fc_dropout=0.3)
            model.load_state_dict(torch.load(args.model_save_path, map_location=device))
            model.to(device)
        else:
            logging.error("No model checkpoint found. Exiting.")
            return
        predictions = predict_next_word(model, args.inference, sp, device, topk=1)
        logging.info("Input: '%s' -> Predicted next words: %s", args.inference, predictions)
        return

    # Training mode
    logging.info("Loading data from %s...", args.data_path)
    df = pd.read_csv(args.data_path)
    if 'data' not in df.columns:
        logging.error("CSV file must contain a 'data' column. Exiting.")
        return
    sentences = df['data'].tolist()
    logging.info("Total sentences loaded: %d", len(sentences))
    
    if not os.path.exists(args.sp_model_path):
        logging.info("SentencePiece model not found at %s. Training new model...", args.sp_model_path)
        train_sentencepiece(sentences, args.sp_model_prefix, args.vocab_size)
    sp = load_sentencepiece_model(args.sp_model_path)
    
    train_sentences = sentences[:int(len(sentences) * args.train_split)]
    valid_sentences = sentences[int(len(sentences) * args.train_split):]
    train_dataset = NextWordSPDataset(train_sentences, sp)
    valid_dataset = NextWordSPDataset(valid_sentences, sp)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=sp_collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=sp_collate_fn)
    logging.info("DataLoaders created: %d training batches, %d validation batches",
                 len(train_loader), len(valid_loader))
    
    vocab_size = sp.get_piece_size()
    model = LSTMNextWordModel(vocab_size=vocab_size,
                              embed_dim=args.embed_dim,
                              hidden_dim=args.hidden_dim,
                              num_layers=args.num_layers,
                              dropout=args.dropout,
                              fc_dropout=0.3)
    model.to(device)
    
    criterion = LabelSmoothingLoss(smoothing=args.label_smoothing)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)
    logging.info("Loss function, optimizer, and scheduler initialized.")
    
    if args.train:
        logging.info("Training mode is ON.")
        if os.path.exists(args.model_save_path):
            logging.info("Existing checkpoint found at %s. Loading weights...", args.model_save_path)
            model.load_state_dict(torch.load(args.model_save_path, map_location=device))
        else:
            logging.info("No checkpoint found. Training from scratch.")
        train_losses, val_losses = train_model(model, train_loader, valid_loader, optimizer, criterion,
                                                scheduler, device, args.num_epochs, args.patience,
                                                args.model_save_path)
        scripted_model = torch.jit.script(model)
        scripted_model.save(args.scripted_model_path)
        logging.info("Model converted to TorchScript and saved to %s", args.scripted_model_path)
    else:
        logging.info("Training flag not set. Skipping training and running inference demo.")
        if not os.path.exists(args.model_save_path):
            logging.error("No model checkpoint found. Exiting.")
            return
    

# ---------------------- Entry Point ----------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Next Word Prediction using LSTM in PyTorch with SentencePiece and advanced techniques")
    parser.add_argument('--data_path', type=str, default='data.csv', help="Path to CSV file with a 'data' column (required for training)")
    parser.add_argument('--vocab_size', type=int, default=10000, help="Vocabulary size for SentencePiece")
    parser.add_argument('--train_split', type=float, default=0.9, help="Fraction of data to use for training")
    parser.add_argument('--batch_size', type=int, default=512, help="Batch size for training")
    parser.add_argument('--embed_dim', type=int, default=256, help="Dimension of word embeddings")
    parser.add_argument('--hidden_dim', type=int, default=256, help="Hidden dimension for LSTM")
    parser.add_argument('--num_layers', type=int, default=2, help="Number of LSTM layers")
    parser.add_argument('--dropout', type=float, default=0.3, help="Dropout rate in LSTM")
    parser.add_argument('--learning_rate', type=float, default=0.001, help="Learning rate for optimizer")
    parser.add_argument('--weight_decay', type=float, default=1e-5, help="Weight decay (L2 regularization) for optimizer")
    parser.add_argument('--num_epochs', type=int, default=25, help="Number of training epochs")
    parser.add_argument('--patience', type=int, default=5, help="Early stopping patience")
    parser.add_argument('--label_smoothing', type=float, default=0.1, help="Label smoothing factor")
    parser.add_argument('--model_save_path', type=str, default='best_model.pth', help="Path to save the best model checkpoint")
    parser.add_argument('--scripted_model_path', type=str, default='best_model_scripted.pt', help="Path to save the TorchScript model")
    parser.add_argument('--sp_model_prefix', type=str, default='spm', help="Prefix for SentencePiece model files")
    parser.add_argument('--sp_model_path', type=str, default='spm.model', help="Path to load/save the SentencePiece model")
    parser.add_argument('--seed', type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument('--train', action='store_true', help="Flag to enable training mode. If not set, runs inference/demo using saved checkpoint.")
    parser.add_argument('--inference', type=str, default=None, help="Input sentence for inference-only mode")
    
    args, unknown = parser.parse_known_args()
    logging.info("Arguments parsed: %s", args)
    main(args)