Zeel's picture
upload first version
a5b88b1
raw
history blame
3.93 kB
import streamlit as st
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def parabola_fn(x):
return x**0.5
def circle_fn(x):
return (1 - x**2) ** 0.5
d_parabola_fn = jax.grad(parabola_fn)
d_circle_fn = jax.grad(circle_fn)
def loss_fn(params):
x1 = params["x1"]
x2 = params["x2"]
# parpendicular line to the tangent of the parabola: y = m1 * x + c1
m1 = -1 / d_parabola_fn(x1)
c1 = parabola_fn(x1) - m1 * x1
def perpendicular_parabola_fn(x):
return m1 * x + c1
# parpendicular line to the tangent of the circle: y = m2 * x + c2
m2 = -1 / d_circle_fn(x2)
c2 = circle_fn(x2) - m2 * x2
def perpendicular_circle_fn(x):
return m2 * x + c2
# x_star and y_star are the intersection of the two lines
x_star = (c2 - c1) / (m1 - m2)
y_star = m1 * x_star + c1
# three quantities should be equal to each other
# 1. distance between intersection and parabola
# 2. distance between intersection and circle
# 3. distance between intersection and x=0 line
d1 = (x_star - x1) ** 2 + (y_star - parabola_fn(x1)) ** 2
d2 = (x_star - x2) ** 2 + (y_star - circle_fn(x2)) ** 2
d3 = x_star**2
aux = {
"x_star": x_star,
"y_star": y_star,
"perpendicular_parabola_fn": perpendicular_parabola_fn,
"perpendicular_circle_fn": perpendicular_circle_fn,
"r": d1**0.5,
}
# final loss
loss = (d1 - d2) ** 2 + (d1 - d3) ** 2 + (d2 - d3) ** 2
return loss, aux
x = jnp.linspace(0, 1, 100)
st.title("Radius of the Circle: Optimization Playground")
col1, col2 = st.columns(2)
x1 = col1.slider("initial x1 (x intersection with parabola)", 0.0, 1.0, 0.5)
x2 = col1.slider("initial x2 (x intersection with the circle)", 0.0, 1.0, 0.5)
n_epochs = col2.slider("n_epochs", 0, 1000, 50)
lr = col2.slider("lr", 0.0, 1.0, value=0.1, step=0.01)
# submit button
submit = st.button("submit")
# when submit button is clicked run the following code
params = {"x1": x1, "x2": x2}
losses = []
value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
# initialize plot
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 1)
value, aux = loss_fn(params)
(pbola_plot,) = axes[0].plot(x, parabola_fn(x), color="red")
(pbola_perpendicular_plot,) = axes[0].plot(x, aux["perpendicular_parabola_fn"](x), color="red", linestyle="--")
(cicle_plot,) = axes[0].plot(x, circle_fn(x), color="blue")
(circle_perpendicular_plot,) = axes[0].plot(x, aux["perpendicular_circle_fn"](x), color="blue", linestyle="--")
x_star, y_star = aux["x_star"], aux["y_star"]
radius = aux["r"]
axes[0].add_patch(plt.Circle((x_star, y_star), radius, fill=False))
axes[1].set_xlim(0, n_epochs)
axes[1].set_ylim(0, value)
(loss_plot,) = axes[1].plot(losses, color="black")
pbar = st.progress(0)
with st.empty():
st.pyplot(fig)
if submit:
for i in range(n_epochs):
(value, _), grad = value_and_grad_fn(params)
params["x1"] -= lr * grad["x1"]
params["x2"] -= lr * grad["x2"]
losses.append(value)
_, aux = loss_fn(params)
print(params, grad, lr)
pbola_plot.set_data(x, parabola_fn(x))
pbola_perpendicular_plot.set_data(x, aux["perpendicular_parabola_fn"](x))
cicle_plot.set_data(x, circle_fn(x))
circle_perpendicular_plot.set_data(x, aux["perpendicular_circle_fn"](x))
x_star, y_star = aux["x_star"], aux["y_star"]
radius = aux["r"]
axes[0].add_patch(plt.Circle((x_star, y_star), radius, fill=False))
loss_plot.set_data(range(len(losses)), losses)
pbar.progress(i / n_epochs)
axes[0].set_title(f"x1: {params['x1']:.3f}, x2: {params['x2']:.3f} \n r: {radius:.4f}")
axes[1].set_title(f"epoch: {i}, loss: {value:.5f}")
st.pyplot(fig)