neuralworm commited on
Commit
d2d8270
·
verified ·
1 Parent(s): b41e522

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -18
app.py CHANGED
@@ -2,11 +2,12 @@ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
- from torch.utils.data import Dataset, DataLoader # For dummy training
6
  import os
7
  import re
8
- import time # For basic progress update
9
- from model import SWCKModel, SeedParser, EntropyEstimator # Assuming model.py is in the same directory
 
10
 
11
  # --- Vocabulary and Tokenizer Setup ---
12
  PAD_TOKEN_STR = "<pad>"; SOS_TOKEN_STR = "<sos>"; EOS_TOKEN_STR = "<eos>"; UNK_TOKEN_STR = "<unk>"
@@ -218,10 +219,10 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
218
 
219
  if logits.size(1) != gold_standard_for_loss.size(1):
220
  min_len = min(logits.size(1), gold_standard_for_loss.size(1))
221
- logits_for_loss = logits[:, :min_len, :].contiguous() # ADDED .contiguous()
222
  gold_for_loss_aligned = gold_standard_for_loss[:, :min_len].contiguous()
223
  else:
224
- logits_for_loss = logits.contiguous() # ADDED .contiguous()
225
  gold_for_loss_aligned = gold_standard_for_loss.contiguous()
226
 
227
  main_loss = criterion_main_app(logits_for_loss.view(-1, logits_for_loss.size(-1)), gold_for_loss_aligned.view(-1))
@@ -230,7 +231,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
230
  if entropy_report["block_output_entropies"]:
231
  for i, block_entropy_tensor in enumerate(entropy_report["block_output_entropies"]):
232
  target_entropy_val = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
233
- block_entropy_loss += F.mse_loss(block_entropy_tensor, torch.tensor(target_entropy_val, device=device_global))
234
  if entropy_report["block_output_entropies"]:
235
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
236
 
@@ -303,11 +304,6 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
303
  swck_model_global.eval()
304
  swck_model_global.set_wiring_phase(False)
305
 
306
- # Temporarily re-enable debug for generation if you want to inspect Space logs
307
- # swck_model_global.debug_prints_enabled = True
308
- # if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
309
- # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
310
-
311
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
312
 
313
  tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
@@ -328,7 +324,7 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
328
  if temperature_gen == 0:
329
  next_token_id = torch.argmax(next_token_logits).item()
330
  else:
331
- probs = F.softmax(next_token_logits / temperature_gen, dim=-1)
332
  if probs.isnan().any() or probs.isinf().any() or torch.sum(probs).item() < 1e-9 :
333
  print(f"Warning: Invalid probabilities at step {i}. Using uniform.")
334
  probs = torch.ones_like(next_token_logits) / next_token_logits.size(-1)
@@ -361,15 +357,11 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
361
 
362
  debug_output_str = "\n".join(debug_info_lines)
363
 
364
- # swck_model_global.debug_prints_enabled = False # Disable after generation
365
- # if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
366
- # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
367
  return final_text, debug_output_str
368
 
369
  initial_load_status = initialize_or_load_model_app()
370
 
371
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
372
- # Using a unique elem_id for the status Markdown
373
  model_status_md = gr.Markdown(value=f"**Model Status:** {initial_load_status}", elem_id="model_status_md_123")
374
 
375
  gr.Markdown(f"""
@@ -402,7 +394,6 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
402
  training_status_output = gr.Textbox(label="Training Log / Status:", lines=10, interactive=False,show_label=True )
403
 
404
  def update_status_text_for_ui():
405
- # This function will be called by .then() to get the new status string
406
  return f"**Model Status:** {model_load_status_global}"
407
 
408
  generate_button.click(
@@ -415,7 +406,7 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
415
  fn=run_short_training_session,
416
  inputs=[train_epochs_slider, train_batch_size_slider, train_lr_slider],
417
  outputs=[training_status_output]
418
- ).then(fn=update_status_text_for_ui, inputs=None, outputs=model_status_md) # Update the Markdown component
419
 
420
 
421
  if __name__ == "__main__":
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
  import os
7
  import re
8
+ import time
9
+ import torch.nn.functional as F # <<<<<<<<<<<< ADDED THIS IMPORT
10
+ from model import SWCKModel, SeedParser, EntropyEstimator
11
 
12
  # --- Vocabulary and Tokenizer Setup ---
13
  PAD_TOKEN_STR = "<pad>"; SOS_TOKEN_STR = "<sos>"; EOS_TOKEN_STR = "<eos>"; UNK_TOKEN_STR = "<unk>"
 
219
 
220
  if logits.size(1) != gold_standard_for_loss.size(1):
221
  min_len = min(logits.size(1), gold_standard_for_loss.size(1))
222
+ logits_for_loss = logits[:, :min_len, :].contiguous()
223
  gold_for_loss_aligned = gold_standard_for_loss[:, :min_len].contiguous()
224
  else:
225
+ logits_for_loss = logits.contiguous()
226
  gold_for_loss_aligned = gold_standard_for_loss.contiguous()
227
 
228
  main_loss = criterion_main_app(logits_for_loss.view(-1, logits_for_loss.size(-1)), gold_for_loss_aligned.view(-1))
 
231
  if entropy_report["block_output_entropies"]:
232
  for i, block_entropy_tensor in enumerate(entropy_report["block_output_entropies"]):
233
  target_entropy_val = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
234
+ block_entropy_loss += F.mse_loss(block_entropy_tensor, torch.tensor(target_entropy_val, device=device_global)) # Used F here
235
  if entropy_report["block_output_entropies"]:
236
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
237
 
 
304
  swck_model_global.eval()
305
  swck_model_global.set_wiring_phase(False)
306
 
 
 
 
 
 
307
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
308
 
309
  tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
 
324
  if temperature_gen == 0:
325
  next_token_id = torch.argmax(next_token_logits).item()
326
  else:
327
+ probs = F.softmax(next_token_logits / temperature_gen, dim=-1) # Used F here
328
  if probs.isnan().any() or probs.isinf().any() or torch.sum(probs).item() < 1e-9 :
329
  print(f"Warning: Invalid probabilities at step {i}. Using uniform.")
330
  probs = torch.ones_like(next_token_logits) / next_token_logits.size(-1)
 
357
 
358
  debug_output_str = "\n".join(debug_info_lines)
359
 
 
 
 
360
  return final_text, debug_output_str
361
 
362
  initial_load_status = initialize_or_load_model_app()
363
 
364
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
 
365
  model_status_md = gr.Markdown(value=f"**Model Status:** {initial_load_status}", elem_id="model_status_md_123")
366
 
367
  gr.Markdown(f"""
 
394
  training_status_output = gr.Textbox(label="Training Log / Status:", lines=10, interactive=False,show_label=True )
395
 
396
  def update_status_text_for_ui():
 
397
  return f"**Model Status:** {model_load_status_global}"
398
 
399
  generate_button.click(
 
406
  fn=run_short_training_session,
407
  inputs=[train_epochs_slider, train_batch_size_slider, train_lr_slider],
408
  outputs=[training_status_output]
409
+ ).then(fn=update_status_text_for_ui, inputs=None, outputs=model_status_md)
410
 
411
 
412
  if __name__ == "__main__":