LorentzEquation / app.py
saoter's picture
Rename app2.py to app.py
bc5605c verified
import streamlit as st
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
import scipy.constants as sc
# Constants
me = sc.electron_mass # Electron mass
mp = sc.proton_mass # Proton mass
e = sc.electron_volt # Elementary charge
# Define the Lorentz and drift equation functions here (same as in your provided code)
def lorentz_equation(solution_at_current_t, t, m, q, E, B):
# Initialize an array for time derivatives
time_derivatives = np.zeros(6)
# Extract velocity vector from the solution
v = solution_at_current_t[3:]
# Calculate the derivatives of position and velocity
time_derivatives[:3] = v # Derivative of position is velocity
time_derivatives[3:] = (q / m) * (E + np.cross(v, B)) # Lorentz force equation
return time_derivatives
def drift_equations(solution_at_current_t, t, m, q, E, B, grad_of_B, B_unit, B_norm):
x, y, z, v_gc_parallel, v_gc_perp = solution_at_current_t
# Calculate the dot product of grad_of_B and B_unit
grad_of_B_norm = np.dot(grad_of_B, B_unit)
# Calculate grad_of_B_unit
grad_of_B_unit = grad_of_B / B_norm - grad_of_B_norm * B_unit
# Calculate v_ExB
v_ExB = np.cross(E, B) / B_norm**2.0
# Calculate v_gradB
v_gradB = -(m * v_gc_perp**2.0 / (2.0 * B_norm)) * np.cross(grad_of_B_norm, B) / (q * B_norm**2.0)
# Calculate v_R
v_R = -(m * v_gc_parallel**2.0) * np.cross(np.dot(B_unit, grad_of_B_unit), B) / (q * B_norm**2.0)
# Initialize v_pol as a zero vector
v_pol = np.zeros(3)
# Calculate v_gc
v_gc = v_ExB + v_gradB + v_R + v_pol + v_gc_parallel * B_unit
# Initialize partial_t_of_B_norm as 0.0 (you may want to calculate this)
partial_t_of_B_norm = 0.0
# Calculate time derivatives
time_derivatives = np.zeros(5)
time_derivatives[0:3] = v_gc
time_derivatives[3] = (q / m) * np.dot(B_unit, E) - (v_gc_perp**2.0 / (2.0 * B_norm)) * np.dot(grad_of_B_norm, B_unit)
time_derivatives[4] = (v_gc_perp / (2.0 * B_norm)) * (partial_t_of_B_norm + np.dot(v_gc, grad_of_B_norm))
return time_derivatives
# Define the function for calculating orbits in homogeneous fields here (same as in your provided code)
def orbits_in_homogeneous_fields(m, q, r_0, v_0, E, B, grad_of_B, cycles):
# Calculate B_norm and cyclotron frequency
B_norm = np.linalg.norm(B)
B_unit = B / B_norm
omega_c = abs((B_norm * q) / m)
tau_c = (2.0 * np.pi) / omega_c
# Set up time stepping parameters
cyclotron_resolution = 30.0
t_step = tau_c / cyclotron_resolution
t_span = np.arange(0.0, cycles * tau_c, t_step)
# Set the tolerance of solutions_at_output_times_lorentz the odeint solver
options = {'rtol': 1e-5, 'atol': 1e-6}
# Initial conditions for Lorentz equation
initial_conditions_lorentz = np.concatenate([r_0, v_0])
# Solve the Lorentz equation
solutions_at_output_times_lorentz = odeint(lorentz_equation, initial_conditions_lorentz, t_span, args=(m, q, E, B), rtol=options['rtol'], atol=options['atol'])
r_lorentz = solutions_at_output_times_lorentz[:, :3].T
v_lorentz = solutions_at_output_times_lorentz[:, 3:].T
# We have to find initial conditions for the guilding center
# This is a mean over one gyration
first_orbit_ind = t_span[t_span < 2.0 * np.pi / omega_c]
first_orbit_ind_range = range(len(first_orbit_ind))
r_gc_0 = np.mean(r_lorentz[:, first_orbit_ind_range], axis=1)
v_parallel = np.dot(B_unit, v_lorentz)
v_gc_parallel_0 = np.mean(v_parallel[first_orbit_ind_range])
v_gc_mean_0 = np.sqrt(np.mean(v_lorentz[0, first_orbit_ind_range]**2 + v_lorentz[1, first_orbit_ind_range]**2 + v_lorentz[2, first_orbit_ind_range]**2))
v_gc_perp_0 = np.sqrt(v_gc_mean_0**2 - v_gc_parallel_0**2)
# Initial conditions for drift equation
initial_conditions_drift = list(r_gc_0) + [v_gc_parallel_0] + [v_gc_perp_0]
# Solve the drift equation
solutions_at_output_times_drift = odeint(drift_equations, initial_conditions_drift, t_span, args=(m, q, E, B, grad_of_B, B / B_norm, B_norm), rtol=options['rtol'], atol=options['atol'])
# Extract positions and velocities
r_gc = solutions_at_output_times_drift[:, :3].T
v_gc = solutions_at_output_times_drift[:, 3:].T
return r_lorentz, v_lorentz, r_gc, v_gc
# Define the plotting function here (same as in your provided code but remove plt.show())
def plot_figs(r_lorentz, v_lorentz, r_gc, v_gc):
# Plotting (you can adjust the plots as needed)
fig = plt.figure(figsize=(12, 8))
# Plot 3D trajectory
ax1 = fig.add_subplot(221, projection='3d')
ax1.plot(r_lorentz[0], r_lorentz[1], r_lorentz[2], linewidth=2)
ax1.scatter(r_lorentz[0, 0], r_lorentz[1, 0], r_lorentz[2, 0], marker='o')
ax1.plot(r_gc[0], r_gc[1], r_gc[2], 'r', linewidth=3)
ax1.scatter(r_gc[0, 0], r_gc[1, 0], r_gc[2, 0], color='red', marker='o')
ax1.set_xlabel('x [m]')
ax1.set_ylabel('y [m]')
ax1.set_zlabel('z [m]')
ax1.set_title('3D plot (rotate)')
# Plot xy trajectory
ax2 = fig.add_subplot(222)
ax2.plot(r_lorentz[0], r_lorentz[1])
ax2.scatter(r_lorentz[0, 0], r_lorentz[1, 0], marker='o')
ax2.plot(r_gc[0], r_gc[1], 'r')
ax2.scatter(r_gc[0, 0], r_gc[1, 0], color='red', marker='o')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
ax2.set_title('xy plot')
# Plot xz trajectory
ax3 = fig.add_subplot(223)
ax3.plot(r_lorentz[0], r_lorentz[2], linewidth=2)
ax3.scatter(r_lorentz[0, 0], r_lorentz[2, 0], marker='o')
ax3.plot(r_gc[0], r_gc[2], 'r')
ax3.scatter(r_gc[0, 0], r_gc[2, 0], color='red', marker='o')
ax3.set_xlabel('x')
ax3.set_ylabel('z')
ax3.set_title('xz plot')
# Plot yz trajectory
ax4 = fig.add_subplot(224)
ax4.plot(r_lorentz[1], r_lorentz[2], linewidth=2)
ax4.scatter(r_lorentz[1, 0], r_lorentz[2, 0], marker='o')
ax4.plot(r_gc[1], r_gc[2], 'r')
ax4.scatter(r_gc[1, 0], r_gc[2, 0], color='red', marker='o')
ax4.set_xlabel('y')
ax4.set_ylabel('z')
ax4.set_title('yz plot')
plt.tight_layout()
# Instead of plt.show(), return the figure object for Streamlit to render
return fig
##################
def reset_values():
"""Reset user inputs to their initial values."""
st.session_state["mass_multiplier"] = 2.0
st.session_state["charge_multiplier"] = -1.0
st.session_state["r_0_x"] = 0.0
st.session_state["r_0_y"] = 0.0
st.session_state["r_0_z"] = 0.0
st.session_state["E_k_keV"] = 500.0
st.session_state["E_x"] = 0.0
st.session_state["E_y"] = 1e5
st.session_state["E_z"] = 0.0
st.session_state["B_x"] = 0.0
st.session_state["B_y"] = 0.0
st.session_state["B_z"] = 1.0
st.session_state["cycles"] = 20
def main():
st.title("Lorentz Equation Solver")
if "mass_multiplier" not in st.session_state:
reset_values() # Initialize session state with default values
# User inputs for parameters
m = st.slider("Mass (in multiples of proton mass)", min_value=1.0, max_value=10.0, value=st.session_state["mass_multiplier"], key="mass_multiplier") * mp
q = st.slider("Charge (in multiples of elementary charge)", min_value=-10.0, max_value=10.0, value=st.session_state["charge_multiplier"], key="charge_multiplier") * e
st.markdown("""---""") # Horizontal rule for visual separation
# Bold parameter names using markdown and layout for initial position input
st.markdown("**Initial Position (m)**")
col1, spacer1, col2, spacer2, col3 = st.columns([1, 0.1, 1, 0.1, 1])
with col1:
r_0_x = st.number_input("**X**", value=st.session_state["r_0_x"], key="r_0_x")
with col2:
r_0_y = st.number_input("**Y**", value=st.session_state["r_0_y"], key="r_0_y")
with col3:
r_0_z = st.number_input("**Z**", value=st.session_state["r_0_z"], key="r_0_z")
r_0 = np.array([r_0_x, r_0_y, r_0_z])
E_k = st.slider("Kinetic Energy (keV)", min_value=100.0, max_value=1000.0, value=st.session_state["E_k_keV"], key="E_k_keV") * 1e3 * e
st.markdown("""---""") # Horizontal rule for visual separation
# Layout for electric field input
st.markdown("**Electric Field (V/m)**")
col4, spacer3, col5, spacer4, col6 = st.columns([1, 0.1, 1, 0.1, 1])
with col4:
E_x = st.number_input("**E_X**", value=st.session_state["E_x"], key="E_x")
with col5:
E_y = st.number_input("**E_Y**", value=st.session_state["E_y"], key="E_y")
with col6:
E_z = st.number_input("**E_Z**", value=st.session_state["E_z"], key="E_z")
E = np.array([E_x, E_y, E_z])
st.markdown("""---""") # Horizontal rule for visual separation
# Layout for magnetic field input
st.markdown("**Magnetic Field (T)**")
col7, spacer5, col8, spacer6, col9 = st.columns([1, 0.1, 1, 0.1, 1])
with col7:
B_x = st.number_input("**B_X**", value=st.session_state["B_x"], key="B_x")
with col8:
B_y = st.number_input("**B_Y**", value=st.session_state["B_y"], key="B_y")
with col9:
B_z = st.number_input("**B_Z**", value=st.session_state["B_z"], key="B_z")
B = np.array([B_x, B_y, B_z])
cycles = st.slider("Number of Cycles", min_value=1, max_value=100, value=st.session_state["cycles"], key="cycles")
# Recalculate thermal velocity based on user input for kinetic energy
v_therm = np.sqrt((2.0 * E_k) / m)
v_0 = np.array([0, v_therm, v_therm / 10]) # Initial velocity
grad_of_B = np.zeros((3, 3)) # Assuming gradient of B remains zero for simplicity
if st.button("Calculate Orbits"):
r_lorentz, v_lorentz, r_gc, v_gc = orbits_in_homogeneous_fields(m, q, r_0, v_0, E, B, grad_of_B, cycles)
# Plot the figures
fig = plot_figs(r_lorentz, v_lorentz, r_gc, v_gc)
st.pyplot(fig)
if st.button("Reset", on_click=reset_values):
st.experimental_rerun()
if __name__ == "__main__":
main()