Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +40 -35
src/streamlit_app.py
CHANGED
@@ -38,9 +38,11 @@ model_type = st.sidebar.selectbox(
|
|
38 |
|
39 |
temperature = st.sidebar.slider("Sampling Temperature", 0.1, 2.0, 1.0)
|
40 |
|
41 |
-
# Context size slider (minimum 2)
|
42 |
context_size = st.sidebar.slider("Context Size (how many tokens to look back)", min_value=2, max_value=10, value=3, step=1)
|
43 |
|
|
|
|
|
|
|
44 |
train_button = st.sidebar.button("Train Model")
|
45 |
|
46 |
device = torch.device("cpu") # force CPU usage
|
@@ -74,10 +76,12 @@ def tokenize(text, tokenizer_type):
|
|
74 |
tokens = text.split()
|
75 |
return tokens
|
76 |
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
|
80 |
-
|
81 |
PAD_TOKEN = "<PAD>"
|
82 |
if PAD_TOKEN not in vocab:
|
83 |
vocab.append(PAD_TOKEN)
|
@@ -90,10 +94,6 @@ idx_to_token = {i: tok for tok, i in token_to_idx.items()}
|
|
90 |
###################################
|
91 |
|
92 |
def pad_context(context, size):
|
93 |
-
"""
|
94 |
-
Pads the context list at the front with PAD_TOKEN if length < size,
|
95 |
-
or truncates to last `size` tokens if longer.
|
96 |
-
"""
|
97 |
pad_len = size - len(context)
|
98 |
if pad_len > 0:
|
99 |
return [PAD_TOKEN]*pad_len + context
|
@@ -145,31 +145,33 @@ class FFNN(nn.Module):
|
|
145 |
|
146 |
def train_ffnn(tokens, context_size=3, epochs=3):
|
147 |
data = []
|
148 |
-
for i in range(len(tokens)):
|
149 |
-
|
150 |
-
context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
|
151 |
context = pad_context(context, context_size - 1)
|
152 |
-
target = tokens[i]
|
153 |
data.append((
|
154 |
torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
|
155 |
token_to_idx.get(target, token_to_idx[PAD_TOKEN])
|
156 |
))
|
157 |
|
|
|
|
|
|
|
|
|
158 |
model = FFNN(len(vocab), context_size - 1).to(device)
|
159 |
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
160 |
criterion = nn.CrossEntropyLoss()
|
161 |
|
162 |
progress_bar = st.progress(0)
|
163 |
-
total_steps =
|
164 |
step = 0
|
165 |
|
166 |
model.train()
|
167 |
-
|
168 |
for epoch in range(epochs):
|
169 |
total_loss = 0
|
170 |
random.shuffle(data)
|
171 |
for x, y in data:
|
172 |
-
x = x.unsqueeze(0)
|
173 |
y = torch.tensor([y], device=device)
|
174 |
|
175 |
optimizer.zero_grad()
|
@@ -201,11 +203,10 @@ def ffnn_predict(model, context, temperature=1.0):
|
|
201 |
|
202 |
def train_dt(tokens, context_size=3):
|
203 |
X, y = [], []
|
204 |
-
for i in range(len(tokens)):
|
205 |
-
|
206 |
-
context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
|
207 |
context = pad_context(context, context_size - 1)
|
208 |
-
target = tokens[i]
|
209 |
X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
|
210 |
y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
|
211 |
|
@@ -226,11 +227,10 @@ def dt_predict(model, context):
|
|
226 |
|
227 |
def train_gbt(tokens, context_size=3):
|
228 |
X, y = [], []
|
229 |
-
for i in range(len(tokens)):
|
230 |
-
|
231 |
-
context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
|
232 |
context = pad_context(context, context_size - 1)
|
233 |
-
target = tokens[i]
|
234 |
X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
|
235 |
y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
|
236 |
|
@@ -264,26 +264,28 @@ class RNNModel(nn.Module):
|
|
264 |
|
265 |
def train_rnn(tokens, context_size=3, epochs=3):
|
266 |
data = []
|
267 |
-
for i in range(len(tokens)):
|
268 |
-
|
269 |
-
context = tokens[start_idx:i] if start_idx >= 0 else tokens[0:i]
|
270 |
context = pad_context(context, context_size - 1)
|
271 |
-
target = tokens[i]
|
272 |
data.append((
|
273 |
torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
|
274 |
token_to_idx.get(target, token_to_idx[PAD_TOKEN])
|
275 |
))
|
276 |
|
|
|
|
|
|
|
|
|
277 |
model = RNNModel(len(vocab)).to(device)
|
278 |
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
279 |
criterion = nn.CrossEntropyLoss()
|
280 |
|
281 |
progress_bar = st.progress(0)
|
282 |
-
total_steps =
|
283 |
step = 0
|
284 |
|
285 |
model.train()
|
286 |
-
|
287 |
for epoch in range(epochs):
|
288 |
total_loss = 0
|
289 |
h = None
|
@@ -296,8 +298,8 @@ def train_rnn(tokens, context_size=3, epochs=3):
|
|
296 |
optimizer.zero_grad()
|
297 |
loss.backward()
|
298 |
optimizer.step()
|
299 |
-
total_loss += loss.item()
|
300 |
|
|
|
301 |
step += 1
|
302 |
progress_bar.progress(step / total_steps)
|
303 |
|
@@ -319,7 +321,7 @@ def rnn_predict(model, context, temperature=1.0):
|
|
319 |
###################################
|
320 |
|
321 |
if train_button:
|
322 |
-
st.write(f"Training **{model_type}** model with context size {context_size}...")
|
323 |
|
324 |
if model_type == "N-gram":
|
325 |
with st.spinner("Training N-gram model..."):
|
@@ -333,10 +335,13 @@ if train_button:
|
|
333 |
elif model_type == "RNN":
|
334 |
model = train_rnn(tokens, context_size=context_size)
|
335 |
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
|
|
340 |
|
341 |
###################################
|
342 |
# Chat interface
|
|
|
38 |
|
39 |
temperature = st.sidebar.slider("Sampling Temperature", 0.1, 2.0, 1.0)
|
40 |
|
|
|
41 |
context_size = st.sidebar.slider("Context Size (how many tokens to look back)", min_value=2, max_value=10, value=3, step=1)
|
42 |
|
43 |
+
# Number of tokens from dataset to use for training (minimum 100 tokens)
|
44 |
+
num_train_tokens = st.sidebar.slider("Number of tokens from dataset to train on", min_value=100, max_value=100000, value=1000, step=100)
|
45 |
+
|
46 |
train_button = st.sidebar.button("Train Model")
|
47 |
|
48 |
device = torch.device("cpu") # force CPU usage
|
|
|
76 |
tokens = text.split()
|
77 |
return tokens
|
78 |
|
79 |
+
tokens_all = tokenize(text_data, tokenizer_type)
|
80 |
+
|
81 |
+
# Cap tokens to requested number for training
|
82 |
+
tokens = tokens_all[:num_train_tokens]
|
83 |
|
84 |
+
vocab = list(set(tokens))
|
85 |
PAD_TOKEN = "<PAD>"
|
86 |
if PAD_TOKEN not in vocab:
|
87 |
vocab.append(PAD_TOKEN)
|
|
|
94 |
###################################
|
95 |
|
96 |
def pad_context(context, size):
|
|
|
|
|
|
|
|
|
97 |
pad_len = size - len(context)
|
98 |
if pad_len > 0:
|
99 |
return [PAD_TOKEN]*pad_len + context
|
|
|
145 |
|
146 |
def train_ffnn(tokens, context_size=3, epochs=3):
|
147 |
data = []
|
148 |
+
for i in range(len(tokens) - (context_size - 1)):
|
149 |
+
context = tokens[i : i + context_size - 1]
|
|
|
150 |
context = pad_context(context, context_size - 1)
|
151 |
+
target = tokens[i + context_size - 1]
|
152 |
data.append((
|
153 |
torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
|
154 |
token_to_idx.get(target, token_to_idx[PAD_TOKEN])
|
155 |
))
|
156 |
|
157 |
+
if len(data) == 0:
|
158 |
+
st.warning("No training data generated. Increase dataset size or reduce context size.")
|
159 |
+
return None
|
160 |
+
|
161 |
model = FFNN(len(vocab), context_size - 1).to(device)
|
162 |
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
163 |
criterion = nn.CrossEntropyLoss()
|
164 |
|
165 |
progress_bar = st.progress(0)
|
166 |
+
total_steps = len(data) * epochs
|
167 |
step = 0
|
168 |
|
169 |
model.train()
|
|
|
170 |
for epoch in range(epochs):
|
171 |
total_loss = 0
|
172 |
random.shuffle(data)
|
173 |
for x, y in data:
|
174 |
+
x = x.unsqueeze(0)
|
175 |
y = torch.tensor([y], device=device)
|
176 |
|
177 |
optimizer.zero_grad()
|
|
|
203 |
|
204 |
def train_dt(tokens, context_size=3):
|
205 |
X, y = [], []
|
206 |
+
for i in range(len(tokens) - (context_size - 1)):
|
207 |
+
context = tokens[i : i + context_size - 1]
|
|
|
208 |
context = pad_context(context, context_size - 1)
|
209 |
+
target = tokens[i + context_size - 1]
|
210 |
X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
|
211 |
y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
|
212 |
|
|
|
227 |
|
228 |
def train_gbt(tokens, context_size=3):
|
229 |
X, y = [], []
|
230 |
+
for i in range(len(tokens) - (context_size - 1)):
|
231 |
+
context = tokens[i : i + context_size - 1]
|
|
|
232 |
context = pad_context(context, context_size - 1)
|
233 |
+
target = tokens[i + context_size - 1]
|
234 |
X.append([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context])
|
235 |
y.append(token_to_idx.get(target, token_to_idx[PAD_TOKEN]))
|
236 |
|
|
|
264 |
|
265 |
def train_rnn(tokens, context_size=3, epochs=3):
|
266 |
data = []
|
267 |
+
for i in range(len(tokens) - (context_size - 1)):
|
268 |
+
context = tokens[i : i + context_size - 1]
|
|
|
269 |
context = pad_context(context, context_size - 1)
|
270 |
+
target = tokens[i + context_size - 1]
|
271 |
data.append((
|
272 |
torch.tensor([token_to_idx.get(t, token_to_idx[PAD_TOKEN]) for t in context], device=device),
|
273 |
token_to_idx.get(target, token_to_idx[PAD_TOKEN])
|
274 |
))
|
275 |
|
276 |
+
if len(data) == 0:
|
277 |
+
st.warning("No training data generated. Increase dataset size or reduce context size.")
|
278 |
+
return None
|
279 |
+
|
280 |
model = RNNModel(len(vocab)).to(device)
|
281 |
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
282 |
criterion = nn.CrossEntropyLoss()
|
283 |
|
284 |
progress_bar = st.progress(0)
|
285 |
+
total_steps = len(data) * epochs
|
286 |
step = 0
|
287 |
|
288 |
model.train()
|
|
|
289 |
for epoch in range(epochs):
|
290 |
total_loss = 0
|
291 |
h = None
|
|
|
298 |
optimizer.zero_grad()
|
299 |
loss.backward()
|
300 |
optimizer.step()
|
|
|
301 |
|
302 |
+
total_loss += loss.item()
|
303 |
step += 1
|
304 |
progress_bar.progress(step / total_steps)
|
305 |
|
|
|
321 |
###################################
|
322 |
|
323 |
if train_button:
|
324 |
+
st.write(f"Training **{model_type}** model with context size {context_size} on {len(tokens)} tokens...")
|
325 |
|
326 |
if model_type == "N-gram":
|
327 |
with st.spinner("Training N-gram model..."):
|
|
|
335 |
elif model_type == "RNN":
|
336 |
model = train_rnn(tokens, context_size=context_size)
|
337 |
|
338 |
+
if model is not None:
|
339 |
+
st.session_state["model"] = model
|
340 |
+
st.session_state["model_type"] = model_type
|
341 |
+
st.session_state["context_size"] = context_size
|
342 |
+
st.success(f"{model_type} model trained.")
|
343 |
+
else:
|
344 |
+
st.error("Training failed due to no data.")
|
345 |
|
346 |
###################################
|
347 |
# Chat interface
|