File size: 10,466 Bytes
fe643f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
feb47ff
fe643f6
 
 
 
 
 
 
feb47ff
fe643f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa1e750
fe643f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa1e750
fe643f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45d9b08
fe643f6
45d9b08
fe643f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b08d128
fe643f6
 
 
 
35a9bb1
fe643f6
 
 
 
 
 
35a9bb1
fe643f6
 
35a9bb1
fe643f6
 
 
 
 
b08d128
fe643f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import marimo

__generated_with = "0.11.9"
app = marimo.App()


@app.cell(hide_code=True)
def _():
    import marimo as mo
    import synalinks

    synalinks.backend.clear_session()
    return mo, synalinks


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        # Training Programs

        Like in machine learning, a LM application needs to be trained. In that case, we
        don't update the weights of the model, but optimize the prompts by automatically
        picking the best examples or generate hints in order to help the program to 
        perform better on your dataset.

        For this lesson we are going to work on GSM8k a well known dataset of grade school
        math word problems. Nowedays, most (all?) public datasets have been leaked, meaning
        that their test set have been included in the LM trainset. This basically means
        that the baseline score won't give you much information about the reasoning abilities
        of the underlying language model (but more about its capability to remember),
        however it is still interesing to have it as a baseline to evaluate the progress 
        of the programs training and the neuro-symbolic methods used or if you use small
        models like here.

        First, let's have a look at the dataset.
        """
    )
    return


@app.cell
def _(synalinks):
    gsm8k_input_data_model = synalinks.datasets.gsm8k.get_input_data_model()
    print("GSM8K input schema:\n")
    print(gsm8k_input_data_model.prettify_schema())
    return (gsm8k_input_data_model,)


@app.cell
def _(synalinks):
    gsm8k_output_data_model = synalinks.datasets.gsm8k.get_output_data_model()
    print("GSM8K output schema:\n")
    print(gsm8k_output_data_model.prettify_schema())
    return (gsm8k_output_data_model,)


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Programming the pipeline

        Now let's make a simple baseline program like in the first lessons
        For this example we are going to use the data models from GSM8k.
        """
    )
    return


@app.cell
async def _(gsm8k_input_data_model, gsm8k_output_data_model, synalinks):
        
    language_model = synalinks.LanguageModel(
        model="openai/gpt-4o-mini",
    )

    _x0 = synalinks.Input(data_model=gsm8k_input_data_model)
    _x1 = await synalinks.Generator(
        data_model=gsm8k_output_data_model,
        language_model=language_model,
    )(_x0)

    program = synalinks.Program(
        inputs=_x0,
        outputs=_x1,
        name="chain_of_thought",
        description="Useful to answer in a step by step manner.",
    )
    return language_model, program


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Compiling the program

        For this example, we are going to select the `RandomFewShot` optimizer.
        The reward fucntion will be `ExactMatch` masked to match only the numerical answer.
        While the additional metric will be the `F1Score` masked to process only the LMs thinking.

        This metric will give us an indication to see if the chain of thought match with the dataset one.
        """
    )
    return


@app.cell
def _(program, synalinks):
    program.compile(
        optimizer=synalinks.optimizers.RandomFewShot(),
        reward=synalinks.rewards.ExactMatch(in_mask=["answer"]),
        metrics=[
            synalinks.metrics.F1Score(in_mask=["thinking"]),
        ],
    )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Training

        ### What do "sample", "batch", and "epoch" mean?

        - **Sample**: A sample is one element of a dataset. For example, one DataModel
            is one sample.
        - **Batch**: A batch is a set of N samples. The samples in a batch are processed
            independently, in parallel. During training, a batch result in only one
            program update. A batch approximates the input distribution better than a
            single input. The larger the batch, the better the approximation; however a
            larger batch will take longer to process and still result in only one update.
        - **Epochs**: A epochs is an arbitrarly cutoff, generally defined as "one pass
            over the entire dataset", used to separate training into distinct phases,
            which is useful for logging and periodic evaluation. When using 
            `validation_split` or `validation_data` with the `fit` method of Synalinks
            programs, evaluation will be run at the end of every epoch.
        """
    )
    return


@app.cell(hide_code=True)
def _(mo):
    load_data = mo.ui.run_button(label="Load dataset")
    load_data.center()
    return (load_data,)


@app.cell
def _(load_data, mo, synalinks):
    mo.stop(not load_data.value, mo.md("Click on the load button above"))
    # Now we can load the dataset
    with mo.status.spinner(title="Loading dataset...") as _spinner:
        (x_train, y_train), (x_test, y_test) = synalinks.datasets.gsm8k.load_data()
        _spinner.update("Done.")
    return x_test, x_train, y_test, y_train


@app.cell(hide_code=True)
def _(mo, x_test, x_train):
    epochs = mo.ui.slider(start=1, stop=64, value=5, label="Epochs")
    batch_size = mo.ui.slider(start=1, stop=64, value=32, label="Batch size")
    train_samples = mo.ui.slider(
        start=1, stop=len(x_train), value=50, label="Train Samples"
    )
    test_samples = mo.ui.slider(start=1, stop=len(x_test), value=50, label="Test Samples")
    return batch_size, epochs, test_samples, train_samples


@app.cell(hide_code=True)
def _(epochs):
    mo.hstack([epochs, mo.md(f"Epochs: {epochs.value}")])    
    return

@app.cell(hide_code=True)
def _(batch_size):
    mo.hstack([batch_size, mo.md(f"Batch size: {batch_size.value}")])
    return

@app.cell(hide_code=True)
def _(train_samples):
    mo.hstack([train_samples, mo.md(f"Nb train samples: {train_samples.value}")])
    return

@app.cell(hide_code=True)
def _(test_samples):
    mo.hstack([test_samples, mo.md(f"Nb test samples: {test_samples.value}")])
    return

@app.cell(hide_code=True)
def _(mo):
    openai_api_key = mo.ui.text_area(placeholder="Your OpenAI API key...").form()
    openai_api_key
    return


@app.cell(hide_code=True)
def _(mo, openai_api_key):
    import litellm
    mo.stop(not openai_api_key.value)
    litellm.openai_key = openai_api_key.value
    return


@app.cell(hide_code=True)
def _(mo):
    train_button = mo.ui.run_button(label="Train")
    train_button.center()
    return (train_button,)


@app.cell
async def train(
    batch_size,
    epochs,
    mo,
    program,
    train_button,
    synalinks,
    test_samples,
    train_samples,
    x_test,
    x_train,
    y_test,
    y_train,
):
    mo.stop(not openai_api_key.value, mo.md("Provide your OpenAI API key"))
    mo.stop(not train_button.value, mo.md("Click on the train button above"))
    # Where to save the best performing program
    checkpoint_filepath = "checkpoint.program.json"

    _program_checkpoint_callback = synalinks.callbacks.ProgramCheckpoint(
        filepath=checkpoint_filepath,
        monitor="val_reward",
        mode="max",
        save_best_only=True,
    )

    # For the purpose of the tutorial, we'll only train on the first N samples

    history = await program.fit(
        epochs=epochs.value,
        batch_size=batch_size.value,
        x=x_train[: train_samples.value],
        y=y_train[: train_samples.value],
        validation_data=(x_test[: test_samples.value], y_test[: test_samples.value]),
        callbacks=[_program_checkpoint_callback],
    )
    return checkpoint_filepath, history


@app.cell
def _(history, synalinks):
    synalinks.utils.plot_history(history)
    return


@app.cell(hide_code=True)
def _(synalinks):
    mo.md(
        r"""
        ## Evaluate Checkpoint
        """
    )
    return


@app.cell
async def _(
        checkpoint_filepath,
        train,
        x_test,
        y_test,
        batch_size,
        test_samples,
        synalinks,
    ):
    # Load the JSON serialized program from disk
    loaded_program = synalinks.Program.load(checkpoint_filepath)
    
    metrics = await loaded_program.evaluate(
        x=x_test[: test_samples],
        y=y_test[: test_samples],
        batch_size=batch_size.value,
    )
    
    synalinks.utils.plot_metrics(metrics)

@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Conclusion
        
        In this notebook, we explored the process of training Synalinks programs
        to optimize their performance on specific datasets. By leveraging the GSM8k
        dataset of grade school math word problems, we demonstrated how to train a
        language model application to improve its reasoning abilities and accuracy.
        
        ### Key Takeaways
        
        - **Rewards**: `Reward`s guide the reinforcement learning process by 
            providing feedback on the system's performance. They are typically
            float values that indicate how well the system performed a task, 
            with the goal of maximizing the reward function during training. 
            Synalinks offers built-in rewards and allows for custom reward 
            functions to suit specific tasks.
            
        - **Metrics**: `Metric`s are scalar values monitored during training
            and evaluation to determine the best-performing program. Unlike
            rewards, metrics are not used for backpropagation. They provide 
            additional insights for comparing different architectures and 
            saving the optimal model.
            
        - **Optimizers**: `Optimizer`s update the module's state to improve
            performance. They handle the backpropagation of rewards and 
            select or generate examples and hints for the language models.
            Proper configuration of optimizers is essential for effective
            training.
            
        - **Filtering Outputs**: When dealing with complex JSON outputs, 
            filtering predictions and ground truths using `out_mask` or 
            `in_mask` parameters ensures that only relevant fields are 
            evaluated. This is particularly useful when the training data 
            includes a subset of the JSON or when additional fields are
            used to aid the language models.
        """
    )
    return


if __name__ == "__main__":
    app.run()