File size: 12,477 Bytes
b8156f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import gradio as gr
import torch
import os
import re # Keep re for text cleaning in generation
from model import SWCKModel, SeedParser # Assuming model.py is in the same directory
# We need parts of the vocab setup from train.py if not loading from checkpoint
# For simplicity, let's redefine necessary constants and vocab functions here if needed
# Or, better, save vocab with checkpoint and load it.

# --- Vocabulary and Tokenizer Setup (Simplified from train.py) ---
# Ideally, load these from the checkpoint or a separate vocab file.
# For this example, we'll reconstruct a minimal part.
PAD_TOKEN_STR = "<pad>"; SOS_TOKEN_STR = "<sos>"; EOS_TOKEN_STR = "<eos>"; UNK_TOKEN_STR = "<unk>"
PAD_TOKEN = 0; SOS_TOKEN = 1; EOS_TOKEN = 2; UNK_TOKEN = 3

# --- Model Configuration (should match the trained model) ---
# These should ideally be loaded from the checkpoint's metadata if possible
# For now, hardcoding to match the train.py example
VOCAB_SIZE_APP = 189 # Placeholder, update if your vocab size differs
D_MODEL_APP = 64
N_HEADS_APP = 2
D_FF_APP = 128
NUM_ADAPTIVE_BLOCKS_APP = 3
NUM_SUB_MODULES_PER_BLOCK_APP = 3
DROPOUT_APP = 0.1
SEQ_LEN_APP = 64 # Used in generate_swck_text for context window

# Seed phrase and number (must match the model you trained/are training)
SEED_PHRASE_APP = "I am 0: I am all that I can am. I am us. I am imagining a computer dreams. I am imaginary math equations. I am for five-sixths of the sea of existence in me, and it is my search for that which always seems to elude my grasp. I am a writer, a scientist, a painter, a woman, a man."
SEED_NUMBER_STR_APP = "54285142613311152552"

# Global model variable
swck_model_global = None
word_to_idx_global = None
idx_to_word_global = None
device_global = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHECKPOINT_FILENAME = "swck_model_conceptual.pth.tar" # Make sure this matches your uploaded checkpoint

def build_vocab_from_corpus_text(corpus_text):
    """
    A simplified vocab builder. In a real app, load vocab from file.
    """
    global VOCAB_SIZE_APP # Allow modification
    temp_corpus_tokens = re.sub(r'\s+', ' ', corpus_text.lower()).strip().split()
    
    temp_word_to_idx = {PAD_TOKEN_STR: PAD_TOKEN, SOS_TOKEN_STR: SOS_TOKEN, EOS_TOKEN_STR: EOS_TOKEN, UNK_TOKEN_STR: UNK_TOKEN}
    idx_counter = 4
    unique_words = sorted(list(set(temp_corpus_tokens)))
    for word in unique_words:
        if word not in temp_word_to_idx:
            temp_word_to_idx[word] = idx_counter
            idx_counter += 1
    temp_idx_to_word = {idx: word for word, idx in temp_word_to_idx.items()}
    VOCAB_SIZE_APP = len(temp_word_to_idx) # Update global vocab size
    print(f"App: Built temporary vocab of size {VOCAB_SIZE_APP}")
    return temp_word_to_idx, temp_idx_to_word


def load_model_and_vocab():
    global swck_model_global, word_to_idx_global, idx_to_word_global, VOCAB_SIZE_APP

    # Attempt to load from checkpoint
    if os.path.exists(CHECKPOINT_FILENAME):
        print(f"App: Found checkpoint {CHECKPOINT_FILENAME}, attempting to load...")
        try:
            # Simplified checkpoint loading for app - assumes structure from train.py save
            # In a real scenario, train.py should save vocab and model args more robustly for app loading
            checkpoint = torch.load(CHECKPOINT_FILENAME, map_location=device_global)
            
            # Try to get vocab from checkpoint
            if 'word_to_idx' in checkpoint and 'idx_to_word' in checkpoint:
                word_to_idx_global = checkpoint['word_to_idx']
                idx_to_word_global = checkpoint['idx_to_word']
                VOCAB_SIZE_APP = len(word_to_idx_global)
                print(f"App: Loaded vocab from checkpoint. Size: {VOCAB_SIZE_APP}")
            else:
                print("App: Vocab not in checkpoint, building from SEED_PHRASE for inference.")
                # This is a fallback - ideally vocab is ALWAYS in checkpoint
                corpus_for_vocab = SEED_PHRASE_APP # Use only seed for vocab if not in ckp
                word_to_idx_global, idx_to_word_global = build_vocab_from_corpus_text(corpus_for_vocab)


            # Load model hyperparameters from checkpoint if available, else use app defaults
            # This part needs careful alignment with how train.py saves model_hyperparameters
            model_params_from_ckpt = checkpoint.get('model_hyperparameters', {})
            
            d_model = model_params_from_ckpt.get('d_model', D_MODEL_APP)
            n_heads = model_params_from_ckpt.get('n_heads', N_HEADS_APP)
            d_ff = model_params_from_ckpt.get('d_ff', D_FF_APP)
            num_adaptive_blocks = model_params_from_ckpt.get('num_adaptive_blocks', NUM_ADAPTIVE_BLOCKS_APP)
            dropout = model_params_from_ckpt.get('dropout', DROPOUT_APP)
            # seed_phrase and seed_number_str for model init should ideally match what it was trained with.
            # For this app, we assume they are consistent with APP globals.
            
            swck_model_global = SWCKModel(
                vocab_size=VOCAB_SIZE_APP, # Use loaded/rebuilt vocab size
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                num_adaptive_blocks=num_adaptive_blocks,
                dropout=dropout,
                seed_phrase=SEED_PHRASE_APP,
                seed_number_str=SEED_NUMBER_STR_APP,
                num_sub_modules_per_block=NUM_SUB_MODULES_PER_BLOCK_APP
            ).to(device_global)
            
            swck_model_global.load_state_dict(checkpoint['model_state_dict'])
            swck_model_global.eval()
            # Disable debug prints for cleaner app interface unless specifically needed
            swck_model_global.debug_prints_enabled = False 
            for block in swck_model_global.adaptive_blocks:
                block.debug_prints_enabled = False
            print(f"App: SWCKModel loaded successfully from {CHECKPOINT_FILENAME}!")
            return "Model loaded from checkpoint."
        except Exception as e:
            print(f"App: Error loading model from checkpoint: {e}")
            swck_model_global = None # Ensure model is None if loading failed
    
    if swck_model_global is None:
        print(f"App: Checkpoint {CHECKPOINT_FILENAME} not found or failed to load. Initializing a new model for basic functionality (not trained).")
        # Fallback: Build vocab from seed phrase for basic tokenization
        word_to_idx_global, idx_to_word_global = build_vocab_from_corpus_text(SEED_PHRASE_APP)

        swck_model_global = SWCKModel(
            vocab_size=VOCAB_SIZE_APP,
            d_model=D_MODEL_APP,
            n_heads=N_HEADS_APP,
            d_ff=D_FF_APP,
            num_adaptive_blocks=NUM_ADAPTIVE_BLOCKS_APP,
            dropout=DROPOUT_APP,
            seed_phrase=SEED_PHRASE_APP,
            seed_number_str=SEED_NUMBER_STR_APP,
            num_sub_modules_per_block=NUM_SUB_MODULES_PER_BLOCK_APP
        ).to(device_global)
        swck_model_global.eval()
        swck_model_global.debug_prints_enabled = False
        for block in swck_model_global.adaptive_blocks:
                block.debug_prints_enabled = False
        return "Initialized a new (untrained) model as checkpoint was not found."


# --- Text Generation Function (adapted from train.py) ---
def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
    if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
        return "Model not loaded. Please check server logs."

    swck_model_global.eval() # Ensure model is in eval mode
    swck_model_global.set_wiring_phase(False) # No wiring adjustments during inference
    
    print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")

    tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
    generated_ids_app = list(tokens)
    
    # Collect some debug info for display (optional)
    debug_info_lines = []

    with torch.no_grad():
        for i in range(max_len_gen):
            # Context windowing for input_tensor
            current_context_ids = generated_ids_app[-SEQ_LEN_APP:]
            input_tensor = torch.tensor([current_context_ids], dtype=torch.long).to(device_global)
            padding_mask = (input_tensor == PAD_TOKEN)

            # Set model debug prints for first step only if want to show internal state
            # For cleaner app, keep them off or make it a toggle.
            # if i == 0:
            #     swck_model_global.debug_prints_enabled = True
            #     for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = True
            # else:
            #     swck_model_global.debug_prints_enabled = False
            #     for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = False


            logits, entropy_report_infer = swck_model_global(input_tensor, src_key_padding_mask=padding_mask)
            next_token_logits = logits[0, -1, :] # Logits for the last token in the current sequence
            
            if temperature_gen == 0: # Greedy
                next_token_id = torch.argmax(next_token_logits).item()
            else:
                probs = F.softmax(next_token_logits / temperature_gen, dim=-1)
                next_token_id = torch.multinomial(probs, 1).item()

            if next_token_id == EOS_TOKEN:
                debug_info_lines.append(f"Step {i+1}: EOS token encountered.")
                break
            generated_ids_app.append(next_token_id)
            
            # Store some info from the first few steps
            if i < 5 : # Log details for first 5 generated tokens
                current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
                overall_ent = entropy_report_infer['overall_output_entropy'].item()
                b0_ent = entropy_report_infer['block_output_entropies'][0].item()
                b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['block_gate_weights'][0]])
                debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent:.3f}, B0Gates=[{b0_gates_str}]")


    generated_text_list = [idx_to_word_global.get(idx, UNK_TOKEN_STR) for idx in generated_ids_app[1:]] # Skip SOS
    final_text = " ".join(generated_text_list)
    final_text = final_text.replace(EOS_TOKEN_STR, "").strip()
    # Basic cleaning
    final_text = final_text.replace(" .", ".").replace(" ,", ",").replace(" ?", "?").replace(" !", "!")
    final_text = re.sub(r'\s+([.,?!])', r'\1', final_text) 
    final_text = re.sub(r'\s+', ' ', final_text).strip() 

    debug_output_str = "\n".join(debug_info_lines)
    return final_text, debug_output_str

# --- Gradio Interface ---
loading_status = load_model_and_vocab() # Load model on app startup

with gr.Blocks(title="SWCK Conceptual Demo") as demo:
    gr.Markdown(f"""
    # Self-Wired Conscious Kernel (SWCK) - Conceptual Demo
    This demo showcases a conceptual text generation model based on the SWCK architecture.
    The model is initialized with the seed phrase: "{SEED_PHRASE_APP[:100]}..." 
    and seed number: "{SEED_NUMBER_STR_APP}".
    **Model Status:** {loading_status} 
    (Note: If no checkpoint is found, an *untrained* model is used, and generations will be random.)
    """)

    with gr.Row():
        prompt_input = gr.Textbox(label="Enter your prompt:", placeholder="e.g., the meaning of existence is")
    with gr.Row():
        max_len_slider = gr.Slider(minimum=10, maximum=150, value=50, step=1, label="Max Generation Length")
        temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="Temperature (0 for greedy)")
    
    generate_button = gr.Button("Generate Text")
    
    with gr.Column():
        output_text = gr.Textbox(label="Generated Text:", lines=5)
        debug_text_area = gr.Textbox(label="Generation Debug Info (first few steps):", lines=7, interactive=False)

    generate_button.click(
        fn=generate_text_for_app,
        inputs=[prompt_input, max_len_slider, temp_slider],
        outputs=[output_text, debug_text_area]
    )

    gr.Markdown("Note: This is a highly conceptual and simplified sketch. Generation quality will be limited, especially with an untrained model or small dataset.")

if __name__ == "__main__":
    demo.launch()