zakerytclarke commited on
Commit
b68a80c
·
verified ·
1 Parent(s): 2d7d97f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +40 -35
src/streamlit_app.py CHANGED
@@ -38,9 +38,11 @@ model_type = st.sidebar.selectbox(
38
 
39
  temperature = st.sidebar.slider("Sampling Temperature", 0.1, 2.0, 1.0)
40
 
41
- # Context size slider (minimum 2)
42
  context_size = st.sidebar.slider("Context Size (how many tokens to look back)", min_value=2, max_value=10, value=3, step=1)
43
 
 
 
 
44
  train_button = st.sidebar.button("Train Model")
45
 
46
  device = torch.device("cpu") # force CPU usage
@@ -74,10 +76,12 @@ def tokenize(text, tokenizer_type):
74
  tokens = text.split()
75
  return tokens
76
 
77
- tokens = tokenize(text_data, tokenizer_type)
78
- vocab = list(set(tokens))
 
 
79
 
80
- # Add PAD token to vocab for padding contexts shorter than context_size - 1
81
  PAD_TOKEN = "<PAD>"
82
  if PAD_TOKEN not in vocab:
83
  vocab.append(PAD_TOKEN)
@@ -90,10 +94,6 @@ idx_to_token = {i: tok for tok, i in token_to_idx.items()}
90
  ###################################
91
 
92
  def pad_context(context, size):
93
- """
94
- Pads the context list at the front with PAD_TOKEN if length < size,
95
- or truncates to last `size` tokens if longer.
96
- """
97
  pad_len = size - len(context)
98
  if pad_len > 0:
99
  return [PAD_TOKEN]*pad_len + context
@@ -145,31 +145,33 @@ class FFNN(nn.Module):
145
 
146
  def train_ffnn(tokens, context_size=3, epochs=3):
147
  data = []
148
- for i in range(len(tokens)):
149
- start_idx = i - (context_size - 1)
150
- context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
151
  context = pad_context(context, context_size - 1)
152
- target = tokens[i]
153
  data.append((
154
  torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
155
  token_to_idx.get(target, token_to_idx[PAD_TOKEN])
156
  ))
157
 
 
 
 
 
158
  model = FFNN(len(vocab), context_size - 1).to(device)
159
  optimizer = optim.Adam(model.parameters(), lr=0.01)
160
  criterion = nn.CrossEntropyLoss()
161
 
162
  progress_bar = st.progress(0)
163
- total_steps = epochs * len(data)
164
  step = 0
165
 
166
  model.train()
167
-
168
  for epoch in range(epochs):
169
  total_loss = 0
170
  random.shuffle(data)
171
  for x, y in data:
172
- x = x.unsqueeze(0) # batch size 1
173
  y = torch.tensor([y], device=device)
174
 
175
  optimizer.zero_grad()
@@ -201,11 +203,10 @@ def ffnn_predict(model, context, temperature=1.0):
201
 
202
  def train_dt(tokens, context_size=3):
203
  X, y = [], []
204
- for i in range(len(tokens)):
205
- start_idx = i - (context_size - 1)
206
- context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
207
  context = pad_context(context, context_size - 1)
208
- target = tokens[i]
209
  X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
210
  y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
211
 
@@ -226,11 +227,10 @@ def dt_predict(model, context):
226
 
227
  def train_gbt(tokens, context_size=3):
228
  X, y = [], []
229
- for i in range(len(tokens)):
230
- start_idx = i - (context_size - 1)
231
- context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
232
  context = pad_context(context, context_size - 1)
233
- target = tokens[i]
234
  X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
235
  y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
236
 
@@ -264,26 +264,28 @@ class RNNModel(nn.Module):
264
 
265
  def train_rnn(tokens, context_size=3, epochs=3):
266
  data = []
267
- for i in range(len(tokens)):
268
- start_idx = i - (context_size - 1)
269
- context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
270
  context = pad_context(context, context_size - 1)
271
- target = tokens[i]
272
  data.append((
273
  torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
274
  token_to_idx.get(target, token_to_idx[PAD_TOKEN])
275
  ))
276
 
 
 
 
 
277
  model = RNNModel(len(vocab)).to(device)
278
  optimizer = optim.Adam(model.parameters(), lr=0.01)
279
  criterion = nn.CrossEntropyLoss()
280
 
281
  progress_bar = st.progress(0)
282
- total_steps = epochs * len(data)
283
  step = 0
284
 
285
  model.train()
286
-
287
  for epoch in range(epochs):
288
  total_loss = 0
289
  h = None
@@ -296,8 +298,8 @@ def train_rnn(tokens, context_size=3, epochs=3):
296
  optimizer.zero_grad()
297
  loss.backward()
298
  optimizer.step()
299
- total_loss += loss.item()
300
 
 
301
  step += 1
302
  progress_bar.progress(step / total_steps)
303
 
@@ -319,7 +321,7 @@ def rnn_predict(model, context, temperature=1.0):
319
  ###################################
320
 
321
  if train_button:
322
- st.write(f"Training **{model_type}** model with context size {context_size}...")
323
 
324
  if model_type == "N-gram":
325
  with st.spinner("Training N-gram model..."):
@@ -333,10 +335,13 @@ if train_button:
333
  elif model_type == "RNN":
334
  model = train_rnn(tokens, context_size=context_size)
335
 
336
- st.session_state["model"] = model
337
- st.session_state["model_type"] = model_type
338
- st.session_state["context_size"] = context_size
339
- st.success(f"{model_type} model trained.")
 
 
 
340
 
341
  ###################################
342
  # Chat interface
 
38
 
39
  temperature = st.sidebar.slider("Sampling Temperature", 0.1, 2.0, 1.0)
40
 
 
41
  context_size = st.sidebar.slider("Context Size (how many tokens to look back)", min_value=2, max_value=10, value=3, step=1)
42
 
43
+ # Number of tokens from dataset to use for training (minimum 100 tokens)
44
+ num_train_tokens = st.sidebar.slider("Number of tokens from dataset to train on", min_value=100, max_value=100000, value=1000, step=100)
45
+
46
  train_button = st.sidebar.button("Train Model")
47
 
48
  device = torch.device("cpu") # force CPU usage
 
76
  tokens = text.split()
77
  return tokens
78
 
79
+ tokens_all = tokenize(text_data, tokenizer_type)
80
+
81
+ # Cap tokens to requested number for training
82
+ tokens = tokens_all[:num_train_tokens]
83
 
84
+ vocab = list(set(tokens))
85
  PAD_TOKEN = "<PAD>"
86
  if PAD_TOKEN not in vocab:
87
  vocab.append(PAD_TOKEN)
 
94
  ###################################
95
 
96
  def pad_context(context, size):
 
 
 
 
97
  pad_len = size - len(context)
98
  if pad_len > 0:
99
  return [PAD_TOKEN]*pad_len + context
 
145
 
146
  def train_ffnn(tokens, context_size=3, epochs=3):
147
  data = []
148
+ for i in range(len(tokens) - (context_size - 1)):
149
+ context = tokens[i : i + context_size - 1]
 
150
  context = pad_context(context, context_size - 1)
151
+ target = tokens[i + context_size - 1]
152
  data.append((
153
  torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
154
  token_to_idx.get(target, token_to_idx[PAD_TOKEN])
155
  ))
156
 
157
+ if len(data) == 0:
158
+ st.warning("No training data generated. Increase dataset size or reduce context size.")
159
+ return None
160
+
161
  model = FFNN(len(vocab), context_size - 1).to(device)
162
  optimizer = optim.Adam(model.parameters(), lr=0.01)
163
  criterion = nn.CrossEntropyLoss()
164
 
165
  progress_bar = st.progress(0)
166
+ total_steps = len(data) * epochs
167
  step = 0
168
 
169
  model.train()
 
170
  for epoch in range(epochs):
171
  total_loss = 0
172
  random.shuffle(data)
173
  for x, y in data:
174
+ x = x.unsqueeze(0)
175
  y = torch.tensor([y], device=device)
176
 
177
  optimizer.zero_grad()
 
203
 
204
  def train_dt(tokens, context_size=3):
205
  X, y = [], []
206
+ for i in range(len(tokens) - (context_size - 1)):
207
+ context = tokens[i : i + context_size - 1]
 
208
  context = pad_context(context, context_size - 1)
209
+ target = tokens[i + context_size - 1]
210
  X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
211
  y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
212
 
 
227
 
228
  def train_gbt(tokens, context_size=3):
229
  X, y = [], []
230
+ for i in range(len(tokens) - (context_size - 1)):
231
+ context = tokens[i : i + context_size - 1]
 
232
  context = pad_context(context, context_size - 1)
233
+ target = tokens[i + context_size - 1]
234
  X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
235
  y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
236
 
 
264
 
265
  def train_rnn(tokens, context_size=3, epochs=3):
266
  data = []
267
+ for i in range(len(tokens) - (context_size - 1)):
268
+ context = tokens[i : i + context_size - 1]
 
269
  context = pad_context(context, context_size - 1)
270
+ target = tokens[i + context_size - 1]
271
  data.append((
272
  torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
273
  token_to_idx.get(target, token_to_idx[PAD_TOKEN])
274
  ))
275
 
276
+ if len(data) == 0:
277
+ st.warning("No training data generated. Increase dataset size or reduce context size.")
278
+ return None
279
+
280
  model = RNNModel(len(vocab)).to(device)
281
  optimizer = optim.Adam(model.parameters(), lr=0.01)
282
  criterion = nn.CrossEntropyLoss()
283
 
284
  progress_bar = st.progress(0)
285
+ total_steps = len(data) * epochs
286
  step = 0
287
 
288
  model.train()
 
289
  for epoch in range(epochs):
290
  total_loss = 0
291
  h = None
 
298
  optimizer.zero_grad()
299
  loss.backward()
300
  optimizer.step()
 
301
 
302
+ total_loss += loss.item()
303
  step += 1
304
  progress_bar.progress(step / total_steps)
305
 
 
321
  ###################################
322
 
323
  if train_button:
324
+ st.write(f"Training **{model_type}** model with context size {context_size} on {len(tokens)} tokens...")
325
 
326
  if model_type == "N-gram":
327
  with st.spinner("Training N-gram model..."):
 
335
  elif model_type == "RNN":
336
  model = train_rnn(tokens, context_size=context_size)
337
 
338
+ if model is not None:
339
+ st.session_state["model"] = model
340
+ st.session_state["model_type"] = model_type
341
+ st.session_state["context_size"] = context_size
342
+ st.success(f"{model_type} model trained.")
343
+ else:
344
+ st.error("Training failed due to no data.")
345
 
346
  ###################################
347
  # Chat interface