aarohanverma commited on
Commit
bc5b02b
·
verified ·
1 Parent(s): 9657ade

Upload 7 files

Browse files
Files changed (7) hide show
  1. analyzer.py +211 -0
  2. app.py +70 -0
  3. best_model.pth +3 -0
  4. best_model_scripted.pt +3 -0
  5. next_word_prediction.py +365 -0
  6. spm.model +3 -0
  7. spm.vocab +0 -0
analyzer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Evaluation script for Next Word Prediction model.
4
+ Loads the trained model and SentencePiece model,
5
+ prepares the validation dataset, and computes:
6
+ - Perplexity (using average loss)
7
+ - Top-k Accuracy (e.g., top-3 accuracy)
8
+ Usage:
9
+ python evaluate_next_word.py --data_path data.csv \
10
+ --sp_model_path spm.model --model_save_path best_model.pth \
11
+ [--batch_size 512] [--top_k 3]
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import math
17
+ import argparse
18
+ import logging
19
+ import pandas as pd
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.utils.data import Dataset, DataLoader
24
+ from torch.nn.utils.rnn import pad_sequence
25
+
26
+ import sentencepiece as spm
27
+
28
+ # ---------------------- Logging Configuration ----------------------
29
+ logging.basicConfig(
30
+ stream=sys.stdout,
31
+ level=logging.INFO,
32
+ format='%(asctime)s - %(levelname)s - %(message)s',
33
+ datefmt='%Y-%m-%d %H:%M:%S'
34
+ )
35
+
36
+ # ---------------------- Dataset Definition ----------------------
37
+ class NextWordSPDataset(Dataset):
38
+ def __init__(self, sentences, sp):
39
+ self.sp = sp
40
+ self.samples = []
41
+ self.prepare_samples(sentences)
42
+
43
+ def prepare_samples(self, sentences):
44
+ for sentence in sentences:
45
+ token_ids = self.sp.encode(sentence.strip(), out_type=int)
46
+ # For each sentence, create (input_sequence, target) pairs.
47
+ for i in range(1, len(token_ids)):
48
+ self.samples.append((
49
+ torch.tensor(token_ids[:i], dtype=torch.long),
50
+ torch.tensor(token_ids[i], dtype=torch.long)
51
+ ))
52
+
53
+ def __len__(self):
54
+ return len(self.samples)
55
+
56
+ def __getitem__(self, idx):
57
+ return self.samples[idx]
58
+
59
+ def sp_collate_fn(batch):
60
+ inputs, targets = zip(*batch)
61
+ padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
62
+ targets = torch.stack(targets)
63
+ return padded_inputs, targets
64
+
65
+ # ---------------------- Model Definition ----------------------
66
+ class LSTMNextWordModel(nn.Module):
67
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout, fc_dropout=0.3):
68
+ super(LSTMNextWordModel, self).__init__()
69
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
70
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
71
+ batch_first=True, dropout=dropout)
72
+ self.layer_norm = nn.LayerNorm(hidden_dim)
73
+ self.dropout = nn.Dropout(fc_dropout)
74
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
75
+ self.fc2 = nn.Linear(hidden_dim // 2, vocab_size)
76
+
77
+ def forward(self, x):
78
+ emb = self.embedding(x)
79
+ output, _ = self.lstm(emb)
80
+ last_output = output[:, -1, :]
81
+ norm_output = self.layer_norm(last_output)
82
+ norm_output = self.dropout(norm_output)
83
+ fc1_out = torch.relu(self.fc1(norm_output))
84
+ fc1_out = self.dropout(fc1_out)
85
+ logits = self.fc2(fc1_out)
86
+ return logits
87
+
88
+ # ---------------------- Evaluation Functions ----------------------
89
+ def evaluate_perplexity(model, dataloader, criterion, device):
90
+ model.eval()
91
+ total_loss = 0.0
92
+ total_samples = 0
93
+ with torch.no_grad():
94
+ for inputs, targets in dataloader:
95
+ inputs = inputs.to(device)
96
+ targets = targets.to(device)
97
+ logits = model(inputs)
98
+ loss = criterion(logits, targets)
99
+ total_loss += loss.item() * inputs.size(0)
100
+ total_samples += inputs.size(0)
101
+ avg_loss = total_loss / total_samples
102
+ perplexity = math.exp(avg_loss)
103
+ return perplexity
104
+
105
+ def evaluate_topk_accuracy(model, dataloader, k, device):
106
+ model.eval()
107
+ correct = 0
108
+ total = 0
109
+ with torch.no_grad():
110
+ for inputs, targets in dataloader:
111
+ inputs = inputs.to(device)
112
+ targets = targets.to(device)
113
+ logits = model(inputs)
114
+ # Get top-k predictions for each sample
115
+ _, topk_indices = torch.topk(logits, k, dim=-1)
116
+ for i in range(len(targets)):
117
+ if targets[i] in topk_indices[i]:
118
+ correct += 1
119
+ total += targets.size(0)
120
+ accuracy = correct / total if total > 0 else 0
121
+ return accuracy
122
+
123
+ # ---------------------- Main Evaluation Routine ----------------------
124
+ def main(args):
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ logging.info("Using device: %s", device)
127
+
128
+ # Load SentencePiece model
129
+ if not os.path.exists(args.sp_model_path):
130
+ logging.error("SentencePiece model not found at %s", args.sp_model_path)
131
+ sys.exit(1)
132
+ sp = spm.SentencePieceProcessor()
133
+ sp.load(args.sp_model_path)
134
+ logging.info("Loaded SentencePiece model from %s", args.sp_model_path)
135
+
136
+ # Load data and prepare validation set
137
+ if not os.path.exists(args.data_path):
138
+ logging.error("Data CSV file not found at %s", args.data_path)
139
+ sys.exit(1)
140
+ df = pd.read_csv(args.data_path)
141
+ if 'data' not in df.columns:
142
+ logging.error("CSV file must contain a 'data' column.")
143
+ sys.exit(1)
144
+ sentences = df['data'].tolist()
145
+ # Use a portion for validation. Here, we assume last 10% is validation.
146
+ split_index = int(len(sentences) * 0.9)
147
+ valid_sentences = sentences[split_index:]
148
+ logging.info("Validation sentences: %d", len(valid_sentences))
149
+
150
+ valid_dataset = NextWordSPDataset(valid_sentences, sp)
151
+ valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size,
152
+ shuffle=False, collate_fn=sp_collate_fn)
153
+
154
+ # Initialize model. You may need to adjust these parameters to match your training.
155
+ vocab_size = sp.get_piece_size()
156
+ embed_dim = args.embed_dim
157
+ hidden_dim = args.hidden_dim
158
+ num_layers = args.num_layers
159
+ dropout = args.dropout
160
+ model = LSTMNextWordModel(vocab_size, embed_dim, hidden_dim, num_layers, dropout)
161
+ model.to(device)
162
+
163
+ # Load the trained model weights
164
+ if not os.path.exists(args.model_save_path):
165
+ logging.error("Model checkpoint not found at %s", args.model_save_path)
166
+ sys.exit(1)
167
+ model.load_state_dict(torch.load(args.model_save_path, map_location=device))
168
+ logging.info("Loaded model checkpoint from %s", args.model_save_path)
169
+
170
+ # Define the loss criterion.
171
+ # Note: If you used label smoothing during training, you can reuse that here.
172
+ class LabelSmoothingLoss(nn.Module):
173
+ def __init__(self, smoothing=0.1):
174
+ super(LabelSmoothingLoss, self).__init__()
175
+ self.smoothing = smoothing
176
+
177
+ def forward(self, pred, target):
178
+ confidence = 1.0 - self.smoothing
179
+ vocab_size = pred.size(1)
180
+ one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)
181
+ smoothed_target = one_hot * confidence + self.smoothing / (vocab_size - 1)
182
+ log_prob = torch.log_softmax(pred, dim=-1)
183
+ loss = -(smoothed_target * log_prob).sum(dim=1).mean()
184
+ return loss
185
+
186
+ criterion = LabelSmoothingLoss(smoothing=args.label_smoothing)
187
+
188
+ # Evaluate perplexity and top-k accuracy
189
+ val_perplexity = evaluate_perplexity(model, valid_loader, criterion, device)
190
+ topk_accuracy = evaluate_topk_accuracy(model, valid_loader, args.top_k, device)
191
+
192
+ logging.info("Validation Perplexity: %.4f", val_perplexity)
193
+ logging.info("Top-%d Accuracy: %.4f", args.top_k, topk_accuracy)
194
+
195
+ if __name__ == "__main__":
196
+ parser = argparse.ArgumentParser(description="Evaluate Next Word Prediction Model")
197
+ parser.add_argument('--data_path', type=str, default='data.csv', help="Path to CSV file with a 'data' column")
198
+ parser.add_argument('--sp_model_path', type=str, default='spm.model', help="Path to the SentencePiece model file")
199
+ parser.add_argument('--model_save_path', type=str, default='best_model.pth', help="Path to the trained model checkpoint")
200
+ parser.add_argument('--batch_size', type=int, default=512, help="Batch size for evaluation")
201
+ parser.add_argument('--top_k', type=int, default=3, help="Top-k value for computing accuracy")
202
+ # Model hyperparameters (should match those used in training)
203
+ parser.add_argument('--embed_dim', type=int, default=256, help="Embedding dimension")
204
+ parser.add_argument('--hidden_dim', type=int, default=256, help="Hidden dimension")
205
+ parser.add_argument('--num_layers', type=int, default=2, help="Number of LSTM layers")
206
+ parser.add_argument('--dropout', type=float, default=0.3, help="Dropout rate")
207
+ parser.add_argument('--label_smoothing', type=float, default=0.1, help="Label smoothing factor")
208
+
209
+ args = parser.parse_args()
210
+ main(args)
211
+
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import sentencepiece as spm
4
+
5
+ # ---------------------- Model & SentencePiece Loading ----------------------
6
+ @st.cache_resource
7
+ def load_model():
8
+ """Load the TorchScript model for inference."""
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = torch.jit.load("best_model_scripted.pt", map_location=device)
11
+ model.to(device)
12
+ return model, device
13
+
14
+ @st.cache_resource
15
+ def load_sp_model():
16
+ """Load the SentencePiece model."""
17
+ sp = spm.SentencePieceProcessor()
18
+ sp.load("spm.model")
19
+ return sp
20
+
21
+ # ---------------------- Prediction Function ----------------------
22
+ def predict_next_words(model, sp, device, text, topk=3):
23
+ if not text.strip():
24
+ return []
25
+ token_ids = sp.encode(text.strip(), out_type=int)
26
+ if len(token_ids) == 0:
27
+ return []
28
+ input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
29
+ with torch.no_grad():
30
+ logits = model(input_seq)
31
+ probabilities = torch.softmax(logits, dim=-1)
32
+ topk_result = torch.topk(probabilities, k=topk, dim=-1)
33
+ top_indices = topk_result.indices.squeeze(0).tolist()
34
+ predicted_pieces = [sp.id_to_piece(idx).lstrip("▁") for idx in top_indices]
35
+ return predicted_pieces
36
+
37
+ # ---------------------- Streamlit App Layout ----------------------
38
+ def main():
39
+ st.title("Real-Time Next Word Prediction")
40
+ st.write(
41
+ """
42
+ Start typing your sentence below. When you finish a word (i.e. type a space at the end),
43
+ the app will suggest three possible next words. Click on a suggestion to auto-complete your sentence.
44
+ """
45
+ )
46
+
47
+ model, device = load_model()
48
+ sp = load_sp_model()
49
+
50
+ if "input_text" not in st.session_state:
51
+ st.session_state.input_text = ""
52
+
53
+ user_input = st.text_input("Enter your sentence:", st.session_state.input_text, key="text_input")
54
+ st.session_state.input_text = user_input
55
+
56
+ if user_input.endswith(" "):
57
+ predictions = predict_next_words(model, sp, device, user_input, topk=3)
58
+ if predictions:
59
+ st.markdown("### Predictions:")
60
+ cols = st.columns(len(predictions))
61
+ for i, word in enumerate(predictions):
62
+ if cols[i].button(word):
63
+ st.session_state.input_text = user_input + word + " "
64
+ st.rerun() # This triggers the refresh correctly
65
+ else:
66
+ st.write("Type a space at the end of your sentence to get next-word suggestions.")
67
+
68
+ if __name__ == "__main__":
69
+ main()
70
+
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64a7b488dfce765aa9e59aa16eba1353409db2fecbe7de66c6059ce5f9667433
3
+ size 19748260
best_model_scripted.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80ac9a35fe8c8f1bc0f2cde2d9fced1064b97cfbd3cc424c20bb36f902a455d7
3
+ size 19769323
next_word_prediction.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Next Word Prediction using an LSTM model in PyTorch with advanced improvements.
4
+ ---------------------------------------------------------------------------------
5
+ This script supports two modes:
6
+
7
+ Training Mode (with --train):
8
+ - Loads data from CSV (must contain a 'data' column)
9
+ - Trains a SentencePiece model for subword tokenization (if not already available)
10
+ - Uses SentencePiece to tokenize text and create a Dataset of (input_sequence, target) pairs
11
+ - Builds and trains an LSTM-based model enhanced with:
12
+ * Extra fully connected layer (with ReLU and dropout)
13
+ * Layer Normalization after LSTM outputs
14
+ * Label Smoothing Loss for improved regularization
15
+ * Gradient clipping, Adam optimizer with weight decay, and ReduceLROnPlateau scheduling
16
+ - Saves training/validation loss graphs
17
+ - Converts and saves the model to TorchScript for production deployment
18
+
19
+ Inference Mode (with --inference "Your sentence"):
20
+ - Loads the saved SentencePiece model and the TorchScript (or checkpoint) model
21
+ - Runs inference to predict the top 3 next words/subwords
22
+
23
+ Usage:
24
+ Training mode:
25
+ python next_word_prediction.py --data_path data.csv --train
26
+ Inference mode:
27
+ python next_word_prediction.py --inference "How do you"
28
+ """
29
+
30
+ import os
31
+ import sys
32
+ import argparse
33
+ import logging
34
+ import random
35
+ import pickle
36
+ from collections import Counter
37
+
38
+ import numpy as np
39
+ import pandas as pd
40
+ import matplotlib.pyplot as plt
41
+
42
+ import torch
43
+ import torch.nn as nn
44
+ import torch.optim as optim
45
+ from torch.utils.data import Dataset, DataLoader
46
+ from torch.nn.utils.rnn import pad_sequence
47
+
48
+ # Import SentencePiece
49
+ import sentencepiece as spm
50
+
51
+ # ---------------------- Global Definitions ----------------------
52
+ PAD_TOKEN = '<PAD>' # For padding (id will be 0)
53
+ UNK_TOKEN = '<UNK>'
54
+ # We use SentencePiece so our tokens come from the trained model
55
+
56
+ # Set up logging to stdout for Colab compatibility
57
+ logging.basicConfig(
58
+ stream=sys.stdout,
59
+ level=logging.INFO,
60
+ format='%(asctime)s - %(levelname)s - %(message)s',
61
+ datefmt='%Y-%m-%d %H:%M:%S'
62
+ )
63
+
64
+ # ---------------------- Label Smoothing Loss ----------------------
65
+ class LabelSmoothingLoss(nn.Module):
66
+ def __init__(self, smoothing=0.1):
67
+ super(LabelSmoothingLoss, self).__init__()
68
+ self.smoothing = smoothing
69
+
70
+ def forward(self, pred, target):
71
+ confidence = 1.0 - self.smoothing
72
+ vocab_size = pred.size(1)
73
+ one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)
74
+ smoothed_target = one_hot * confidence + self.smoothing / (vocab_size - 1)
75
+ log_prob = torch.log_softmax(pred, dim=-1)
76
+ loss = -(smoothed_target * log_prob).sum(dim=1).mean()
77
+ return loss
78
+
79
+ # ---------------------- SentencePiece Functions ----------------------
80
+ def train_sentencepiece(corpus, model_prefix, vocab_size):
81
+ temp_file = "sp_temp.txt"
82
+ with open(temp_file, "w", encoding="utf-8") as f:
83
+ for sentence in corpus:
84
+ f.write(sentence.strip() + "\n")
85
+ spm.SentencePieceTrainer.train(
86
+ input=temp_file,
87
+ model_prefix=model_prefix,
88
+ vocab_size=vocab_size,
89
+ character_coverage=1.0,
90
+ model_type='unigram'
91
+ )
92
+ os.remove(temp_file)
93
+ logging.info("SentencePiece model trained and saved with prefix '%s'", model_prefix)
94
+
95
+ def load_sentencepiece_model(model_path):
96
+ sp = spm.SentencePieceProcessor()
97
+ sp.load(model_path)
98
+ logging.info("Loaded SentencePiece model from %s", model_path)
99
+ return sp
100
+
101
+ # ---------------------- Dataset using SentencePiece ----------------------
102
+ class NextWordSPDataset(Dataset):
103
+ def __init__(self, sentences, sp):
104
+ logging.info("Initializing NextWordSPDataset with %d sentences", len(sentences))
105
+ self.sp = sp
106
+ self.samples = []
107
+ self.prepare_samples(sentences)
108
+ logging.info("Total samples generated: %d", len(self.samples))
109
+
110
+ def prepare_samples(self, sentences):
111
+ for idx, sentence in enumerate(sentences):
112
+ token_ids = self.sp.encode(sentence.strip(), out_type=int)
113
+ for i in range(1, len(token_ids)):
114
+ self.samples.append((
115
+ torch.tensor(token_ids[:i], dtype=torch.long),
116
+ torch.tensor(token_ids[i], dtype=torch.long)
117
+ ))
118
+ if (idx + 1) % 1000 == 0:
119
+ logging.debug("Processed %d/%d sentences", idx + 1, len(sentences))
120
+
121
+ def __len__(self):
122
+ return len(self.samples)
123
+
124
+ def __getitem__(self, idx):
125
+ return self.samples[idx]
126
+
127
+ def sp_collate_fn(batch):
128
+ inputs, targets = zip(*batch)
129
+ padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
130
+ targets = torch.stack(targets)
131
+ logging.debug("Batch collated: inputs shape %s, targets shape %s", padded_inputs.shape, targets.shape)
132
+ return padded_inputs, targets
133
+
134
+ # ---------------------- Model Definition ----------------------
135
+ class LSTMNextWordModel(nn.Module):
136
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout, fc_dropout=0.3):
137
+ super(LSTMNextWordModel, self).__init__()
138
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
139
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
140
+ batch_first=True, dropout=dropout)
141
+ self.layer_norm = nn.LayerNorm(hidden_dim)
142
+ self.dropout = nn.Dropout(fc_dropout)
143
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
144
+ self.fc2 = nn.Linear(hidden_dim // 2, vocab_size)
145
+
146
+ def forward(self, x):
147
+ # Logging calls removed to allow TorchScript conversion.
148
+ emb = self.embedding(x)
149
+ output, _ = self.lstm(emb)
150
+ last_output = output[:, -1, :]
151
+ norm_output = self.layer_norm(last_output)
152
+ norm_output = self.dropout(norm_output)
153
+ fc1_out = torch.relu(self.fc1(norm_output))
154
+ fc1_out = self.dropout(fc1_out)
155
+ logits = self.fc2(fc1_out)
156
+ return logits
157
+
158
+ # ---------------------- Training and Evaluation ----------------------
159
+ def train_model(model, train_loader, valid_loader, optimizer, criterion, scheduler, device,
160
+ num_epochs, patience, model_save_path, clip_value=5):
161
+ best_val_loss = float('inf')
162
+ patience_counter = 0
163
+ train_losses = []
164
+ val_losses = []
165
+ logging.info("Starting training for %d epochs", num_epochs)
166
+
167
+ for epoch in range(num_epochs):
168
+ logging.info("Epoch %d started...", epoch + 1)
169
+ model.train()
170
+ total_loss = 0.0
171
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
172
+ inputs = inputs.to(device)
173
+ targets = targets.to(device)
174
+ optimizer.zero_grad()
175
+ outputs = model(inputs)
176
+ loss = criterion(outputs, targets)
177
+ loss.backward()
178
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
179
+ optimizer.step()
180
+ total_loss += loss.item()
181
+ if (batch_idx + 1) % 50 == 0:
182
+ logging.debug("Epoch %d, Batch %d: Loss = %.4f", epoch + 1, batch_idx + 1, loss.item())
183
+ avg_train_loss = total_loss / len(train_loader)
184
+ train_losses.append(avg_train_loss)
185
+ logging.info("Epoch %d training completed. Avg Train Loss: %.4f", epoch + 1, avg_train_loss)
186
+
187
+ model.eval()
188
+ total_val_loss = 0.0
189
+ with torch.no_grad():
190
+ for batch_idx, (inputs, targets) in enumerate(valid_loader):
191
+ inputs = inputs.to(device)
192
+ targets = targets.to(device)
193
+ outputs = model(inputs)
194
+ loss = criterion(outputs, targets)
195
+ total_val_loss += loss.item()
196
+ if (batch_idx + 1) % 50 == 0:
197
+ logging.debug("Validation Epoch %d, Batch %d: Loss = %.4f", epoch + 1, batch_idx + 1, loss.item())
198
+ avg_val_loss = total_val_loss / len(valid_loader)
199
+ val_losses.append(avg_val_loss)
200
+ logging.info("Epoch %d validation completed. Avg Val Loss: %.4f", epoch + 1, avg_val_loss)
201
+
202
+ scheduler.step(avg_val_loss)
203
+
204
+ if avg_val_loss < best_val_loss:
205
+ best_val_loss = avg_val_loss
206
+ patience_counter = 0
207
+ torch.save(model.state_dict(), model_save_path)
208
+ logging.info("Checkpoint saved at epoch %d with Val Loss: %.4f", epoch + 1, avg_val_loss)
209
+ else:
210
+ patience_counter += 1
211
+ logging.info("No improvement in validation loss for %d consecutive epoch(s).", patience_counter)
212
+ if patience_counter >= patience:
213
+ logging.info("Early stopping triggered at epoch %d", epoch + 1)
214
+ break
215
+
216
+ plt.figure()
217
+ plt.plot(range(1, len(train_losses)+1), train_losses, label="Train Loss")
218
+ plt.plot(range(1, len(val_losses)+1), val_losses, label="Validation Loss")
219
+ plt.xlabel("Epoch")
220
+ plt.ylabel("Loss")
221
+ plt.legend()
222
+ plt.title("Training and Validation Loss")
223
+ plt.savefig("loss_graph.png")
224
+ logging.info("Loss graph saved as loss_graph.png")
225
+
226
+ return train_losses, val_losses
227
+
228
+ def predict_next_word(model, sentence, sp, device, topk=3):
229
+ """
230
+ Given a partial sentence, uses SentencePiece to tokenize and predicts the top k next words.
231
+ """
232
+ logging.info("Predicting top %d next words for input sentence: '%s'", topk, sentence)
233
+ model.eval()
234
+ token_ids = sp.encode(sentence.strip(), out_type=int)
235
+ logging.debug("Token IDs for prediction: %s", token_ids)
236
+ if len(token_ids) == 0:
237
+ logging.warning("No tokens found in input sentence.")
238
+ return []
239
+ input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
240
+ with torch.no_grad():
241
+ logits = model(input_seq)
242
+ probabilities = torch.softmax(logits, dim=-1)
243
+ topk_result = torch.topk(probabilities, k=topk, dim=-1)
244
+ top_indices = topk_result.indices.squeeze(0).tolist()
245
+ predicted_pieces = [sp.id_to_piece(idx) for idx in top_indices]
246
+ cleaned_predictions = [piece.lstrip("▁") for piece in predicted_pieces]
247
+ logging.info("Predicted top %d next words/subwords: %s", topk, cleaned_predictions)
248
+ return cleaned_predictions
249
+
250
+ # ---------------------- Main Function ----------------------
251
+ def main(args):
252
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
253
+ logging.info("Using device: %s", device)
254
+
255
+ # Inference-only mode
256
+ if args.inference is not None:
257
+ logging.info("Running in inference-only mode with input: '%s'", args.inference)
258
+ if not os.path.exists(args.sp_model_path):
259
+ logging.error("SentencePiece model not found at %s. Cannot run inference.", args.sp_model_path)
260
+ return
261
+ sp = load_sentencepiece_model(args.sp_model_path)
262
+ if os.path.exists(args.scripted_model_path):
263
+ logging.info("Loading TorchScript model from %s", args.scripted_model_path)
264
+ model = torch.jit.load(args.scripted_model_path, map_location=device)
265
+ elif os.path.exists(args.model_save_path):
266
+ logging.info("Loading model checkpoint from %s", args.model_save_path)
267
+ model = LSTMNextWordModel(vocab_size=sp.get_piece_size(),
268
+ embed_dim=args.embed_dim,
269
+ hidden_dim=args.hidden_dim,
270
+ num_layers=args.num_layers,
271
+ dropout=args.dropout,
272
+ fc_dropout=0.3)
273
+ model.load_state_dict(torch.load(args.model_save_path, map_location=device))
274
+ model.to(device)
275
+ else:
276
+ logging.error("No model checkpoint found. Exiting.")
277
+ return
278
+ predictions = predict_next_word(model, args.inference, sp, device, topk=1)
279
+ logging.info("Input: '%s' -> Predicted next words: %s", args.inference, predictions)
280
+ return
281
+
282
+ # Training mode
283
+ logging.info("Loading data from %s...", args.data_path)
284
+ df = pd.read_csv(args.data_path)
285
+ if 'data' not in df.columns:
286
+ logging.error("CSV file must contain a 'data' column. Exiting.")
287
+ return
288
+ sentences = df['data'].tolist()
289
+ logging.info("Total sentences loaded: %d", len(sentences))
290
+
291
+ if not os.path.exists(args.sp_model_path):
292
+ logging.info("SentencePiece model not found at %s. Training new model...", args.sp_model_path)
293
+ train_sentencepiece(sentences, args.sp_model_prefix, args.vocab_size)
294
+ sp = load_sentencepiece_model(args.sp_model_path)
295
+
296
+ train_sentences = sentences[:int(len(sentences) * args.train_split)]
297
+ valid_sentences = sentences[int(len(sentences) * args.train_split):]
298
+ train_dataset = NextWordSPDataset(train_sentences, sp)
299
+ valid_dataset = NextWordSPDataset(valid_sentences, sp)
300
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=sp_collate_fn)
301
+ valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=sp_collate_fn)
302
+ logging.info("DataLoaders created: %d training batches, %d validation batches",
303
+ len(train_loader), len(valid_loader))
304
+
305
+ vocab_size = sp.get_piece_size()
306
+ model = LSTMNextWordModel(vocab_size=vocab_size,
307
+ embed_dim=args.embed_dim,
308
+ hidden_dim=args.hidden_dim,
309
+ num_layers=args.num_layers,
310
+ dropout=args.dropout,
311
+ fc_dropout=0.3)
312
+ model.to(device)
313
+
314
+ criterion = LabelSmoothingLoss(smoothing=args.label_smoothing)
315
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
316
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)
317
+ logging.info("Loss function, optimizer, and scheduler initialized.")
318
+
319
+ if args.train:
320
+ logging.info("Training mode is ON.")
321
+ if os.path.exists(args.model_save_path):
322
+ logging.info("Existing checkpoint found at %s. Loading weights...", args.model_save_path)
323
+ model.load_state_dict(torch.load(args.model_save_path, map_location=device))
324
+ else:
325
+ logging.info("No checkpoint found. Training from scratch.")
326
+ train_losses, val_losses = train_model(model, train_loader, valid_loader, optimizer, criterion,
327
+ scheduler, device, args.num_epochs, args.patience,
328
+ args.model_save_path)
329
+ scripted_model = torch.jit.script(model)
330
+ scripted_model.save(args.scripted_model_path)
331
+ logging.info("Model converted to TorchScript and saved to %s", args.scripted_model_path)
332
+ else:
333
+ logging.info("Training flag not set. Skipping training and running inference demo.")
334
+ if not os.path.exists(args.model_save_path):
335
+ logging.error("No model checkpoint found. Exiting.")
336
+ return
337
+
338
+
339
+ # ---------------------- Entry Point ----------------------
340
+ if __name__ == "__main__":
341
+ parser = argparse.ArgumentParser(description="Next Word Prediction using LSTM in PyTorch with SentencePiece and advanced techniques")
342
+ parser.add_argument('--data_path', type=str, default='data.csv', help="Path to CSV file with a 'data' column (required for training)")
343
+ parser.add_argument('--vocab_size', type=int, default=10000, help="Vocabulary size for SentencePiece")
344
+ parser.add_argument('--train_split', type=float, default=0.9, help="Fraction of data to use for training")
345
+ parser.add_argument('--batch_size', type=int, default=512, help="Batch size for training")
346
+ parser.add_argument('--embed_dim', type=int, default=256, help="Dimension of word embeddings")
347
+ parser.add_argument('--hidden_dim', type=int, default=256, help="Hidden dimension for LSTM")
348
+ parser.add_argument('--num_layers', type=int, default=2, help="Number of LSTM layers")
349
+ parser.add_argument('--dropout', type=float, default=0.3, help="Dropout rate in LSTM")
350
+ parser.add_argument('--learning_rate', type=float, default=0.001, help="Learning rate for optimizer")
351
+ parser.add_argument('--weight_decay', type=float, default=1e-5, help="Weight decay (L2 regularization) for optimizer")
352
+ parser.add_argument('--num_epochs', type=int, default=25, help="Number of training epochs")
353
+ parser.add_argument('--patience', type=int, default=5, help="Early stopping patience")
354
+ parser.add_argument('--label_smoothing', type=float, default=0.1, help="Label smoothing factor")
355
+ parser.add_argument('--model_save_path', type=str, default='best_model.pth', help="Path to save the best model checkpoint")
356
+ parser.add_argument('--scripted_model_path', type=str, default='best_model_scripted.pt', help="Path to save the TorchScript model")
357
+ parser.add_argument('--sp_model_prefix', type=str, default='spm', help="Prefix for SentencePiece model files")
358
+ parser.add_argument('--sp_model_path', type=str, default='spm.model', help="Path to load/save the SentencePiece model")
359
+ parser.add_argument('--seed', type=int, default=42, help="Random seed for reproducibility")
360
+ parser.add_argument('--train', action='store_true', help="Flag to enable training mode. If not set, runs inference/demo using saved checkpoint.")
361
+ parser.add_argument('--inference', type=str, default=None, help="Input sentence for inference-only mode")
362
+
363
+ args, unknown = parser.parse_known_args()
364
+ logging.info("Arguments parsed: %s", args)
365
+ main(args)
spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe3060038cf9883da1a90d9a4770b57e82c537903000dcb7c07cee5acd7e68e8
3
+ size 411288
spm.vocab ADDED
The diff for this file is too large to render. See raw diff