File size: 10,466 Bytes
fe643f6 feb47ff fe643f6 feb47ff fe643f6 fa1e750 fe643f6 fa1e750 fe643f6 45d9b08 fe643f6 45d9b08 fe643f6 b08d128 fe643f6 35a9bb1 fe643f6 35a9bb1 fe643f6 35a9bb1 fe643f6 b08d128 fe643f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 |
import marimo
__generated_with = "0.11.9"
app = marimo.App()
@app.cell(hide_code=True)
def _():
import marimo as mo
import synalinks
synalinks.backend.clear_session()
return mo, synalinks
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
# Training Programs
Like in machine learning, a LM application needs to be trained. In that case, we
don't update the weights of the model, but optimize the prompts by automatically
picking the best examples or generate hints in order to help the program to
perform better on your dataset.
For this lesson we are going to work on GSM8k a well known dataset of grade school
math word problems. Nowedays, most (all?) public datasets have been leaked, meaning
that their test set have been included in the LM trainset. This basically means
that the baseline score won't give you much information about the reasoning abilities
of the underlying language model (but more about its capability to remember),
however it is still interesing to have it as a baseline to evaluate the progress
of the programs training and the neuro-symbolic methods used or if you use small
models like here.
First, let's have a look at the dataset.
"""
)
return
@app.cell
def _(synalinks):
gsm8k_input_data_model = synalinks.datasets.gsm8k.get_input_data_model()
print("GSM8K input schema:\n")
print(gsm8k_input_data_model.prettify_schema())
return (gsm8k_input_data_model,)
@app.cell
def _(synalinks):
gsm8k_output_data_model = synalinks.datasets.gsm8k.get_output_data_model()
print("GSM8K output schema:\n")
print(gsm8k_output_data_model.prettify_schema())
return (gsm8k_output_data_model,)
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
## Programming the pipeline
Now let's make a simple baseline program like in the first lessons
For this example we are going to use the data models from GSM8k.
"""
)
return
@app.cell
async def _(gsm8k_input_data_model, gsm8k_output_data_model, synalinks):
language_model = synalinks.LanguageModel(
model="openai/gpt-4o-mini",
)
_x0 = synalinks.Input(data_model=gsm8k_input_data_model)
_x1 = await synalinks.Generator(
data_model=gsm8k_output_data_model,
language_model=language_model,
)(_x0)
program = synalinks.Program(
inputs=_x0,
outputs=_x1,
name="chain_of_thought",
description="Useful to answer in a step by step manner.",
)
return language_model, program
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
## Compiling the program
For this example, we are going to select the `RandomFewShot` optimizer.
The reward fucntion will be `ExactMatch` masked to match only the numerical answer.
While the additional metric will be the `F1Score` masked to process only the LMs thinking.
This metric will give us an indication to see if the chain of thought match with the dataset one.
"""
)
return
@app.cell
def _(program, synalinks):
program.compile(
optimizer=synalinks.optimizers.RandomFewShot(),
reward=synalinks.rewards.ExactMatch(in_mask=["answer"]),
metrics=[
synalinks.metrics.F1Score(in_mask=["thinking"]),
],
)
return
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
## Training
### What do "sample", "batch", and "epoch" mean?
- **Sample**: A sample is one element of a dataset. For example, one DataModel
is one sample.
- **Batch**: A batch is a set of N samples. The samples in a batch are processed
independently, in parallel. During training, a batch result in only one
program update. A batch approximates the input distribution better than a
single input. The larger the batch, the better the approximation; however a
larger batch will take longer to process and still result in only one update.
- **Epochs**: A epochs is an arbitrarly cutoff, generally defined as "one pass
over the entire dataset", used to separate training into distinct phases,
which is useful for logging and periodic evaluation. When using
`validation_split` or `validation_data` with the `fit` method of Synalinks
programs, evaluation will be run at the end of every epoch.
"""
)
return
@app.cell(hide_code=True)
def _(mo):
load_data = mo.ui.run_button(label="Load dataset")
load_data.center()
return (load_data,)
@app.cell
def _(load_data, mo, synalinks):
mo.stop(not load_data.value, mo.md("Click on the load button above"))
# Now we can load the dataset
with mo.status.spinner(title="Loading dataset...") as _spinner:
(x_train, y_train), (x_test, y_test) = synalinks.datasets.gsm8k.load_data()
_spinner.update("Done.")
return x_test, x_train, y_test, y_train
@app.cell(hide_code=True)
def _(mo, x_test, x_train):
epochs = mo.ui.slider(start=1, stop=64, value=5, label="Epochs")
batch_size = mo.ui.slider(start=1, stop=64, value=32, label="Batch size")
train_samples = mo.ui.slider(
start=1, stop=len(x_train), value=50, label="Train Samples"
)
test_samples = mo.ui.slider(start=1, stop=len(x_test), value=50, label="Test Samples")
return batch_size, epochs, test_samples, train_samples
@app.cell(hide_code=True)
def _(epochs):
mo.hstack([epochs, mo.md(f"Epochs: {epochs.value}")])
return
@app.cell(hide_code=True)
def _(batch_size):
mo.hstack([batch_size, mo.md(f"Batch size: {batch_size.value}")])
return
@app.cell(hide_code=True)
def _(train_samples):
mo.hstack([train_samples, mo.md(f"Nb train samples: {train_samples.value}")])
return
@app.cell(hide_code=True)
def _(test_samples):
mo.hstack([test_samples, mo.md(f"Nb test samples: {test_samples.value}")])
return
@app.cell(hide_code=True)
def _(mo):
openai_api_key = mo.ui.text_area(placeholder="Your OpenAI API key...").form()
openai_api_key
return
@app.cell(hide_code=True)
def _(mo, openai_api_key):
import litellm
mo.stop(not openai_api_key.value)
litellm.openai_key = openai_api_key.value
return
@app.cell(hide_code=True)
def _(mo):
train_button = mo.ui.run_button(label="Train")
train_button.center()
return (train_button,)
@app.cell
async def train(
batch_size,
epochs,
mo,
program,
train_button,
synalinks,
test_samples,
train_samples,
x_test,
x_train,
y_test,
y_train,
):
mo.stop(not openai_api_key.value, mo.md("Provide your OpenAI API key"))
mo.stop(not train_button.value, mo.md("Click on the train button above"))
# Where to save the best performing program
checkpoint_filepath = "checkpoint.program.json"
_program_checkpoint_callback = synalinks.callbacks.ProgramCheckpoint(
filepath=checkpoint_filepath,
monitor="val_reward",
mode="max",
save_best_only=True,
)
# For the purpose of the tutorial, we'll only train on the first N samples
history = await program.fit(
epochs=epochs.value,
batch_size=batch_size.value,
x=x_train[: train_samples.value],
y=y_train[: train_samples.value],
validation_data=(x_test[: test_samples.value], y_test[: test_samples.value]),
callbacks=[_program_checkpoint_callback],
)
return checkpoint_filepath, history
@app.cell
def _(history, synalinks):
synalinks.utils.plot_history(history)
return
@app.cell(hide_code=True)
def _(synalinks):
mo.md(
r"""
## Evaluate Checkpoint
"""
)
return
@app.cell
async def _(
checkpoint_filepath,
train,
x_test,
y_test,
batch_size,
test_samples,
synalinks,
):
# Load the JSON serialized program from disk
loaded_program = synalinks.Program.load(checkpoint_filepath)
metrics = await loaded_program.evaluate(
x=x_test[: test_samples],
y=y_test[: test_samples],
batch_size=batch_size.value,
)
synalinks.utils.plot_metrics(metrics)
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
## Conclusion
In this notebook, we explored the process of training Synalinks programs
to optimize their performance on specific datasets. By leveraging the GSM8k
dataset of grade school math word problems, we demonstrated how to train a
language model application to improve its reasoning abilities and accuracy.
### Key Takeaways
- **Rewards**: `Reward`s guide the reinforcement learning process by
providing feedback on the system's performance. They are typically
float values that indicate how well the system performed a task,
with the goal of maximizing the reward function during training.
Synalinks offers built-in rewards and allows for custom reward
functions to suit specific tasks.
- **Metrics**: `Metric`s are scalar values monitored during training
and evaluation to determine the best-performing program. Unlike
rewards, metrics are not used for backpropagation. They provide
additional insights for comparing different architectures and
saving the optimal model.
- **Optimizers**: `Optimizer`s update the module's state to improve
performance. They handle the backpropagation of rewards and
select or generate examples and hints for the language models.
Proper configuration of optimizers is essential for effective
training.
- **Filtering Outputs**: When dealing with complex JSON outputs,
filtering predictions and ground truths using `out_mask` or
`in_mask` parameters ensures that only relevant fields are
evaluated. This is particularly useful when the training data
includes a subset of the JSON or when additional fields are
used to aid the language models.
"""
)
return
if __name__ == "__main__":
app.run()
|