euler314 commited on
Commit
2cad740
·
verified ·
1 Parent(s): 21251a0

Update cubic_cpp.cpp

Browse files
Files changed (1) hide show
  1. cubic_cpp.cpp +91 -1
cubic_cpp.cpp CHANGED
@@ -681,7 +681,94 @@ compute_derivatives(const std::vector<double>& curve, const std::vector<double>&
681
 
682
  return std::make_tuple(d1, d2);
683
  }
684
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  // Python module definition
686
  PYBIND11_MODULE(cubic_cpp, m) {
687
  m.doc() = "C++ accelerated functions for cubic root analysis";
@@ -736,4 +823,7 @@ PYBIND11_MODULE(cubic_cpp, m) {
736
  m.def("generate_eigenvalue_distribution", &generate_eigenvalue_distribution,
737
  "Generate eigenvalue distribution for a specific beta",
738
  py::arg("beta"), py::arg("y"), py::arg("z_a"), py::arg("n"), py::arg("seed"));
 
 
 
739
  }
 
681
 
682
  return std::make_tuple(d1, d2);
683
  }
684
+ // Compute cubic equation roots
685
+ std::vector<std::complex<double>> compute_cubic_roots(double z, double beta, double z_a, double y) {
686
+ // Apply the condition for y
687
+ double y_effective = apply_y_condition(y);
688
+
689
+ // Coefficients in the form ax^3 + bx^2 + cx + d = 0
690
+ double a = z * z_a;
691
+ double b = z * z_a + z + z_a - z_a * y_effective;
692
+ double c = z + z_a + 1 - y_effective * (beta * z_a + 1 - beta);
693
+ double d = 1.0;
694
+
695
+ // Handle special cases
696
+ if (std::abs(a) < 1e-10) {
697
+ // Quadratic case or linear case
698
+ std::vector<std::complex<double>> roots(3);
699
+ if (std::abs(b) < 1e-10) {
700
+ // Linear case
701
+ roots[0] = std::complex<double>(-d / c, 0.0);
702
+ roots[1] = std::complex<double>(0.0, 0.0);
703
+ roots[2] = std::complex<double>(0.0, 0.0);
704
+ } else {
705
+ // Quadratic case: bx^2 + cx + d = 0
706
+ double discriminant = c*c - 4*b*d;
707
+ if (discriminant >= 0) {
708
+ double sqrt_disc = std::sqrt(discriminant);
709
+ roots[0] = std::complex<double>((-c + sqrt_disc) / (2 * b), 0.0);
710
+ roots[1] = std::complex<double>((-c - sqrt_disc) / (2 * b), 0.0);
711
+ } else {
712
+ double sqrt_disc = std::sqrt(-discriminant);
713
+ roots[0] = std::complex<double>(-c / (2 * b), sqrt_disc / (2 * b));
714
+ roots[1] = std::complex<double>(-c / (2 * b), -sqrt_disc / (2 * b));
715
+ }
716
+ roots[2] = std::complex<double>(0.0, 0.0);
717
+ }
718
+ return roots;
719
+ }
720
+
721
+ // Standard cubic case
722
+ // First, convert to depressed cubic t^3 + pt + q = 0
723
+ b /= a;
724
+ c /= a;
725
+ d /= a;
726
+
727
+ double p = c - b*b/3;
728
+ double q = d - b*c/3 + 2*b*b*b/27;
729
+ double disc = q*q/4 + p*p*p/27;
730
+
731
+ std::vector<std::complex<double>> roots(3);
732
+
733
+ // Handle different cases based on discriminant
734
+ if (std::abs(disc) < 1e-10) {
735
+ // Discriminant is zero, potential multiple roots
736
+ if (std::abs(p) < 1e-10 && std::abs(q) < 1e-10) {
737
+ // Triple root
738
+ roots[0] = roots[1] = roots[2] = std::complex<double>(-b/3, 0.0);
739
+ } else {
740
+ // One double root and one single root
741
+ double u;
742
+ if (q > 0) u = -std::cbrt(q/2);
743
+ else u = std::cbrt(-q/2);
744
+
745
+ roots[0] = std::complex<double>(2*u - b/3, 0.0);
746
+ roots[1] = roots[2] = std::complex<double>(-u - b/3, 0.0);
747
+ }
748
+ } else if (disc > 0) {
749
+ // One real root and two complex conjugate roots
750
+ double u = std::cbrt(-q/2 + std::sqrt(disc));
751
+ double v = std::cbrt(-q/2 - std::sqrt(disc));
752
+
753
+ roots[0] = std::complex<double>(u + v - b/3, 0.0);
754
+
755
+ double real_part = -(u + v)/2 - b/3;
756
+ double imag_part = std::sqrt(3) * (u - v) / 2;
757
+
758
+ roots[1] = std::complex<double>(real_part, imag_part);
759
+ roots[2] = std::complex<double>(real_part, -imag_part);
760
+ } else {
761
+ // Three distinct real roots
762
+ double theta = std::acos(-q/2 * std::sqrt(-27/(p*p*p)));
763
+ double coef = 2 * std::sqrt(-p/3);
764
+
765
+ roots[0] = std::complex<double>(coef * std::cos(theta/3) - b/3, 0.0);
766
+ roots[1] = std::complex<double>(coef * std::cos((theta + 2*M_PI)/3) - b/3, 0.0);
767
+ roots[2] = std::complex<double>(coef * std::cos((theta + 4*M_PI)/3) - b/3, 0.0);
768
+ }
769
+
770
+ return roots;
771
+ }
772
  // Python module definition
773
  PYBIND11_MODULE(cubic_cpp, m) {
774
  m.doc() = "C++ accelerated functions for cubic root analysis";
 
823
  m.def("generate_eigenvalue_distribution", &generate_eigenvalue_distribution,
824
  "Generate eigenvalue distribution for a specific beta",
825
  py::arg("beta"), py::arg("y"), py::arg("z_a"), py::arg("n"), py::arg("seed"));
826
+ m.def("compute_cubic_roots", &compute_cubic_roots,
827
+ "Compute the roots of the cubic equation",
828
+ py::arg("z"), py::arg("beta"), py::arg("z_a"), py::arg("y"));
829
  }