File size: 5,411 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
/* Translated from Cython into C++ by SciPy developers in 2023.
 * Original header with Copyright information appears below.
 */

/* Implementation of the Lambert W function [1]. Based on MPMath
 *  Implementation [2], and documentation [3].
 *
 * Copyright: Yosef Meller, 2009
 * Author email: [email protected]
 *
 * Distributed under the same license as SciPy
 *
 *
 * References:
 * [1] On the Lambert W function, Adv. Comp. Math. 5 (1996) 329-359,
 *     available online: https://web.archive.org/web/20230123211413/https://cs.uwaterloo.ca/research/tr/1993/03/W.pdf
 * [2] mpmath source code,
 https://github.com/mpmath/mpmath/blob/c5939823669e1bcce151d89261b802fe0d8978b4/mpmath/functions/functions.py#L435-L461
 * [3]
 https://web.archive.org/web/20230504171447/https://mpmath.org/doc/current/functions/powers.html#lambert-w-function
 *

 * TODO: use a series expansion when extremely close to the branch point
 * at `-1/e` and make sure that the proper branch is chosen there.
 */

#pragma once

#include "config.h"
#include "error.h"
#include "evalpoly.h"

namespace xsf {
constexpr double EXPN1 = 0.36787944117144232159553; // exp(-1)
constexpr double OMEGA = 0.56714329040978387299997; // W(1, 0)

namespace detail {
    XSF_HOST_DEVICE inline std::complex<double> lambertw_branchpt(std::complex<double> z) {
        // Series for W(z, 0) around the branch point; see 4.22 in [1].
        double coeffs[] = {-1.0 / 3.0, 1.0, -1.0};
        std::complex<double> p = std::sqrt(2.0 * (M_E * z + 1.0));

        return cevalpoly(coeffs, 2, p);
    }

    XSF_HOST_DEVICE inline std::complex<double> lambertw_pade0(std::complex<double> z) {
        // (3, 2) Pade approximation for W(z, 0) around 0.
        double num[] = {12.85106382978723404255, 12.34042553191489361902, 1.0};
        double denom[] = {32.53191489361702127660, 14.34042553191489361702, 1.0};

        /* This only gets evaluated close to 0, so we don't need a more
         * careful algorithm that avoids overflow in the numerator for
         * large z. */
        return z * cevalpoly(num, 2, z) / cevalpoly(denom, 2, z);
    }

    XSF_HOST_DEVICE inline std::complex<double> lambertw_asy(std::complex<double> z, long k) {
        /* Compute the W function using the first two terms of the
         * asymptotic series. See 4.20 in [1].
         */
        std::complex<double> w = std::log(z) + 2.0 * M_PI * k * std::complex<double>(0, 1);
        return w - std::log(w);
    }

} // namespace detail

XSF_HOST_DEVICE inline std::complex<double> lambertw(std::complex<double> z, long k, double tol) {
    double absz;
    std::complex<double> w;
    std::complex<double> ew, wew, wewz, wn;

    if (std::isnan(z.real()) || std::isnan(z.imag())) {
        return z;
    }
    if (z.real() == std::numeric_limits<double>::infinity()) {
        return z + 2.0 * M_PI * k * std::complex<double>(0, 1);
    }
    if (z.real() == -std::numeric_limits<double>::infinity()) {
        return -z + (2.0 * M_PI * k + M_PI) * std::complex<double>(0, 1);
    }
    if (z == 0.0) {
        if (k == 0) {
            return z;
        }
        set_error("lambertw", SF_ERROR_SINGULAR, NULL);
        return -std::numeric_limits<double>::infinity();
    }
    if (z == 1.0 && k == 0) {
        // Split out this case because the asymptotic series blows up
        return OMEGA;
    }

    absz = std::abs(z);
    // Get an initial guess for Halley's method
    if (k == 0) {
        if (std::abs(z + EXPN1) < 0.3) {
            w = detail::lambertw_branchpt(z);
        } else if (-1.0 < z.real() && z.real() < 1.5 && std::abs(z.imag()) < 1.0 &&
                   -2.5 * std::abs(z.imag()) - 0.2 < z.real()) {
            /* Empirically determined decision boundary where the Pade
             * approximation is more accurate. */
            w = detail::lambertw_pade0(z);
        } else {
            w = detail::lambertw_asy(z, k);
        }
    } else if (k == -1) {
        if (absz <= EXPN1 && z.imag() == 0.0 && z.real() < 0.0) {
            w = std::log(-z.real());
        } else {
            w = detail::lambertw_asy(z, k);
        }
    } else {
        w = detail::lambertw_asy(z, k);
    }

    // Halley's method; see 5.9 in [1]
    if (w.real() >= 0) {
        // Rearrange the formula to avoid overflow in exp
        for (int i = 0; i < 100; i++) {
            ew = std::exp(-w);
            wewz = w - z * ew;
            wn = w - wewz / (w + 1.0 - (w + 2.0) * wewz / (2.0 * w + 2.0));
            if (std::abs(wn - w) <= tol * std::abs(wn)) {
                return wn;
            }
            w = wn;
        }
    } else {
        for (int i = 0; i < 100; i++) {
            ew = std::exp(w);
            wew = w * ew;
            wewz = wew - z;
            wn = w - wewz / (wew + ew - (w + 2.0) * wewz / (2.0 * w + 2.0));
            if (std::abs(wn - w) <= tol * std::abs(wn)) {
                return wn;
            }
            w = wn;
        }
    }

    set_error("lambertw", SF_ERROR_SLOW, "iteration failed to converge: %g + %gj", z.real(), z.imag());
    return {std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN()};
}

XSF_HOST_DEVICE inline std::complex<float> lambertw(std::complex<float> z, long k, float tol) {
    return static_cast<std::complex<float>>(
        lambertw(static_cast<std::complex<double>>(z), k, static_cast<double>(tol)));
}

} // namespace xsf