euler314 commited on
Commit
ef6361a
·
verified ·
1 Parent(s): ba9a480

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -311
app.py CHANGED
@@ -6,9 +6,9 @@ from scipy.optimize import fsolve
6
 
7
  # Configure Streamlit for Hugging Face Spaces
8
  st.set_page_config(
9
- page_title="Cubic Root Analysis",
10
- layout="wide",
11
- initial_sidebar_state="collapsed"
12
  )
13
 
14
  #############################
@@ -26,8 +26,8 @@ d_sym = 1
26
 
27
  # Symbolic expression for the standard cubic discriminant
28
  Delta_expr = (
29
- ( (b_sym*c_sym)/(6*a_sym**2) - (b_sym**3)/(27*a_sym**3) - d_sym/(2*a_sym) )**2
30
- + ( c_sym/(3*a_sym) - (b_sym**2)/(9*a_sym**2) )**3
31
  )
32
 
33
  # Turn that into a fast numeric function:
@@ -35,329 +35,352 @@ discriminant_func = sp.lambdify((z_sym, beta_sym, z_a_sym, y_sym), Delta_expr, "
35
 
36
  @st.cache_data
37
  def find_z_at_discriminant_zero(z_a, y, beta, z_min, z_max, steps=20000):
38
- """
39
- Numerically scan z in [z_min, z_max] looking for sign changes of
40
- Delta(z) = 0. Returns all roots found via bisection.
41
- """
42
- z_grid = np.linspace(z_min, z_max, steps)
43
- disc_vals = discriminant_func(z_grid, beta, z_a, y)
44
-
45
- roots_found = []
46
-
47
- # Scan for sign changes
48
- for i in range(len(z_grid) - 1):
49
- f1, f2 = disc_vals[i], disc_vals[i+1]
50
- if np.isnan(f1) or np.isnan(f2):
51
- continue
52
-
53
- if f1 == 0.0:
54
- roots_found.append(z_grid[i])
55
- elif f2 == 0.0:
56
- roots_found.append(z_grid[i+1])
57
- elif f1*f2 < 0:
58
- zl = z_grid[i]
59
- zr = z_grid[i+1]
60
- for _ in range(50):
61
- mid = 0.5*(zl + zr)
62
- fm = discriminant_func(mid, beta, z_a, y)
63
- if fm == 0:
64
- zl = zr = mid
65
- break
66
- if np.sign(fm) == np.sign(f1):
67
- zl = mid
68
- f1 = fm
69
- else:
70
- zr = mid
71
- f2 = fm
72
- root_approx = 0.5*(zl + zr)
73
- roots_found.append(root_approx)
74
-
75
- return np.array(roots_found)
76
 
77
  @st.cache_data
78
  def sweep_beta_and_find_z_bounds(z_a, y, z_min, z_max, beta_steps=51):
79
- """
80
- For each beta, find both the largest and smallest z where discriminant=0.
81
- Returns (betas, z_min_values, z_max_values).
82
- """
83
- betas = np.linspace(0, 1, beta_steps)
84
- z_min_values = []
85
- z_max_values = []
86
-
87
- for b in betas:
88
- roots = find_z_at_discriminant_zero(z_a, y, b, z_min, z_max)
89
- if len(roots) == 0:
90
- z_min_values.append(np.nan)
91
- z_max_values.append(np.nan)
92
- else:
93
- z_min_values.append(np.min(roots))
94
- z_max_values.append(np.max(roots))
95
-
96
- return betas, np.array(z_min_values), np.array(z_max_values)
97
 
98
  @st.cache_data
99
  def compute_low_y_curve(betas, z_a, y):
100
- """
101
- Compute the additional curve with proper handling of divide by zero cases
102
- """
103
- betas = np.array(betas)
104
- with np.errstate(invalid='ignore', divide='ignore'):
105
- sqrt_term = y * betas * (z_a - 1)
106
- sqrt_term = np.where(sqrt_term < 0, np.nan, np.sqrt(sqrt_term))
107
-
108
- term = (-1 + sqrt_term)/z_a
109
- numerator = (y - 2)*term + y * betas * ((z_a - 1)/z_a) - 1/z_a - 1
110
- denominator = term**2 + term
111
-
112
- # Handle division by zero and invalid values
113
- mask = (denominator != 0) & ~np.isnan(denominator) & ~np.isnan(numerator)
114
- return np.where(mask, numerator/denominator, np.nan)
115
 
116
  @st.cache_data
117
  def compute_high_y_curve(betas, z_a, y):
118
- """
119
- Compute the expression: ((4y + 12)(4 - a) + 16y*β*(a - 1))/(3(4 - a))
120
- """
121
- betas = np.array(betas)
122
- denominator = 3*(4 - z_a)
123
-
124
- if denominator == 0:
125
- return np.full_like(betas, np.nan)
126
-
127
- numerator = (4*y + 12)*(4 - z_a) + 16*y*betas*(z_a - 1)
128
- return numerator/denominator
129
 
130
  def generate_z_vs_beta_plot(z_a, y, z_min, z_max):
131
- if z_a <= 0 or y <= 0 or z_min >= z_max:
132
- st.error("Invalid input parameters.")
133
- return None
134
-
135
- beta_steps = 101
136
- betas = np.linspace(0, 1, beta_steps)
137
-
138
- betas, z_mins, z_maxs = sweep_beta_and_find_z_bounds(z_a, y, z_min, z_max, beta_steps=beta_steps)
139
- low_y_curve = compute_low_y_curve(betas, z_a, y)
140
- high_y_curve = compute_high_y_curve(betas, z_a, y)
141
-
142
- fig = go.Figure()
143
-
144
- fig.add_trace(
145
- go.Scatter(
146
- x=betas,
147
- y=z_maxs,
148
- mode="markers+lines",
149
- name="Upper z*(β)",
150
- marker=dict(size=5, color='blue'),
151
- line=dict(color='blue'),
152
- )
153
- )
154
-
155
- fig.add_trace(
156
- go.Scatter(
157
- x=betas,
158
- y=z_mins,
159
- mode="markers+lines",
160
- name="Lower z*(β)",
161
- marker=dict(size=5, color='lightblue'),
162
- line=dict(color='lightblue'),
163
- )
164
- )
165
-
166
- fig.add_trace(
167
- go.Scatter(
168
- x=betas,
169
- y=low_y_curve,
170
- mode="markers+lines",
171
- name="Low y Expression",
172
- marker=dict(size=5, color='red'),
173
- line=dict(color='red'),
174
- )
175
- )
176
-
177
- fig.add_trace(
178
- go.Scatter(
179
- x=betas,
180
- y=high_y_curve,
181
- mode="markers+lines",
182
- name="High y Expression",
183
- marker=dict(size=5, color='green'),
184
- line=dict(color='green'),
185
- )
186
- )
187
-
188
- fig.update_layout(
189
- title="Curves vs β: z*(β) boundaries and Asymptotic Expressions",
190
- xaxis_title="β",
191
- yaxis_title="Value",
192
- hovermode="x unified",
193
- )
194
- return fig
195
 
196
  def compute_cubic_roots(z, beta, z_a, y):
197
- """
198
- Compute the roots of the cubic equation for given parameters.
199
- """
200
- a = z * z_a
201
- b = z * z_a + z + z_a - z_a*y
202
- c = z + z_a + 1 - y*(beta*z_a + 1 - beta)
203
- d = 1
204
-
205
- coeffs = [a, b, c, d]
206
- roots = np.roots(coeffs)
207
- return roots
208
 
209
- def generate_ims_vs_z_plot(beta, y, z_a, z_min, z_max):
210
- if z_a <= 0 or y <= 0 or z_min >= z_max:
211
- st.error("Invalid input parameters.")
212
- return None
213
-
214
- z_points = np.linspace(z_min, z_max, 1000)
215
- ims = []
216
-
217
- for z in z_points:
218
- roots = compute_cubic_roots(z, beta, z_a, y)
219
- roots = sorted(roots, key=lambda x: abs(x.imag))
220
- ims.append([root.imag for root in roots])
221
-
222
- ims = np.array(ims)
223
-
224
- fig = go.Figure()
225
-
226
- for i in range(3):
227
- fig.add_trace(
228
- go.Scatter(
229
- x=z_points,
230
- y=ims[:,i],
231
- mode="lines",
232
- name=f"Im{{s{i+1}}}",
233
- line=dict(width=2),
234
- )
235
- )
236
-
237
- fig.update_layout(
238
- title=f"Im{{s}} vs. z (β={beta:.3f}, y={y:.3f}, z_a={z_a:.3f})",
239
- xaxis_title="z",
240
- yaxis_title="Im{s}",
241
- hovermode="x unified",
242
- )
243
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  def curve1(s, z, y):
246
- """First curve: z*s^2 + (z-y+1)*s + 1"""
247
- return z*s**2 + (z-y+1)*s + 1
248
 
249
  def curve2(s, y, beta, a):
250
- """Second curve: y*β*((a-1)*s)/(a*s+1)"""
251
- return y*beta*((a-1)*s)/(a*s+1)
252
 
253
  def find_intersections(z, y, beta, a, s_range):
254
- """Find intersections between the two curves with improved accuracy"""
255
- def equation(s):
256
- return curve1(s, z, y) - curve2(s, y, beta, a)
257
-
258
- # Create a finer grid of initial guesses
259
- s_guesses = np.linspace(s_range[0], s_range[1], 200)
260
- intersections = []
261
-
262
- # Parameters for accuracy
263
- tolerance = 1e-10
264
-
265
- # First pass: find all potential intersections
266
- for s_guess in s_guesses:
267
- try:
268
- s_sol = fsolve(equation, s_guess, full_output=True, xtol=tolerance)
269
- if s_sol[2] == 1: # Check if convergence was achieved
270
- s_val = s_sol[0][0]
271
- if (s_range[0] <= s_val <= s_range[1] and
272
- not any(abs(s_val - s_prev) < tolerance for s_prev in intersections)):
273
- if abs(equation(s_val)) < tolerance:
274
- intersections.append(s_val)
275
- except:
276
- continue
277
-
278
- # Sort intersections
279
- intersections = np.sort(np.array(intersections))
280
-
281
- # Ensure even number of intersections by checking for missed ones
282
- if len(intersections) % 2 != 0:
283
- refined_intersections = []
284
- for i in range(len(intersections)-1):
285
- mid_point = (intersections[i] + intersections[i+1])/2
286
- try:
287
- s_sol = fsolve(equation, mid_point, full_output=True, xtol=tolerance)
288
- if s_sol[2] == 1:
289
- s_val = s_sol[0][0]
290
- if (intersections[i] < s_val < intersections[i+1] and
291
- abs(equation(s_val)) < tolerance):
292
- refined_intersections.append(s_val)
293
- except:
294
- continue
295
-
296
- intersections = np.sort(np.append(intersections, refined_intersections))
297
-
298
- return intersections
299
 
300
  def generate_curves_plot(z, y, beta, a, s_range):
301
- s = np.linspace(s_range[0], s_range[1], 2000)
302
-
303
- # Compute curves
304
- y1 = curve1(s, z, y)
305
- y2 = curve2(s, y, beta, a)
306
-
307
- # Find intersections with improved accuracy
308
- intersections = find_intersections(z, y, beta, a, s_range)
309
-
310
- fig = go.Figure()
311
-
312
- fig.add_trace(
313
- go.Scatter(
314
- x=s, y=y1,
315
- mode='lines',
316
- name='z*s² + (z-y+1)*s + 1',
317
- line=dict(color='blue', width=2)
318
- )
319
- )
320
-
321
- fig.add_trace(
322
- go.Scatter(
323
- x=s, y=y2,
324
- mode='lines',
325
- name='y*β*((a-1)*s)/(a*s+1)',
326
- line=dict(color='red', width=2)
327
- )
328
- )
329
-
330
- if len(intersections) > 0:
331
- fig.add_trace(
332
- go.Scatter(
333
- x=intersections,
334
- y=curve1(intersections, z, y),
335
- mode='markers',
336
- name='Intersections',
337
- marker=dict(
338
- size=12,
339
- color='green',
340
- symbol='x',
341
- line=dict(width=2)
342
- )
343
- )
344
- )
345
-
346
- fig.update_layout(
347
- title=f"Curve Intersection Analysis (y={y:.4f}, β={beta:.4f}, a={a:.4f})",
348
- xaxis_title="s",
349
- yaxis_title="Value",
350
- hovermode="closest",
351
- showlegend=True,
352
- legend=dict(
353
- yanchor="top",
354
- y=0.99,
355
- xanchor="left",
356
- x=0.01
357
- )
358
- )
359
-
360
- return fig, intersections
361
 
362
  # Streamlit UI
363
  st.title("Cubic Root Analysis")
@@ -397,7 +420,7 @@ with tab1:
397
  """)
398
 
399
  with tab2:
400
- st.header("Plot Imaginary Parts of Roots vs. z")
401
 
402
  col1, col2 = st.columns([1, 2])
403
 
@@ -408,11 +431,12 @@ with tab2:
408
  z_min_2 = st.number_input("z_min", value=-10.0, key="z_min_2")
409
  z_max_2 = st.number_input("z_max", value=10.0, key="z_max_2")
410
 
411
- if st.button("Compute Im{s} vs. z"):
412
  with col2:
413
- fig = generate_ims_vs_z_plot(beta, y_2, z_a_2, z_min_2, z_max_2)
414
- if fig is not None:
415
- st.plotly_chart(fig, use_container_width=True)
 
416
 
417
  with tab3:
418
  st.header("Curve Intersection Analysis")
 
6
 
7
  # Configure Streamlit for Hugging Face Spaces
8
  st.set_page_config(
9
+ page_title="Cubic Root Analysis",
10
+ layout="wide",
11
+ initial_sidebar_state="collapsed"
12
  )
13
 
14
  #############################
 
26
 
27
  # Symbolic expression for the standard cubic discriminant
28
  Delta_expr = (
29
+ ( (b_sym*c_sym)/(6*a_sym**2) - (b_sym**3)/(27*a_sym**3) - d_sym/(2*a_sym) )**2
30
+ + ( c_sym/(3*a_sym) - (b_sym**2)/(9*a_sym**2) )**3
31
  )
32
 
33
  # Turn that into a fast numeric function:
 
35
 
36
  @st.cache_data
37
  def find_z_at_discriminant_zero(z_a, y, beta, z_min, z_max, steps=20000):
38
+ """
39
+ Numerically scan z in [z_min, z_max] looking for sign changes of
40
+ Delta(z) = 0. Returns all roots found via bisection.
41
+ """
42
+ z_grid = np.linspace(z_min, z_max, steps)
43
+ disc_vals = discriminant_func(z_grid, beta, z_a, y)
44
+
45
+ roots_found = []
46
+
47
+ # Scan for sign changes
48
+ for i in range(len(z_grid) - 1):
49
+ f1, f2 = disc_vals[i], disc_vals[i+1]
50
+ if np.isnan(f1) or np.isnan(f2):
51
+ continue
52
+
53
+ if f1 == 0.0:
54
+ roots_found.append(z_grid[i])
55
+ elif f2 == 0.0:
56
+ roots_found.append(z_grid[i+1])
57
+ elif f1*f2 < 0:
58
+ zl = z_grid[i]
59
+ zr = z_grid[i+1]
60
+ for _ in range(50):
61
+ mid = 0.5*(zl + zr)
62
+ fm = discriminant_func(mid, beta, z_a, y)
63
+ if fm == 0:
64
+ zl = zr = mid
65
+ break
66
+ if np.sign(fm) == np.sign(f1):
67
+ zl = mid
68
+ f1 = fm
69
+ else:
70
+ zr = mid
71
+ f2 = fm
72
+ root_approx = 0.5*(zl + zr)
73
+ roots_found.append(root_approx)
74
+
75
+ return np.array(roots_found)
76
 
77
  @st.cache_data
78
  def sweep_beta_and_find_z_bounds(z_a, y, z_min, z_max, beta_steps=51):
79
+ """
80
+ For each beta, find both the largest and smallest z where discriminant=0.
81
+ Returns (betas, z_min_values, z_max_values).
82
+ """
83
+ betas = np.linspace(0, 1, beta_steps)
84
+ z_min_values = []
85
+ z_max_values = []
86
+
87
+ for b in betas:
88
+ roots = find_z_at_discriminant_zero(z_a, y, b, z_min, z_max)
89
+ if len(roots) == 0:
90
+ z_min_values.append(np.nan)
91
+ z_max_values.append(np.nan)
92
+ else:
93
+ z_min_values.append(np.min(roots))
94
+ z_max_values.append(np.max(roots))
95
+
96
+ return betas, np.array(z_min_values), np.array(z_max_values)
97
 
98
  @st.cache_data
99
  def compute_low_y_curve(betas, z_a, y):
100
+ """
101
+ Compute the additional curve with proper handling of divide by zero cases
102
+ """
103
+ betas = np.array(betas)
104
+ with np.errstate(invalid='ignore', divide='ignore'):
105
+ sqrt_term = y * betas * (z_a - 1)
106
+ sqrt_term = np.where(sqrt_term < 0, np.nan, np.sqrt(sqrt_term))
107
+
108
+ term = (-1 + sqrt_term)/z_a
109
+ numerator = (y - 2)*term + y * betas * ((z_a - 1)/z_a) - 1/z_a - 1
110
+ denominator = term**2 + term
111
+
112
+ # Handle division by zero and invalid values
113
+ mask = (denominator != 0) & ~np.isnan(denominator) & ~np.isnan(numerator)
114
+ return np.where(mask, numerator/denominator, np.nan)
115
 
116
  @st.cache_data
117
  def compute_high_y_curve(betas, z_a, y):
118
+ """
119
+ Compute the expression: ((4y + 12)(4 - a) + 16y*β*(a - 1))/(3(4 - a))
120
+ """
121
+ betas = np.array(betas)
122
+ denominator = 3*(4 - z_a)
123
+
124
+ if denominator == 0:
125
+ return np.full_like(betas, np.nan)
126
+
127
+ numerator = (4*y + 12)*(4 - z_a) + 16*y*betas*(z_a - 1)
128
+ return numerator/denominator
129
 
130
  def generate_z_vs_beta_plot(z_a, y, z_min, z_max):
131
+ if z_a <= 0 or y <= 0 or z_min >= z_max:
132
+ st.error("Invalid input parameters.")
133
+ return None
134
+
135
+ beta_steps = 101
136
+ betas = np.linspace(0, 1, beta_steps)
137
+
138
+ betas, z_mins, z_maxs = sweep_beta_and_find_z_bounds(z_a, y, z_min, z_max, beta_steps=beta_steps)
139
+ low_y_curve = compute_low_y_curve(betas, z_a, y)
140
+ high_y_curve = compute_high_y_curve(betas, z_a, y)
141
+
142
+ fig = go.Figure()
143
+
144
+ fig.add_trace(
145
+ go.Scatter(
146
+ x=betas,
147
+ y=z_maxs,
148
+ mode="markers+lines",
149
+ name="Upper z*(β)",
150
+ marker=dict(size=5, color='blue'),
151
+ line=dict(color='blue'),
152
+ )
153
+ )
154
+
155
+ fig.add_trace(
156
+ go.Scatter(
157
+ x=betas,
158
+ y=z_mins,
159
+ mode="markers+lines",
160
+ name="Lower z*(β)",
161
+ marker=dict(size=5, color='lightblue'),
162
+ line=dict(color='lightblue'),
163
+ )
164
+ )
165
+
166
+ fig.add_trace(
167
+ go.Scatter(
168
+ x=betas,
169
+ y=low_y_curve,
170
+ mode="markers+lines",
171
+ name="Low y Expression",
172
+ marker=dict(size=5, color='red'),
173
+ line=dict(color='red'),
174
+ )
175
+ )
176
+
177
+ fig.add_trace(
178
+ go.Scatter(
179
+ x=betas,
180
+ y=high_y_curve,
181
+ mode="markers+lines",
182
+ name="High y Expression",
183
+ marker=dict(size=5, color='green'),
184
+ line=dict(color='green'),
185
+ )
186
+ )
187
+
188
+ fig.update_layout(
189
+ title="Curves vs β: z*(β) boundaries and Asymptotic Expressions",
190
+ xaxis_title="β",
191
+ yaxis_title="Value",
192
+ hovermode="x unified",
193
+ )
194
+ return fig
195
 
196
  def compute_cubic_roots(z, beta, z_a, y):
197
+ """
198
+ Compute the roots of the cubic equation for given parameters.
199
+ """
200
+ a = z * z_a
201
+ b = z * z_a + z + z_a - z_a*y
202
+ c = z + z_a + 1 - y*(beta*z_a + 1 - beta)
203
+ d = 1
204
+
205
+ coeffs = [a, b, c, d]
206
+ roots = np.roots(coeffs)
207
+ return roots
208
 
209
+ def generate_root_plots(beta, y, z_a, z_min, z_max):
210
+ """Generate both Im(s) and Re(s) vs. z plots"""
211
+ if z_a <= 0 or y <= 0 or z_min >= z_max:
212
+ st.error("Invalid input parameters.")
213
+ return None, None
214
+
215
+ z_points = np.linspace(z_min, z_max, 1000)
216
+ ims = []
217
+ res = []
218
+
219
+ for z in z_points:
220
+ roots = compute_cubic_roots(z, beta, z_a, y)
221
+ roots = sorted(roots, key=lambda x: abs(x.imag))
222
+ ims.append([root.imag for root in roots])
223
+ res.append([root.real for root in roots])
224
+
225
+ ims = np.array(ims)
226
+ res = np.array(res)
227
+
228
+ # Create Im(s) plot
229
+ fig_im = go.Figure()
230
+ for i in range(3):
231
+ fig_im.add_trace(
232
+ go.Scatter(
233
+ x=z_points,
234
+ y=ims[:,i],
235
+ mode="lines",
236
+ name=f"Im{{s{i+1}}}",
237
+ line=dict(width=2),
238
+ )
239
+ )
240
+ fig_im.update_layout(
241
+ title=f"Im{{s}} vs. z (β={beta:.3f}, y={y:.3f}, z_a={z_a:.3f})",
242
+ xaxis_title="z",
243
+ yaxis_title="Im{s}",
244
+ hovermode="x unified",
245
+ )
246
+
247
+ # Create Re(s) plot
248
+ fig_re = go.Figure()
249
+ for i in range(3):
250
+ fig_re.add_trace(
251
+ go.Scatter(
252
+ x=z_points,
253
+ y=res[:,i],
254
+ mode="lines",
255
+ name=f"Re{{s{i+1}}}",
256
+ line=dict(width=2),
257
+ )
258
+ )
259
+ fig_re.update_layout(
260
+ title=f"Re{{s}} vs. z (β={beta:.3f}, y={y:.3f}, z_a={z_a:.3f})",
261
+ xaxis_title="z",
262
+ yaxis_title="Re{s}",
263
+ hovermode="x unified",
264
+ )
265
+
266
+ return fig_im, fig_re
267
 
268
  def curve1(s, z, y):
269
+ """First curve: z*s^2 + (z-y+1)*s + 1"""
270
+ return z*s**2 + (z-y+1)*s + 1
271
 
272
  def curve2(s, y, beta, a):
273
+ """Second curve: y*β*((a-1)*s)/(a*s+1)"""
274
+ return y*beta*((a-1)*s)/(a*s+1)
275
 
276
  def find_intersections(z, y, beta, a, s_range):
277
+ """Find intersections between the two curves with improved accuracy"""
278
+ def equation(s):
279
+ return curve1(s, z, y) - curve2(s, y, beta, a)
280
+
281
+ # Create a finer grid of initial guesses
282
+ s_guesses = np.linspace(s_range[0], s_range[1], 200)
283
+ intersections = []
284
+
285
+ # Parameters for accuracy
286
+ tolerance = 1e-10
287
+
288
+ # First pass: find all potential intersections
289
+ for s_guess in s_guesses:
290
+ try:
291
+ s_sol = fsolve(equation, s_guess, full_output=True, xtol=tolerance)
292
+ if s_sol[2] == 1: # Check if convergence was achieved
293
+ s_val = s_sol[0][0]
294
+ if (s_range[0] <= s_val <= s_range[1] and
295
+ not any(abs(s_val - s_prev) < tolerance for s_prev in intersections)):
296
+ if abs(equation(s_val)) < tolerance:
297
+ intersections.append(s_val)
298
+ except:
299
+ continue
300
+
301
+ # Sort intersections
302
+ intersections = np.sort(np.array(intersections))
303
+
304
+ # Ensure even number of intersections by checking for missed ones
305
+ if len(intersections) % 2 != 0:
306
+ refined_intersections = []
307
+ for i in range(len(intersections)-1):
308
+ mid_point = (intersections[i] + intersections[i+1])/2
309
+ try:
310
+ s_sol = fsolve(equation, mid_point, full_output=True, xtol=tolerance)
311
+ if s_sol[2] == 1:
312
+ s_val = s_sol[0][0]
313
+ if (intersections[i] < s_val < intersections[i+1] and
314
+ abs(equation(s_val)) < tolerance):
315
+ refined_intersections.append(s_val)
316
+ except:
317
+ continue
318
+
319
+ intersections = np.sort(np.append(intersections, refined_intersections))
320
+
321
+ return intersections
322
 
323
  def generate_curves_plot(z, y, beta, a, s_range):
324
+ s = np.linspace(s_range[0], s_range[1], 2000)
325
+
326
+ # Compute curves
327
+ y1 = curve1(s, z, y)
328
+ y2 = curve2(s, y, beta, a)
329
+
330
+ # Find intersections with improved accuracy
331
+ intersections = find_intersections(z, y, beta, a, s_range)
332
+
333
+ fig = go.Figure()
334
+
335
+ fig.add_trace(
336
+ go.Scatter(
337
+ x=s, y=y1,
338
+ mode='lines',
339
+ name='z*s² + (z-y+1)*s + 1',
340
+ line=dict(color='blue', width=2)
341
+ )
342
+ )
343
+
344
+ fig.add_trace(
345
+ go.Scatter(
346
+ x=s, y=y2,
347
+ mode='lines',
348
+ name='y*β*((a-1)*s)/(a*s+1)',
349
+ line=dict(color='red', width=2)
350
+ )
351
+ )
352
+
353
+ if len(intersections) > 0:
354
+ fig.add_trace(
355
+ go.Scatter(
356
+ x=intersections,
357
+ y=curve1(intersections, z, y),
358
+ mode='markers',
359
+ name='Intersections',
360
+ marker=dict(
361
+ size=12,
362
+ color='green',
363
+ symbol='x',
364
+ line=dict(width=2)
365
+ )
366
+ )
367
+ )
368
+
369
+ fig.update_layout(
370
+ title=f"Curve Intersection Analysis (y={y:.4f}, β={beta:.4f}, a={a:.4f})",
371
+ xaxis_title="s",
372
+ yaxis_title="Value",
373
+ hovermode="closest",
374
+ showlegend=True,
375
+ legend=dict(
376
+ yanchor="top",
377
+ y=0.99,
378
+ xanchor="left",
379
+ x=0.01
380
+ )
381
+ )
382
+
383
+ return fig, intersections
384
 
385
  # Streamlit UI
386
  st.title("Cubic Root Analysis")
 
420
  """)
421
 
422
  with tab2:
423
+ st.header("Plot Complex Roots vs. z")
424
 
425
  col1, col2 = st.columns([1, 2])
426
 
 
431
  z_min_2 = st.number_input("z_min", value=-10.0, key="z_min_2")
432
  z_max_2 = st.number_input("z_max", value=10.0, key="z_max_2")
433
 
434
+ if st.button("Compute Complex Roots vs. z"):
435
  with col2:
436
+ fig_im, fig_re = generate_root_plots(beta, y_2, z_a_2, z_min_2, z_max_2)
437
+ if fig_im is not None and fig_re is not None:
438
+ st.plotly_chart(fig_im, use_container_width=True)
439
+ st.plotly_chart(fig_re, use_container_width=True)
440
 
441
  with tab3:
442
  st.header("Curve Intersection Analysis")