MilesCranmer commited on
Commit
07e3605
1 Parent(s): 24b8a6f

Add example for custom objective

Browse files
Files changed (1) hide show
  1. docs/examples.md +118 -1
docs/examples.md CHANGED
@@ -318,7 +318,124 @@ model.predict(X, -1)
318
 
319
  to make predictions with the most accurate expression.
320
 
321
- ## 9. Additional features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  For the many other features available in PySR, please
324
  read the [Options section](options.md).
 
318
 
319
  to make predictions with the most accurate expression.
320
 
321
+ ## 9. Custom objectives
322
+
323
+ You can also pass a custom objectives as a snippet of Julia code,
324
+ which might include symbolic manipulations or custom functional forms.
325
+ These do not even need to be differentiable! First, let's look at the
326
+ default objective used (a simplified version, without weights
327
+ and with mean square error), so that you can see how to write your own:
328
+
329
+ ```julia
330
+ function default_objective(tree, dataset::Dataset{T,L}, options)::L where {T,L}
331
+ (prediction, completion) = eval_tree_array(tree, dataset.X, options)
332
+ if !completion
333
+ return L(Inf)
334
+ end
335
+
336
+ diffs = prediction .- dataset.y
337
+
338
+ return sum(diffs .^ 2) / length(diffs)
339
+ end
340
+ ```
341
+
342
+ Here, the `where {T,L}` syntax defines the function for arbitrary types `T` and `L`.
343
+ If you have `precision=32` (default) and pass in regular floating point data,
344
+ then both `T` and `L` will be equal to `Float32`. If you pass in complex data,
345
+ then `T` will be `ComplexF32` and `L` will be `Float32` (since we need to return
346
+ a real number from the loss function). But, you don't need to worry about this, just
347
+ make sure to return a scalar number of type `L`.
348
+
349
+ The `tree` argument is the current expression being evaluated. You can read
350
+ about the `tree` fields [here](https://astroautomata.com/SymbolicRegression.jl/stable/types/).
351
+
352
+ For example, let's fix a symbolic form of an expression,
353
+ as a rational function. i.e., $P(X)/Q(X)$ for polynomials $P$ and $Q$.
354
+
355
+ ```python
356
+ objective = """
357
+ function my_custom_objective(tree, dataset::Dataset{T,L}, options) where {T,L}
358
+ # Require root node to be binary, so we can split it,
359
+ # otherwise return a large loss:
360
+ tree.degree != 2 && return L(Inf)
361
+
362
+ P = tree.l
363
+ Q = tree.r
364
+
365
+ # Evaluate numerator:
366
+ P_prediction, flag = eval_tree_array(P, dataset.X, options)
367
+ !flag && return L(Inf)
368
+
369
+ # Evaluate denominator:
370
+ Q_prediction, flag = eval_tree_array(Q, dataset.X, options)
371
+ !flag && return L(Inf)
372
+
373
+ # Impose functional form:
374
+ prediction = P_prediction ./ Q_prediction
375
+
376
+ diffs = prediction .- dataset.y
377
+
378
+ return sum(diffs .^ 2) / length(diffs)
379
+ end
380
+ """
381
+
382
+ model = PySRRegressor(
383
+ niterations=100,
384
+ binary_operators=["*", "+", "-"],
385
+ full_objective=objective,
386
+ )
387
+ ```
388
+
389
+ > **Warning**: When using a custom objective like this that performs symbolic
390
+ > manipulations, many functionalities of PySR will not work, such as `.sympy()`,
391
+ > `.predict()`, etc. This is because the SymPy parsing does not know about
392
+ > how you are manipulating the expression, so you will need to do this yourself.
393
+
394
+ Note how we did not pass `/` as a binary operator; it will just be implicit
395
+ in the functional form.
396
+
397
+ Let's generate an equation of the form $\frac{x_0^2 x_1 - 2}{x_2^2 + 1}$:
398
+
399
+ ```python
400
+ X = np.random.randn(1000, 3)
401
+ y = (X[:, 0]**2 * X[:, 1] - 2) / (X[:, 2]**2 + 1)
402
+ ```
403
+
404
+ Finally, let's fit:
405
+
406
+ ```python
407
+ model.fit(X, y)
408
+ ```
409
+
410
+ > Note that the printed equation is not the same as the evaluated equation,
411
+ > because the printing functionality does not know about the functional form.
412
+
413
+ We can get the string format with:
414
+
415
+ ```python
416
+ model.get_best().equation
417
+ ```
418
+
419
+ (or, you could use `model.equations_.iloc[-1].equation`)
420
+
421
+ For me, this equation was:
422
+
423
+ ```text
424
+ (((2.3554819 + -0.3554746) - (x1 * (x0 * x0))) - (-1.0000019 - (x2 * x2)))
425
+ ```
426
+
427
+ looking at the bracket structure of the equation, we can see that the outermost
428
+ bracket is split at the `-` operator (note that we ignore the root operator in
429
+ the evaluation, as we simply evaluated each argument and divided the result) into
430
+ `((2.3554819 + -0.3554746) - (x1 * (x0 * x0)))` and
431
+ `(-1.0000019 - (x2 * x2))`, meaning that our discovered equation is
432
+ equal to:
433
+ $\frac{x_0^2 x_1 - 2.0000073}{x_2^2 - 1.0000019}$, which
434
+ is nearly the same as the true equation!
435
+
436
+
437
+
438
+ ## 10. Additional features
439
 
440
  For the many other features available in PySR, please
441
  read the [Options section](options.md).