Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
from scipy.integrate import odeint | |
import matplotlib.pyplot as plt | |
import scipy.constants as sc | |
# Constants | |
me = sc.electron_mass # Electron mass | |
mp = sc.proton_mass # Proton mass | |
e = sc.electron_volt # Elementary charge | |
# Define the Lorentz and drift equation functions here (same as in your provided code) | |
def lorentz_equation(solution_at_current_t, t, m, q, E, B): | |
# Initialize an array for time derivatives | |
time_derivatives = np.zeros(6) | |
# Extract velocity vector from the solution | |
v = solution_at_current_t[3:] | |
# Calculate the derivatives of position and velocity | |
time_derivatives[:3] = v # Derivative of position is velocity | |
time_derivatives[3:] = (q / m) * (E + np.cross(v, B)) # Lorentz force equation | |
return time_derivatives | |
def drift_equations(solution_at_current_t, t, m, q, E, B, grad_of_B, B_unit, B_norm): | |
x, y, z, v_gc_parallel, v_gc_perp = solution_at_current_t | |
# Calculate the dot product of grad_of_B and B_unit | |
grad_of_B_norm = np.dot(grad_of_B, B_unit) | |
# Calculate grad_of_B_unit | |
grad_of_B_unit = grad_of_B / B_norm - grad_of_B_norm * B_unit | |
# Calculate v_ExB | |
v_ExB = np.cross(E, B) / B_norm**2.0 | |
# Calculate v_gradB | |
v_gradB = -(m * v_gc_perp**2.0 / (2.0 * B_norm)) * np.cross(grad_of_B_norm, B) / (q * B_norm**2.0) | |
# Calculate v_R | |
v_R = -(m * v_gc_parallel**2.0) * np.cross(np.dot(B_unit, grad_of_B_unit), B) / (q * B_norm**2.0) | |
# Initialize v_pol as a zero vector | |
v_pol = np.zeros(3) | |
# Calculate v_gc | |
v_gc = v_ExB + v_gradB + v_R + v_pol + v_gc_parallel * B_unit | |
# Initialize partial_t_of_B_norm as 0.0 (you may want to calculate this) | |
partial_t_of_B_norm = 0.0 | |
# Calculate time derivatives | |
time_derivatives = np.zeros(5) | |
time_derivatives[0:3] = v_gc | |
time_derivatives[3] = (q / m) * np.dot(B_unit, E) - (v_gc_perp**2.0 / (2.0 * B_norm)) * np.dot(grad_of_B_norm, B_unit) | |
time_derivatives[4] = (v_gc_perp / (2.0 * B_norm)) * (partial_t_of_B_norm + np.dot(v_gc, grad_of_B_norm)) | |
return time_derivatives | |
# Define the function for calculating orbits in homogeneous fields here (same as in your provided code) | |
def orbits_in_homogeneous_fields(m, q, r_0, v_0, E, B, grad_of_B, cycles): | |
# Calculate B_norm and cyclotron frequency | |
B_norm = np.linalg.norm(B) | |
B_unit = B / B_norm | |
omega_c = abs((B_norm * q) / m) | |
tau_c = (2.0 * np.pi) / omega_c | |
# Set up time stepping parameters | |
cyclotron_resolution = 30.0 | |
t_step = tau_c / cyclotron_resolution | |
t_span = np.arange(0.0, cycles * tau_c, t_step) | |
# Set the tolerance of solutions_at_output_times_lorentz the odeint solver | |
options = {'rtol': 1e-5, 'atol': 1e-6} | |
# Initial conditions for Lorentz equation | |
initial_conditions_lorentz = np.concatenate([r_0, v_0]) | |
# Solve the Lorentz equation | |
solutions_at_output_times_lorentz = odeint(lorentz_equation, initial_conditions_lorentz, t_span, args=(m, q, E, B), rtol=options['rtol'], atol=options['atol']) | |
r_lorentz = solutions_at_output_times_lorentz[:, :3].T | |
v_lorentz = solutions_at_output_times_lorentz[:, 3:].T | |
# We have to find initial conditions for the guilding center | |
# This is a mean over one gyration | |
first_orbit_ind = t_span[t_span < 2.0 * np.pi / omega_c] | |
first_orbit_ind_range = range(len(first_orbit_ind)) | |
r_gc_0 = np.mean(r_lorentz[:, first_orbit_ind_range], axis=1) | |
v_parallel = np.dot(B_unit, v_lorentz) | |
v_gc_parallel_0 = np.mean(v_parallel[first_orbit_ind_range]) | |
v_gc_mean_0 = np.sqrt(np.mean(v_lorentz[0, first_orbit_ind_range]**2 + v_lorentz[1, first_orbit_ind_range]**2 + v_lorentz[2, first_orbit_ind_range]**2)) | |
v_gc_perp_0 = np.sqrt(v_gc_mean_0**2 - v_gc_parallel_0**2) | |
# Initial conditions for drift equation | |
initial_conditions_drift = list(r_gc_0) + [v_gc_parallel_0] + [v_gc_perp_0] | |
# Solve the drift equation | |
solutions_at_output_times_drift = odeint(drift_equations, initial_conditions_drift, t_span, args=(m, q, E, B, grad_of_B, B / B_norm, B_norm), rtol=options['rtol'], atol=options['atol']) | |
# Extract positions and velocities | |
r_gc = solutions_at_output_times_drift[:, :3].T | |
v_gc = solutions_at_output_times_drift[:, 3:].T | |
return r_lorentz, v_lorentz, r_gc, v_gc | |
# Define the plotting function here (same as in your provided code but remove plt.show()) | |
def plot_figs(r_lorentz, v_lorentz, r_gc, v_gc): | |
# Plotting (you can adjust the plots as needed) | |
fig = plt.figure(figsize=(12, 8)) | |
# Plot 3D trajectory | |
ax1 = fig.add_subplot(221, projection='3d') | |
ax1.plot(r_lorentz[0], r_lorentz[1], r_lorentz[2], linewidth=2) | |
ax1.scatter(r_lorentz[0, 0], r_lorentz[1, 0], r_lorentz[2, 0], marker='o') | |
ax1.plot(r_gc[0], r_gc[1], r_gc[2], 'r', linewidth=3) | |
ax1.scatter(r_gc[0, 0], r_gc[1, 0], r_gc[2, 0], color='red', marker='o') | |
ax1.set_xlabel('x [m]') | |
ax1.set_ylabel('y [m]') | |
ax1.set_zlabel('z [m]') | |
ax1.set_title('3D plot (rotate)') | |
# Plot xy trajectory | |
ax2 = fig.add_subplot(222) | |
ax2.plot(r_lorentz[0], r_lorentz[1]) | |
ax2.scatter(r_lorentz[0, 0], r_lorentz[1, 0], marker='o') | |
ax2.plot(r_gc[0], r_gc[1], 'r') | |
ax2.scatter(r_gc[0, 0], r_gc[1, 0], color='red', marker='o') | |
ax2.set_xlabel('x') | |
ax2.set_ylabel('y') | |
ax2.set_title('xy plot') | |
# Plot xz trajectory | |
ax3 = fig.add_subplot(223) | |
ax3.plot(r_lorentz[0], r_lorentz[2], linewidth=2) | |
ax3.scatter(r_lorentz[0, 0], r_lorentz[2, 0], marker='o') | |
ax3.plot(r_gc[0], r_gc[2], 'r') | |
ax3.scatter(r_gc[0, 0], r_gc[2, 0], color='red', marker='o') | |
ax3.set_xlabel('x') | |
ax3.set_ylabel('z') | |
ax3.set_title('xz plot') | |
# Plot yz trajectory | |
ax4 = fig.add_subplot(224) | |
ax4.plot(r_lorentz[1], r_lorentz[2], linewidth=2) | |
ax4.scatter(r_lorentz[1, 0], r_lorentz[2, 0], marker='o') | |
ax4.plot(r_gc[1], r_gc[2], 'r') | |
ax4.scatter(r_gc[1, 0], r_gc[2, 0], color='red', marker='o') | |
ax4.set_xlabel('y') | |
ax4.set_ylabel('z') | |
ax4.set_title('yz plot') | |
plt.tight_layout() | |
# Instead of plt.show(), return the figure object for Streamlit to render | |
return fig | |
################## | |
def reset_values(): | |
"""Reset user inputs to their initial values.""" | |
st.session_state["mass_multiplier"] = 2.0 | |
st.session_state["charge_multiplier"] = -1.0 | |
st.session_state["r_0_x"] = 0.0 | |
st.session_state["r_0_y"] = 0.0 | |
st.session_state["r_0_z"] = 0.0 | |
st.session_state["E_k_keV"] = 500.0 | |
st.session_state["E_x"] = 0.0 | |
st.session_state["E_y"] = 1e5 | |
st.session_state["E_z"] = 0.0 | |
st.session_state["B_x"] = 0.0 | |
st.session_state["B_y"] = 0.0 | |
st.session_state["B_z"] = 1.0 | |
st.session_state["cycles"] = 20 | |
def main(): | |
st.title("Lorentz Equation Solver") | |
if "mass_multiplier" not in st.session_state: | |
reset_values() # Initialize session state with default values | |
# User inputs for parameters | |
m = st.slider("Mass (in multiples of proton mass)", min_value=1.0, max_value=10.0, value=st.session_state["mass_multiplier"], key="mass_multiplier") * mp | |
q = st.slider("Charge (in multiples of elementary charge)", min_value=-10.0, max_value=10.0, value=st.session_state["charge_multiplier"], key="charge_multiplier") * e | |
st.markdown("""---""") # Horizontal rule for visual separation | |
# Bold parameter names using markdown and layout for initial position input | |
st.markdown("**Initial Position (m)**") | |
col1, spacer1, col2, spacer2, col3 = st.columns([1, 0.1, 1, 0.1, 1]) | |
with col1: | |
r_0_x = st.number_input("**X**", value=st.session_state["r_0_x"], key="r_0_x") | |
with col2: | |
r_0_y = st.number_input("**Y**", value=st.session_state["r_0_y"], key="r_0_y") | |
with col3: | |
r_0_z = st.number_input("**Z**", value=st.session_state["r_0_z"], key="r_0_z") | |
r_0 = np.array([r_0_x, r_0_y, r_0_z]) | |
E_k = st.slider("Kinetic Energy (keV)", min_value=100.0, max_value=1000.0, value=st.session_state["E_k_keV"], key="E_k_keV") * 1e3 * e | |
st.markdown("""---""") # Horizontal rule for visual separation | |
# Layout for electric field input | |
st.markdown("**Electric Field (V/m)**") | |
col4, spacer3, col5, spacer4, col6 = st.columns([1, 0.1, 1, 0.1, 1]) | |
with col4: | |
E_x = st.number_input("**E_X**", value=st.session_state["E_x"], key="E_x") | |
with col5: | |
E_y = st.number_input("**E_Y**", value=st.session_state["E_y"], key="E_y") | |
with col6: | |
E_z = st.number_input("**E_Z**", value=st.session_state["E_z"], key="E_z") | |
E = np.array([E_x, E_y, E_z]) | |
st.markdown("""---""") # Horizontal rule for visual separation | |
# Layout for magnetic field input | |
st.markdown("**Magnetic Field (T)**") | |
col7, spacer5, col8, spacer6, col9 = st.columns([1, 0.1, 1, 0.1, 1]) | |
with col7: | |
B_x = st.number_input("**B_X**", value=st.session_state["B_x"], key="B_x") | |
with col8: | |
B_y = st.number_input("**B_Y**", value=st.session_state["B_y"], key="B_y") | |
with col9: | |
B_z = st.number_input("**B_Z**", value=st.session_state["B_z"], key="B_z") | |
B = np.array([B_x, B_y, B_z]) | |
cycles = st.slider("Number of Cycles", min_value=1, max_value=100, value=st.session_state["cycles"], key="cycles") | |
# Recalculate thermal velocity based on user input for kinetic energy | |
v_therm = np.sqrt((2.0 * E_k) / m) | |
v_0 = np.array([0, v_therm, v_therm / 10]) # Initial velocity | |
grad_of_B = np.zeros((3, 3)) # Assuming gradient of B remains zero for simplicity | |
if st.button("Calculate Orbits"): | |
r_lorentz, v_lorentz, r_gc, v_gc = orbits_in_homogeneous_fields(m, q, r_0, v_0, E, B, grad_of_B, cycles) | |
# Plot the figures | |
fig = plot_figs(r_lorentz, v_lorentz, r_gc, v_gc) | |
st.pyplot(fig) | |
if st.button("Reset", on_click=reset_values): | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() |