reach-vb HF staff commited on
Commit
eb90369
·
1 Parent(s): d2e7957

5196c2cb84e1a787c43794229370aa2a1975ce16c5a8ae4ded7470fd1bfe6153

Browse files
Files changed (50) hide show
  1. lib/python3.11/site-packages/mpmath/functions/functions.py +645 -0
  2. lib/python3.11/site-packages/mpmath/functions/hypergeometric.py +1413 -0
  3. lib/python3.11/site-packages/mpmath/functions/orthogonal.py +493 -0
  4. lib/python3.11/site-packages/mpmath/functions/qfunctions.py +280 -0
  5. lib/python3.11/site-packages/mpmath/functions/rszeta.py +1403 -0
  6. lib/python3.11/site-packages/mpmath/functions/signals.py +32 -0
  7. lib/python3.11/site-packages/mpmath/functions/theta.py +1049 -0
  8. lib/python3.11/site-packages/mpmath/functions/zeta.py +1154 -0
  9. lib/python3.11/site-packages/mpmath/functions/zetazeros.py +1018 -0
  10. lib/python3.11/site-packages/mpmath/identification.py +844 -0
  11. lib/python3.11/site-packages/mpmath/libmp/__init__.py +77 -0
  12. lib/python3.11/site-packages/mpmath/libmp/__pycache__/__init__.cpython-311.pyc +0 -0
  13. lib/python3.11/site-packages/mpmath/libmp/__pycache__/backend.cpython-311.pyc +0 -0
  14. lib/python3.11/site-packages/mpmath/libmp/__pycache__/gammazeta.cpython-311.pyc +0 -0
  15. lib/python3.11/site-packages/mpmath/libmp/__pycache__/libelefun.cpython-311.pyc +0 -0
  16. lib/python3.11/site-packages/mpmath/libmp/__pycache__/libhyper.cpython-311.pyc +0 -0
  17. lib/python3.11/site-packages/mpmath/libmp/__pycache__/libintmath.cpython-311.pyc +0 -0
  18. lib/python3.11/site-packages/mpmath/libmp/__pycache__/libmpc.cpython-311.pyc +0 -0
  19. lib/python3.11/site-packages/mpmath/libmp/__pycache__/libmpf.cpython-311.pyc +0 -0
  20. lib/python3.11/site-packages/mpmath/libmp/__pycache__/libmpi.cpython-311.pyc +0 -0
  21. lib/python3.11/site-packages/mpmath/libmp/backend.py +115 -0
  22. lib/python3.11/site-packages/mpmath/libmp/gammazeta.py +2167 -0
  23. lib/python3.11/site-packages/mpmath/libmp/libelefun.py +1428 -0
  24. lib/python3.11/site-packages/mpmath/libmp/libhyper.py +1150 -0
  25. lib/python3.11/site-packages/mpmath/libmp/libintmath.py +584 -0
  26. lib/python3.11/site-packages/mpmath/libmp/libmpc.py +835 -0
  27. lib/python3.11/site-packages/mpmath/libmp/libmpf.py +1414 -0
  28. lib/python3.11/site-packages/mpmath/libmp/libmpi.py +935 -0
  29. lib/python3.11/site-packages/mpmath/math2.py +672 -0
  30. lib/python3.11/site-packages/mpmath/matrices/__init__.py +2 -0
  31. lib/python3.11/site-packages/mpmath/matrices/__pycache__/__init__.cpython-311.pyc +0 -0
  32. lib/python3.11/site-packages/mpmath/matrices/__pycache__/calculus.cpython-311.pyc +0 -0
  33. lib/python3.11/site-packages/mpmath/matrices/__pycache__/eigen.cpython-311.pyc +0 -0
  34. lib/python3.11/site-packages/mpmath/matrices/__pycache__/eigen_symmetric.cpython-311.pyc +0 -0
  35. lib/python3.11/site-packages/mpmath/matrices/__pycache__/linalg.cpython-311.pyc +0 -0
  36. lib/python3.11/site-packages/mpmath/matrices/__pycache__/matrices.cpython-311.pyc +0 -0
  37. lib/python3.11/site-packages/mpmath/matrices/calculus.py +531 -0
  38. lib/python3.11/site-packages/mpmath/matrices/eigen.py +877 -0
  39. lib/python3.11/site-packages/mpmath/matrices/eigen_symmetric.py +1807 -0
  40. lib/python3.11/site-packages/mpmath/matrices/linalg.py +790 -0
  41. lib/python3.11/site-packages/mpmath/matrices/matrices.py +1005 -0
  42. lib/python3.11/site-packages/mpmath/rational.py +240 -0
  43. lib/python3.11/site-packages/mpmath/tests/__init__.py +0 -0
  44. lib/python3.11/site-packages/mpmath/tests/__pycache__/__init__.cpython-311.pyc +0 -0
  45. lib/python3.11/site-packages/mpmath/tests/__pycache__/extratest_gamma.cpython-311.pyc +0 -0
  46. lib/python3.11/site-packages/mpmath/tests/__pycache__/extratest_zeta.cpython-311.pyc +0 -0
  47. lib/python3.11/site-packages/mpmath/tests/__pycache__/runtests.cpython-311.pyc +0 -0
  48. lib/python3.11/site-packages/mpmath/tests/__pycache__/test_basic_ops.cpython-311.pyc +0 -0
  49. lib/python3.11/site-packages/mpmath/tests/__pycache__/test_bitwise.cpython-311.pyc +0 -0
  50. lib/python3.11/site-packages/mpmath/tests/__pycache__/test_calculus.cpython-311.pyc +0 -0
lib/python3.11/site-packages/mpmath/functions/functions.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..libmp.backend import xrange
2
+
3
+ class SpecialFunctions(object):
4
+ """
5
+ This class implements special functions using high-level code.
6
+
7
+ Elementary and some other functions (e.g. gamma function, basecase
8
+ hypergeometric series) are assumed to be predefined by the context as
9
+ "builtins" or "low-level" functions.
10
+ """
11
+ defined_functions = {}
12
+
13
+ # The series for the Jacobi theta functions converge for |q| < 1;
14
+ # in the current implementation they throw a ValueError for
15
+ # abs(q) > THETA_Q_LIM
16
+ THETA_Q_LIM = 1 - 10**-7
17
+
18
+ def __init__(self):
19
+ cls = self.__class__
20
+ for name in cls.defined_functions:
21
+ f, wrap = cls.defined_functions[name]
22
+ cls._wrap_specfun(name, f, wrap)
23
+
24
+ self.mpq_1 = self._mpq((1,1))
25
+ self.mpq_0 = self._mpq((0,1))
26
+ self.mpq_1_2 = self._mpq((1,2))
27
+ self.mpq_3_2 = self._mpq((3,2))
28
+ self.mpq_1_4 = self._mpq((1,4))
29
+ self.mpq_1_16 = self._mpq((1,16))
30
+ self.mpq_3_16 = self._mpq((3,16))
31
+ self.mpq_5_2 = self._mpq((5,2))
32
+ self.mpq_3_4 = self._mpq((3,4))
33
+ self.mpq_7_4 = self._mpq((7,4))
34
+ self.mpq_5_4 = self._mpq((5,4))
35
+ self.mpq_1_3 = self._mpq((1,3))
36
+ self.mpq_2_3 = self._mpq((2,3))
37
+ self.mpq_4_3 = self._mpq((4,3))
38
+ self.mpq_1_6 = self._mpq((1,6))
39
+ self.mpq_5_6 = self._mpq((5,6))
40
+ self.mpq_5_3 = self._mpq((5,3))
41
+
42
+ self._misc_const_cache = {}
43
+
44
+ self._aliases.update({
45
+ 'phase' : 'arg',
46
+ 'conjugate' : 'conj',
47
+ 'nthroot' : 'root',
48
+ 'polygamma' : 'psi',
49
+ 'hurwitz' : 'zeta',
50
+ #'digamma' : 'psi0',
51
+ #'trigamma' : 'psi1',
52
+ #'tetragamma' : 'psi2',
53
+ #'pentagamma' : 'psi3',
54
+ 'fibonacci' : 'fib',
55
+ 'factorial' : 'fac',
56
+ })
57
+
58
+ self.zetazero_memoized = self.memoize(self.zetazero)
59
+
60
+ # Default -- do nothing
61
+ @classmethod
62
+ def _wrap_specfun(cls, name, f, wrap):
63
+ setattr(cls, name, f)
64
+
65
+ # Optional fast versions of common functions in common cases.
66
+ # If not overridden, default (generic hypergeometric series)
67
+ # implementations will be used
68
+ def _besselj(ctx, n, z): raise NotImplementedError
69
+ def _erf(ctx, z): raise NotImplementedError
70
+ def _erfc(ctx, z): raise NotImplementedError
71
+ def _gamma_upper_int(ctx, z, a): raise NotImplementedError
72
+ def _expint_int(ctx, n, z): raise NotImplementedError
73
+ def _zeta(ctx, s): raise NotImplementedError
74
+ def _zetasum_fast(ctx, s, a, n, derivatives, reflect): raise NotImplementedError
75
+ def _ei(ctx, z): raise NotImplementedError
76
+ def _e1(ctx, z): raise NotImplementedError
77
+ def _ci(ctx, z): raise NotImplementedError
78
+ def _si(ctx, z): raise NotImplementedError
79
+ def _altzeta(ctx, s): raise NotImplementedError
80
+
81
+ def defun_wrapped(f):
82
+ SpecialFunctions.defined_functions[f.__name__] = f, True
83
+ return f
84
+
85
+ def defun(f):
86
+ SpecialFunctions.defined_functions[f.__name__] = f, False
87
+ return f
88
+
89
+ def defun_static(f):
90
+ setattr(SpecialFunctions, f.__name__, f)
91
+ return f
92
+
93
+ @defun_wrapped
94
+ def cot(ctx, z): return ctx.one / ctx.tan(z)
95
+
96
+ @defun_wrapped
97
+ def sec(ctx, z): return ctx.one / ctx.cos(z)
98
+
99
+ @defun_wrapped
100
+ def csc(ctx, z): return ctx.one / ctx.sin(z)
101
+
102
+ @defun_wrapped
103
+ def coth(ctx, z): return ctx.one / ctx.tanh(z)
104
+
105
+ @defun_wrapped
106
+ def sech(ctx, z): return ctx.one / ctx.cosh(z)
107
+
108
+ @defun_wrapped
109
+ def csch(ctx, z): return ctx.one / ctx.sinh(z)
110
+
111
+ @defun_wrapped
112
+ def acot(ctx, z):
113
+ if not z:
114
+ return ctx.pi * 0.5
115
+ else:
116
+ return ctx.atan(ctx.one / z)
117
+
118
+ @defun_wrapped
119
+ def asec(ctx, z): return ctx.acos(ctx.one / z)
120
+
121
+ @defun_wrapped
122
+ def acsc(ctx, z): return ctx.asin(ctx.one / z)
123
+
124
+ @defun_wrapped
125
+ def acoth(ctx, z):
126
+ if not z:
127
+ return ctx.pi * 0.5j
128
+ else:
129
+ return ctx.atanh(ctx.one / z)
130
+
131
+
132
+ @defun_wrapped
133
+ def asech(ctx, z): return ctx.acosh(ctx.one / z)
134
+
135
+ @defun_wrapped
136
+ def acsch(ctx, z): return ctx.asinh(ctx.one / z)
137
+
138
+ @defun
139
+ def sign(ctx, x):
140
+ x = ctx.convert(x)
141
+ if not x or ctx.isnan(x):
142
+ return x
143
+ if ctx._is_real_type(x):
144
+ if x > 0:
145
+ return ctx.one
146
+ else:
147
+ return -ctx.one
148
+ return x / abs(x)
149
+
150
+ @defun
151
+ def agm(ctx, a, b=1):
152
+ if b == 1:
153
+ return ctx.agm1(a)
154
+ a = ctx.convert(a)
155
+ b = ctx.convert(b)
156
+ return ctx._agm(a, b)
157
+
158
+ @defun_wrapped
159
+ def sinc(ctx, x):
160
+ if ctx.isinf(x):
161
+ return 1/x
162
+ if not x:
163
+ return x+1
164
+ return ctx.sin(x)/x
165
+
166
+ @defun_wrapped
167
+ def sincpi(ctx, x):
168
+ if ctx.isinf(x):
169
+ return 1/x
170
+ if not x:
171
+ return x+1
172
+ return ctx.sinpi(x)/(ctx.pi*x)
173
+
174
+ # TODO: tests; improve implementation
175
+ @defun_wrapped
176
+ def expm1(ctx, x):
177
+ if not x:
178
+ return ctx.zero
179
+ # exp(x) - 1 ~ x
180
+ if ctx.mag(x) < -ctx.prec:
181
+ return x + 0.5*x**2
182
+ # TODO: accurately eval the smaller of the real/imag parts
183
+ return ctx.sum_accurately(lambda: iter([ctx.exp(x),-1]),1)
184
+
185
+ @defun_wrapped
186
+ def log1p(ctx, x):
187
+ if not x:
188
+ return ctx.zero
189
+ if ctx.mag(x) < -ctx.prec:
190
+ return x - 0.5*x**2
191
+ return ctx.log(ctx.fadd(1, x, prec=2*ctx.prec))
192
+
193
+ @defun_wrapped
194
+ def powm1(ctx, x, y):
195
+ mag = ctx.mag
196
+ one = ctx.one
197
+ w = x**y - one
198
+ M = mag(w)
199
+ # Only moderate cancellation
200
+ if M > -8:
201
+ return w
202
+ # Check for the only possible exact cases
203
+ if not w:
204
+ if (not y) or (x in (1, -1, 1j, -1j) and ctx.isint(y)):
205
+ return w
206
+ x1 = x - one
207
+ magy = mag(y)
208
+ lnx = ctx.ln(x)
209
+ # Small y: x^y - 1 ~ log(x)*y + O(log(x)^2 * y^2)
210
+ if magy + mag(lnx) < -ctx.prec:
211
+ return lnx*y + (lnx*y)**2/2
212
+ # TODO: accurately eval the smaller of the real/imag part
213
+ return ctx.sum_accurately(lambda: iter([x**y, -1]), 1)
214
+
215
+ @defun
216
+ def _rootof1(ctx, k, n):
217
+ k = int(k)
218
+ n = int(n)
219
+ k %= n
220
+ if not k:
221
+ return ctx.one
222
+ elif 2*k == n:
223
+ return -ctx.one
224
+ elif 4*k == n:
225
+ return ctx.j
226
+ elif 4*k == 3*n:
227
+ return -ctx.j
228
+ return ctx.expjpi(2*ctx.mpf(k)/n)
229
+
230
+ @defun
231
+ def root(ctx, x, n, k=0):
232
+ n = int(n)
233
+ x = ctx.convert(x)
234
+ if k:
235
+ # Special case: there is an exact real root
236
+ if (n & 1 and 2*k == n-1) and (not ctx.im(x)) and (ctx.re(x) < 0):
237
+ return -ctx.root(-x, n)
238
+ # Multiply by root of unity
239
+ prec = ctx.prec
240
+ try:
241
+ ctx.prec += 10
242
+ v = ctx.root(x, n, 0) * ctx._rootof1(k, n)
243
+ finally:
244
+ ctx.prec = prec
245
+ return +v
246
+ return ctx._nthroot(x, n)
247
+
248
+ @defun
249
+ def unitroots(ctx, n, primitive=False):
250
+ gcd = ctx._gcd
251
+ prec = ctx.prec
252
+ try:
253
+ ctx.prec += 10
254
+ if primitive:
255
+ v = [ctx._rootof1(k,n) for k in range(n) if gcd(k,n) == 1]
256
+ else:
257
+ # TODO: this can be done *much* faster
258
+ v = [ctx._rootof1(k,n) for k in range(n)]
259
+ finally:
260
+ ctx.prec = prec
261
+ return [+x for x in v]
262
+
263
+ @defun
264
+ def arg(ctx, x):
265
+ x = ctx.convert(x)
266
+ re = ctx._re(x)
267
+ im = ctx._im(x)
268
+ return ctx.atan2(im, re)
269
+
270
+ @defun
271
+ def fabs(ctx, x):
272
+ return abs(ctx.convert(x))
273
+
274
+ @defun
275
+ def re(ctx, x):
276
+ x = ctx.convert(x)
277
+ if hasattr(x, "real"): # py2.5 doesn't have .real/.imag for all numbers
278
+ return x.real
279
+ return x
280
+
281
+ @defun
282
+ def im(ctx, x):
283
+ x = ctx.convert(x)
284
+ if hasattr(x, "imag"): # py2.5 doesn't have .real/.imag for all numbers
285
+ return x.imag
286
+ return ctx.zero
287
+
288
+ @defun
289
+ def conj(ctx, x):
290
+ x = ctx.convert(x)
291
+ try:
292
+ return x.conjugate()
293
+ except AttributeError:
294
+ return x
295
+
296
+ @defun
297
+ def polar(ctx, z):
298
+ return (ctx.fabs(z), ctx.arg(z))
299
+
300
+ @defun_wrapped
301
+ def rect(ctx, r, phi):
302
+ return r * ctx.mpc(*ctx.cos_sin(phi))
303
+
304
+ @defun
305
+ def log(ctx, x, b=None):
306
+ if b is None:
307
+ return ctx.ln(x)
308
+ wp = ctx.prec + 20
309
+ return ctx.ln(x, prec=wp) / ctx.ln(b, prec=wp)
310
+
311
+ @defun
312
+ def log10(ctx, x):
313
+ return ctx.log(x, 10)
314
+
315
+ @defun
316
+ def fmod(ctx, x, y):
317
+ return ctx.convert(x) % ctx.convert(y)
318
+
319
+ @defun
320
+ def degrees(ctx, x):
321
+ return x / ctx.degree
322
+
323
+ @defun
324
+ def radians(ctx, x):
325
+ return x * ctx.degree
326
+
327
+ def _lambertw_special(ctx, z, k):
328
+ # W(0,0) = 0; all other branches are singular
329
+ if not z:
330
+ if not k:
331
+ return z
332
+ return ctx.ninf + z
333
+ if z == ctx.inf:
334
+ if k == 0:
335
+ return z
336
+ else:
337
+ return z + 2*k*ctx.pi*ctx.j
338
+ if z == ctx.ninf:
339
+ return (-z) + (2*k+1)*ctx.pi*ctx.j
340
+ # Some kind of nan or complex inf/nan?
341
+ return ctx.ln(z)
342
+
343
+ import math
344
+ import cmath
345
+
346
+ def _lambertw_approx_hybrid(z, k):
347
+ imag_sign = 0
348
+ if hasattr(z, "imag"):
349
+ x = float(z.real)
350
+ y = z.imag
351
+ if y:
352
+ imag_sign = (-1) ** (y < 0)
353
+ y = float(y)
354
+ else:
355
+ x = float(z)
356
+ y = 0.0
357
+ imag_sign = 0
358
+ # hack to work regardless of whether Python supports -0.0
359
+ if not y:
360
+ y = 0.0
361
+ z = complex(x,y)
362
+ if k == 0:
363
+ if -4.0 < y < 4.0 and -1.0 < x < 2.5:
364
+ if imag_sign:
365
+ # Taylor series in upper/lower half-plane
366
+ if y > 1.00: return (0.876+0.645j) + (0.118-0.174j)*(z-(0.75+2.5j))
367
+ if y > 0.25: return (0.505+0.204j) + (0.375-0.132j)*(z-(0.75+0.5j))
368
+ if y < -1.00: return (0.876-0.645j) + (0.118+0.174j)*(z-(0.75-2.5j))
369
+ if y < -0.25: return (0.505-0.204j) + (0.375+0.132j)*(z-(0.75-0.5j))
370
+ # Taylor series near -1
371
+ if x < -0.5:
372
+ if imag_sign >= 0:
373
+ return (-0.318+1.34j) + (-0.697-0.593j)*(z+1)
374
+ else:
375
+ return (-0.318-1.34j) + (-0.697+0.593j)*(z+1)
376
+ # return real type
377
+ r = -0.367879441171442
378
+ if (not imag_sign) and x > r:
379
+ z = x
380
+ # Singularity near -1/e
381
+ if x < -0.2:
382
+ return -1 + 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r)
383
+ # Taylor series near 0
384
+ if x < 0.5: return z
385
+ # Simple linear approximation
386
+ return 0.2 + 0.3*z
387
+ if (not imag_sign) and x > 0.0:
388
+ L1 = math.log(x); L2 = math.log(L1)
389
+ else:
390
+ L1 = cmath.log(z); L2 = cmath.log(L1)
391
+ elif k == -1:
392
+ # return real type
393
+ r = -0.367879441171442
394
+ if (not imag_sign) and r < x < 0.0:
395
+ z = x
396
+ if (imag_sign >= 0) and y < 0.1 and -0.6 < x < -0.2:
397
+ return -1 - 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r)
398
+ if (not imag_sign) and -0.2 <= x < 0.0:
399
+ L1 = math.log(-x)
400
+ return L1 - math.log(-L1)
401
+ else:
402
+ if imag_sign == -1 and (not y) and x < 0.0:
403
+ L1 = cmath.log(z) - 3.1415926535897932j
404
+ else:
405
+ L1 = cmath.log(z) - 6.2831853071795865j
406
+ L2 = cmath.log(L1)
407
+ return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2)
408
+
409
+ def _lambertw_series(ctx, z, k, tol):
410
+ """
411
+ Return rough approximation for W_k(z) from an asymptotic series,
412
+ sufficiently accurate for the Halley iteration to converge to
413
+ the correct value.
414
+ """
415
+ magz = ctx.mag(z)
416
+ if (-10 < magz < 900) and (-1000 < k < 1000):
417
+ # Near the branch point at -1/e
418
+ if magz < 1 and abs(z+0.36787944117144) < 0.05:
419
+ if k == 0 or (k == -1 and ctx._im(z) >= 0) or \
420
+ (k == 1 and ctx._im(z) < 0):
421
+ delta = ctx.sum_accurately(lambda: [z, ctx.exp(-1)])
422
+ cancellation = -ctx.mag(delta)
423
+ ctx.prec += cancellation
424
+ # Use series given in Corless et al.
425
+ p = ctx.sqrt(2*(ctx.e*z+1))
426
+ ctx.prec -= cancellation
427
+ u = {0:ctx.mpf(-1), 1:ctx.mpf(1)}
428
+ a = {0:ctx.mpf(2), 1:ctx.mpf(-1)}
429
+ if k != 0:
430
+ p = -p
431
+ s = ctx.zero
432
+ # The series converges, so we could use it directly, but unless
433
+ # *extremely* close, it is better to just use the first few
434
+ # terms to get a good approximation for the iteration
435
+ for l in xrange(max(2,cancellation)):
436
+ if l not in u:
437
+ a[l] = ctx.fsum(u[j]*u[l+1-j] for j in xrange(2,l))
438
+ u[l] = (l-1)*(u[l-2]/2+a[l-2]/4)/(l+1)-a[l]/2-u[l-1]/(l+1)
439
+ term = u[l] * p**l
440
+ s += term
441
+ if ctx.mag(term) < -tol:
442
+ return s, True
443
+ l += 1
444
+ ctx.prec += cancellation//2
445
+ return s, False
446
+ if k == 0 or k == -1:
447
+ return _lambertw_approx_hybrid(z, k), False
448
+ if k == 0:
449
+ if magz < -1:
450
+ return z*(1-z), False
451
+ L1 = ctx.ln(z)
452
+ L2 = ctx.ln(L1)
453
+ elif k == -1 and (not ctx._im(z)) and (-0.36787944117144 < ctx._re(z) < 0):
454
+ L1 = ctx.ln(-z)
455
+ return L1 - ctx.ln(-L1), False
456
+ else:
457
+ # This holds both as z -> 0 and z -> inf.
458
+ # Relative error is O(1/log(z)).
459
+ L1 = ctx.ln(z) + 2j*ctx.pi*k
460
+ L2 = ctx.ln(L1)
461
+ return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2), False
462
+
463
+ @defun
464
+ def lambertw(ctx, z, k=0):
465
+ z = ctx.convert(z)
466
+ k = int(k)
467
+ if not ctx.isnormal(z):
468
+ return _lambertw_special(ctx, z, k)
469
+ prec = ctx.prec
470
+ ctx.prec += 20 + ctx.mag(k or 1)
471
+ wp = ctx.prec
472
+ tol = wp - 5
473
+ w, done = _lambertw_series(ctx, z, k, tol)
474
+ if not done:
475
+ # Use Halley iteration to solve w*exp(w) = z
476
+ two = ctx.mpf(2)
477
+ for i in xrange(100):
478
+ ew = ctx.exp(w)
479
+ wew = w*ew
480
+ wewz = wew-z
481
+ wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two))
482
+ if ctx.mag(wn-w) <= ctx.mag(wn) - tol:
483
+ w = wn
484
+ break
485
+ else:
486
+ w = wn
487
+ if i == 100:
488
+ ctx.warn("Lambert W iteration failed to converge for z = %s" % z)
489
+ ctx.prec = prec
490
+ return +w
491
+
492
+ @defun_wrapped
493
+ def bell(ctx, n, x=1):
494
+ x = ctx.convert(x)
495
+ if not n:
496
+ if ctx.isnan(x):
497
+ return x
498
+ return type(x)(1)
499
+ if ctx.isinf(x) or ctx.isinf(n) or ctx.isnan(x) or ctx.isnan(n):
500
+ return x**n
501
+ if n == 1: return x
502
+ if n == 2: return x*(x+1)
503
+ if x == 0: return ctx.sincpi(n)
504
+ return _polyexp(ctx, n, x, True) / ctx.exp(x)
505
+
506
+ def _polyexp(ctx, n, x, extra=False):
507
+ def _terms():
508
+ if extra:
509
+ yield ctx.sincpi(n)
510
+ t = x
511
+ k = 1
512
+ while 1:
513
+ yield k**n * t
514
+ k += 1
515
+ t = t*x/k
516
+ return ctx.sum_accurately(_terms, check_step=4)
517
+
518
+ @defun_wrapped
519
+ def polyexp(ctx, s, z):
520
+ if ctx.isinf(z) or ctx.isinf(s) or ctx.isnan(z) or ctx.isnan(s):
521
+ return z**s
522
+ if z == 0: return z*s
523
+ if s == 0: return ctx.expm1(z)
524
+ if s == 1: return ctx.exp(z)*z
525
+ if s == 2: return ctx.exp(z)*z*(z+1)
526
+ return _polyexp(ctx, s, z)
527
+
528
+ @defun_wrapped
529
+ def cyclotomic(ctx, n, z):
530
+ n = int(n)
531
+ if n < 0:
532
+ raise ValueError("n cannot be negative")
533
+ p = ctx.one
534
+ if n == 0:
535
+ return p
536
+ if n == 1:
537
+ return z - p
538
+ if n == 2:
539
+ return z + p
540
+ # Use divisor product representation. Unfortunately, this sometimes
541
+ # includes singularities for roots of unity, which we have to cancel out.
542
+ # Matching zeros/poles pairwise, we have (1-z^a)/(1-z^b) ~ a/b + O(z-1).
543
+ a_prod = 1
544
+ b_prod = 1
545
+ num_zeros = 0
546
+ num_poles = 0
547
+ for d in range(1,n+1):
548
+ if not n % d:
549
+ w = ctx.moebius(n//d)
550
+ # Use powm1 because it is important that we get 0 only
551
+ # if it really is exactly 0
552
+ b = -ctx.powm1(z, d)
553
+ if b:
554
+ p *= b**w
555
+ else:
556
+ if w == 1:
557
+ a_prod *= d
558
+ num_zeros += 1
559
+ elif w == -1:
560
+ b_prod *= d
561
+ num_poles += 1
562
+ #print n, num_zeros, num_poles
563
+ if num_zeros:
564
+ if num_zeros > num_poles:
565
+ p *= 0
566
+ else:
567
+ p *= a_prod
568
+ p /= b_prod
569
+ return p
570
+
571
+ @defun
572
+ def mangoldt(ctx, n):
573
+ r"""
574
+ Evaluates the von Mangoldt function `\Lambda(n) = \log p`
575
+ if `n = p^k` a power of a prime, and `\Lambda(n) = 0` otherwise.
576
+
577
+ **Examples**
578
+
579
+ >>> from mpmath import *
580
+ >>> mp.dps = 25; mp.pretty = True
581
+ >>> [mangoldt(n) for n in range(-2,3)]
582
+ [0.0, 0.0, 0.0, 0.0, 0.6931471805599453094172321]
583
+ >>> mangoldt(6)
584
+ 0.0
585
+ >>> mangoldt(7)
586
+ 1.945910149055313305105353
587
+ >>> mangoldt(8)
588
+ 0.6931471805599453094172321
589
+ >>> fsum(mangoldt(n) for n in range(101))
590
+ 94.04531122935739224600493
591
+ >>> fsum(mangoldt(n) for n in range(10001))
592
+ 10013.39669326311478372032
593
+
594
+ """
595
+ n = int(n)
596
+ if n < 2:
597
+ return ctx.zero
598
+ if n % 2 == 0:
599
+ # Must be a power of two
600
+ if n & (n-1) == 0:
601
+ return +ctx.ln2
602
+ else:
603
+ return ctx.zero
604
+ # TODO: the following could be generalized into a perfect
605
+ # power testing function
606
+ # ---
607
+ # Look for a small factor
608
+ for p in (3,5,7,11,13,17,19,23,29,31):
609
+ if not n % p:
610
+ q, r = n // p, 0
611
+ while q > 1:
612
+ q, r = divmod(q, p)
613
+ if r:
614
+ return ctx.zero
615
+ return ctx.ln(p)
616
+ if ctx.isprime(n):
617
+ return ctx.ln(n)
618
+ # Obviously, we could use arbitrary-precision arithmetic for this...
619
+ if n > 10**30:
620
+ raise NotImplementedError
621
+ k = 2
622
+ while 1:
623
+ p = int(n**(1./k) + 0.5)
624
+ if p < 2:
625
+ return ctx.zero
626
+ if p ** k == n:
627
+ if ctx.isprime(p):
628
+ return ctx.ln(p)
629
+ k += 1
630
+
631
+ @defun
632
+ def stirling1(ctx, n, k, exact=False):
633
+ v = ctx._stirling1(int(n), int(k))
634
+ if exact:
635
+ return int(v)
636
+ else:
637
+ return ctx.mpf(v)
638
+
639
+ @defun
640
+ def stirling2(ctx, n, k, exact=False):
641
+ v = ctx._stirling2(int(n), int(k))
642
+ if exact:
643
+ return int(v)
644
+ else:
645
+ return ctx.mpf(v)
lib/python3.11/site-packages/mpmath/functions/hypergeometric.py ADDED
@@ -0,0 +1,1413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..libmp.backend import xrange
2
+ from .functions import defun, defun_wrapped
3
+
4
+ def _check_need_perturb(ctx, terms, prec, discard_known_zeros):
5
+ perturb = recompute = False
6
+ extraprec = 0
7
+ discard = []
8
+ for term_index, term in enumerate(terms):
9
+ w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term
10
+ have_singular_nongamma_weight = False
11
+ # Avoid division by zero in leading factors (TODO:
12
+ # also check for near division by zero?)
13
+ for k, w in enumerate(w_s):
14
+ if not w:
15
+ if ctx.re(c_s[k]) <= 0 and c_s[k]:
16
+ perturb = recompute = True
17
+ have_singular_nongamma_weight = True
18
+ pole_count = [0, 0, 0]
19
+ # Check for gamma and series poles and near-poles
20
+ for data_index, data in enumerate([alpha_s, beta_s, b_s]):
21
+ for i, x in enumerate(data):
22
+ n, d = ctx.nint_distance(x)
23
+ # Poles
24
+ if n > 0:
25
+ continue
26
+ if d == ctx.ninf:
27
+ # OK if we have a polynomial
28
+ # ------------------------------
29
+ ok = False
30
+ if data_index == 2:
31
+ for u in a_s:
32
+ if ctx.isnpint(u) and u >= int(n):
33
+ ok = True
34
+ break
35
+ if ok:
36
+ continue
37
+ pole_count[data_index] += 1
38
+ # ------------------------------
39
+ #perturb = recompute = True
40
+ #return perturb, recompute, extraprec
41
+ elif d < -4:
42
+ extraprec += -d
43
+ recompute = True
44
+ if discard_known_zeros and pole_count[1] > pole_count[0] + pole_count[2] \
45
+ and not have_singular_nongamma_weight:
46
+ discard.append(term_index)
47
+ elif sum(pole_count):
48
+ perturb = recompute = True
49
+ return perturb, recompute, extraprec, discard
50
+
51
+ _hypercomb_msg = """
52
+ hypercomb() failed to converge to the requested %i bits of accuracy
53
+ using a working precision of %i bits. The function value may be zero or
54
+ infinite; try passing zeroprec=N or infprec=M to bound finite values between
55
+ 2^(-N) and 2^M. Otherwise try a higher maxprec or maxterms.
56
+ """
57
+
58
+ @defun
59
+ def hypercomb(ctx, function, params=[], discard_known_zeros=True, **kwargs):
60
+ orig = ctx.prec
61
+ sumvalue = ctx.zero
62
+ dist = ctx.nint_distance
63
+ ninf = ctx.ninf
64
+ orig_params = params[:]
65
+ verbose = kwargs.get('verbose', False)
66
+ maxprec = kwargs.get('maxprec', ctx._default_hyper_maxprec(orig))
67
+ kwargs['maxprec'] = maxprec # For calls to hypsum
68
+ zeroprec = kwargs.get('zeroprec')
69
+ infprec = kwargs.get('infprec')
70
+ perturbed_reference_value = None
71
+ hextra = 0
72
+ try:
73
+ while 1:
74
+ ctx.prec += 10
75
+ if ctx.prec > maxprec:
76
+ raise ValueError(_hypercomb_msg % (orig, ctx.prec))
77
+ orig2 = ctx.prec
78
+ params = orig_params[:]
79
+ terms = function(*params)
80
+ if verbose:
81
+ print()
82
+ print("ENTERING hypercomb main loop")
83
+ print("prec =", ctx.prec)
84
+ print("hextra", hextra)
85
+ perturb, recompute, extraprec, discard = \
86
+ _check_need_perturb(ctx, terms, orig, discard_known_zeros)
87
+ ctx.prec += extraprec
88
+ if perturb:
89
+ if "hmag" in kwargs:
90
+ hmag = kwargs["hmag"]
91
+ elif ctx._fixed_precision:
92
+ hmag = int(ctx.prec*0.3)
93
+ else:
94
+ hmag = orig + 10 + hextra
95
+ h = ctx.ldexp(ctx.one, -hmag)
96
+ ctx.prec = orig2 + 10 + hmag + 10
97
+ for k in range(len(params)):
98
+ params[k] += h
99
+ # Heuristically ensure that the perturbations
100
+ # are "independent" so that two perturbations
101
+ # don't accidentally cancel each other out
102
+ # in a subtraction.
103
+ h += h/(k+1)
104
+ if recompute:
105
+ terms = function(*params)
106
+ if discard_known_zeros:
107
+ terms = [term for (i, term) in enumerate(terms) if i not in discard]
108
+ if not terms:
109
+ return ctx.zero
110
+ evaluated_terms = []
111
+ for term_index, term_data in enumerate(terms):
112
+ w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term_data
113
+ if verbose:
114
+ print()
115
+ print(" Evaluating term %i/%i : %iF%i" % \
116
+ (term_index+1, len(terms), len(a_s), len(b_s)))
117
+ print(" powers", ctx.nstr(w_s), ctx.nstr(c_s))
118
+ print(" gamma", ctx.nstr(alpha_s), ctx.nstr(beta_s))
119
+ print(" hyper", ctx.nstr(a_s), ctx.nstr(b_s))
120
+ print(" z", ctx.nstr(z))
121
+ #v = ctx.hyper(a_s, b_s, z, **kwargs)
122
+ #for a in alpha_s: v *= ctx.gamma(a)
123
+ #for b in beta_s: v *= ctx.rgamma(b)
124
+ #for w, c in zip(w_s, c_s): v *= ctx.power(w, c)
125
+ v = ctx.fprod([ctx.hyper(a_s, b_s, z, **kwargs)] + \
126
+ [ctx.gamma(a) for a in alpha_s] + \
127
+ [ctx.rgamma(b) for b in beta_s] + \
128
+ [ctx.power(w,c) for (w,c) in zip(w_s,c_s)])
129
+ if verbose:
130
+ print(" Value:", v)
131
+ evaluated_terms.append(v)
132
+
133
+ if len(terms) == 1 and (not perturb):
134
+ sumvalue = evaluated_terms[0]
135
+ break
136
+
137
+ if ctx._fixed_precision:
138
+ sumvalue = ctx.fsum(evaluated_terms)
139
+ break
140
+
141
+ sumvalue = ctx.fsum(evaluated_terms)
142
+ term_magnitudes = [ctx.mag(x) for x in evaluated_terms]
143
+ max_magnitude = max(term_magnitudes)
144
+ sum_magnitude = ctx.mag(sumvalue)
145
+ cancellation = max_magnitude - sum_magnitude
146
+ if verbose:
147
+ print()
148
+ print(" Cancellation:", cancellation, "bits")
149
+ print(" Increased precision:", ctx.prec - orig, "bits")
150
+
151
+ precision_ok = cancellation < ctx.prec - orig
152
+
153
+ if zeroprec is None:
154
+ zero_ok = False
155
+ else:
156
+ zero_ok = max_magnitude - ctx.prec < -zeroprec
157
+ if infprec is None:
158
+ inf_ok = False
159
+ else:
160
+ inf_ok = max_magnitude > infprec
161
+
162
+ if precision_ok and (not perturb) or ctx.isnan(cancellation):
163
+ break
164
+ elif precision_ok:
165
+ if perturbed_reference_value is None:
166
+ hextra += 20
167
+ perturbed_reference_value = sumvalue
168
+ continue
169
+ elif ctx.mag(sumvalue - perturbed_reference_value) <= \
170
+ ctx.mag(sumvalue) - orig:
171
+ break
172
+ elif zero_ok:
173
+ sumvalue = ctx.zero
174
+ break
175
+ elif inf_ok:
176
+ sumvalue = ctx.inf
177
+ break
178
+ elif 'hmag' in kwargs:
179
+ break
180
+ else:
181
+ hextra *= 2
182
+ perturbed_reference_value = sumvalue
183
+ # Increase precision
184
+ else:
185
+ increment = min(max(cancellation, orig//2), max(extraprec,orig))
186
+ ctx.prec += increment
187
+ if verbose:
188
+ print(" Must start over with increased precision")
189
+ continue
190
+ finally:
191
+ ctx.prec = orig
192
+ return +sumvalue
193
+
194
+ @defun
195
+ def hyper(ctx, a_s, b_s, z, **kwargs):
196
+ """
197
+ Hypergeometric function, general case.
198
+ """
199
+ z = ctx.convert(z)
200
+ p = len(a_s)
201
+ q = len(b_s)
202
+ a_s = [ctx._convert_param(a) for a in a_s]
203
+ b_s = [ctx._convert_param(b) for b in b_s]
204
+ # Reduce degree by eliminating common parameters
205
+ if kwargs.get('eliminate', True):
206
+ elim_nonpositive = kwargs.get('eliminate_all', False)
207
+ i = 0
208
+ while i < q and a_s:
209
+ b = b_s[i]
210
+ if b in a_s and (elim_nonpositive or not ctx.isnpint(b[0])):
211
+ a_s.remove(b)
212
+ b_s.remove(b)
213
+ p -= 1
214
+ q -= 1
215
+ else:
216
+ i += 1
217
+ # Handle special cases
218
+ if p == 0:
219
+ if q == 1: return ctx._hyp0f1(b_s, z, **kwargs)
220
+ elif q == 0: return ctx.exp(z)
221
+ elif p == 1:
222
+ if q == 1: return ctx._hyp1f1(a_s, b_s, z, **kwargs)
223
+ elif q == 2: return ctx._hyp1f2(a_s, b_s, z, **kwargs)
224
+ elif q == 0: return ctx._hyp1f0(a_s[0][0], z)
225
+ elif p == 2:
226
+ if q == 1: return ctx._hyp2f1(a_s, b_s, z, **kwargs)
227
+ elif q == 2: return ctx._hyp2f2(a_s, b_s, z, **kwargs)
228
+ elif q == 3: return ctx._hyp2f3(a_s, b_s, z, **kwargs)
229
+ elif q == 0: return ctx._hyp2f0(a_s, b_s, z, **kwargs)
230
+ elif p == q+1:
231
+ return ctx._hypq1fq(p, q, a_s, b_s, z, **kwargs)
232
+ elif p > q+1 and not kwargs.get('force_series'):
233
+ return ctx._hyp_borel(p, q, a_s, b_s, z, **kwargs)
234
+ coeffs, types = zip(*(a_s+b_s))
235
+ return ctx.hypsum(p, q, types, coeffs, z, **kwargs)
236
+
237
+ @defun
238
+ def hyp0f1(ctx,b,z,**kwargs):
239
+ return ctx.hyper([],[b],z,**kwargs)
240
+
241
+ @defun
242
+ def hyp1f1(ctx,a,b,z,**kwargs):
243
+ return ctx.hyper([a],[b],z,**kwargs)
244
+
245
+ @defun
246
+ def hyp1f2(ctx,a1,b1,b2,z,**kwargs):
247
+ return ctx.hyper([a1],[b1,b2],z,**kwargs)
248
+
249
+ @defun
250
+ def hyp2f1(ctx,a,b,c,z,**kwargs):
251
+ return ctx.hyper([a,b],[c],z,**kwargs)
252
+
253
+ @defun
254
+ def hyp2f2(ctx,a1,a2,b1,b2,z,**kwargs):
255
+ return ctx.hyper([a1,a2],[b1,b2],z,**kwargs)
256
+
257
+ @defun
258
+ def hyp2f3(ctx,a1,a2,b1,b2,b3,z,**kwargs):
259
+ return ctx.hyper([a1,a2],[b1,b2,b3],z,**kwargs)
260
+
261
+ @defun
262
+ def hyp2f0(ctx,a,b,z,**kwargs):
263
+ return ctx.hyper([a,b],[],z,**kwargs)
264
+
265
+ @defun
266
+ def hyp3f2(ctx,a1,a2,a3,b1,b2,z,**kwargs):
267
+ return ctx.hyper([a1,a2,a3],[b1,b2],z,**kwargs)
268
+
269
+ @defun_wrapped
270
+ def _hyp1f0(ctx, a, z):
271
+ return (1-z) ** (-a)
272
+
273
+ @defun
274
+ def _hyp0f1(ctx, b_s, z, **kwargs):
275
+ (b, btype), = b_s
276
+ if z:
277
+ magz = ctx.mag(z)
278
+ else:
279
+ magz = 0
280
+ if magz >= 8 and not kwargs.get('force_series'):
281
+ try:
282
+ # http://functions.wolfram.com/HypergeometricFunctions/
283
+ # Hypergeometric0F1/06/02/03/0004/
284
+ # TODO: handle the all-real case more efficiently!
285
+ # TODO: figure out how much precision is needed (exponential growth)
286
+ orig = ctx.prec
287
+ try:
288
+ ctx.prec += 12 + magz//2
289
+ def h():
290
+ w = ctx.sqrt(-z)
291
+ jw = ctx.j*w
292
+ u = 1/(4*jw)
293
+ c = ctx.mpq_1_2 - b
294
+ E = ctx.exp(2*jw)
295
+ T1 = ([-jw,E], [c,-1], [], [], [b-ctx.mpq_1_2, ctx.mpq_3_2-b], [], -u)
296
+ T2 = ([jw,E], [c,1], [], [], [b-ctx.mpq_1_2, ctx.mpq_3_2-b], [], u)
297
+ return T1, T2
298
+ v = ctx.hypercomb(h, [], force_series=True)
299
+ v = ctx.gamma(b)/(2*ctx.sqrt(ctx.pi))*v
300
+ finally:
301
+ ctx.prec = orig
302
+ if ctx._is_real_type(b) and ctx._is_real_type(z):
303
+ v = ctx._re(v)
304
+ return +v
305
+ except ctx.NoConvergence:
306
+ pass
307
+ return ctx.hypsum(0, 1, (btype,), [b], z, **kwargs)
308
+
309
+ @defun
310
+ def _hyp1f1(ctx, a_s, b_s, z, **kwargs):
311
+ (a, atype), = a_s
312
+ (b, btype), = b_s
313
+ if not z:
314
+ return ctx.one+z
315
+ magz = ctx.mag(z)
316
+ if magz >= 7 and not (ctx.isint(a) and ctx.re(a) <= 0):
317
+ if ctx.isinf(z):
318
+ if ctx.sign(a) == ctx.sign(b) == ctx.sign(z) == 1:
319
+ return ctx.inf
320
+ return ctx.nan * z
321
+ try:
322
+ try:
323
+ ctx.prec += magz
324
+ sector = ctx._im(z) < 0
325
+ def h(a,b):
326
+ if sector:
327
+ E = ctx.expjpi(ctx.fneg(a, exact=True))
328
+ else:
329
+ E = ctx.expjpi(a)
330
+ rz = 1/z
331
+ T1 = ([E,z], [1,-a], [b], [b-a], [a, 1+a-b], [], -rz)
332
+ T2 = ([ctx.exp(z),z], [1,a-b], [b], [a], [b-a, 1-a], [], rz)
333
+ return T1, T2
334
+ v = ctx.hypercomb(h, [a,b], force_series=True)
335
+ if ctx._is_real_type(a) and ctx._is_real_type(b) and ctx._is_real_type(z):
336
+ v = ctx._re(v)
337
+ return +v
338
+ except ctx.NoConvergence:
339
+ pass
340
+ finally:
341
+ ctx.prec -= magz
342
+ v = ctx.hypsum(1, 1, (atype, btype), [a, b], z, **kwargs)
343
+ return v
344
+
345
+ def _hyp2f1_gosper(ctx,a,b,c,z,**kwargs):
346
+ # Use Gosper's recurrence
347
+ # See http://www.math.utexas.edu/pipermail/maxima/2006/000126.html
348
+ _a,_b,_c,_z = a, b, c, z
349
+ orig = ctx.prec
350
+ maxprec = kwargs.get('maxprec', 100*orig)
351
+ extra = 10
352
+ while 1:
353
+ ctx.prec = orig + extra
354
+ #a = ctx.convert(_a)
355
+ #b = ctx.convert(_b)
356
+ #c = ctx.convert(_c)
357
+ z = ctx.convert(_z)
358
+ d = ctx.mpf(0)
359
+ e = ctx.mpf(1)
360
+ f = ctx.mpf(0)
361
+ k = 0
362
+ # Common subexpression elimination, unfortunately making
363
+ # things a bit unreadable. The formula is quite messy to begin
364
+ # with, though...
365
+ abz = a*b*z
366
+ ch = c * ctx.mpq_1_2
367
+ c1h = (c+1) * ctx.mpq_1_2
368
+ nz = 1-z
369
+ g = z/nz
370
+ abg = a*b*g
371
+ cba = c-b-a
372
+ z2 = z-2
373
+ tol = -ctx.prec - 10
374
+ nstr = ctx.nstr
375
+ nprint = ctx.nprint
376
+ mag = ctx.mag
377
+ maxmag = ctx.ninf
378
+ while 1:
379
+ kch = k+ch
380
+ kakbz = (k+a)*(k+b)*z / (4*(k+1)*kch*(k+c1h))
381
+ d1 = kakbz*(e-(k+cba)*d*g)
382
+ e1 = kakbz*(d*abg+(k+c)*e)
383
+ ft = d*(k*(cba*z+k*z2-c)-abz)/(2*kch*nz)
384
+ f1 = f + e - ft
385
+ maxmag = max(maxmag, mag(f1))
386
+ if mag(f1-f) < tol:
387
+ break
388
+ d, e, f = d1, e1, f1
389
+ k += 1
390
+ cancellation = maxmag - mag(f1)
391
+ if cancellation < extra:
392
+ break
393
+ else:
394
+ extra += cancellation
395
+ if extra > maxprec:
396
+ raise ctx.NoConvergence
397
+ return f1
398
+
399
+ @defun
400
+ def _hyp2f1(ctx, a_s, b_s, z, **kwargs):
401
+ (a, atype), (b, btype) = a_s
402
+ (c, ctype), = b_s
403
+ if z == 1:
404
+ # TODO: the following logic can be simplified
405
+ convergent = ctx.re(c-a-b) > 0
406
+ finite = (ctx.isint(a) and a <= 0) or (ctx.isint(b) and b <= 0)
407
+ zerodiv = ctx.isint(c) and c <= 0 and not \
408
+ ((ctx.isint(a) and c <= a <= 0) or (ctx.isint(b) and c <= b <= 0))
409
+ #print "bz", a, b, c, z, convergent, finite, zerodiv
410
+ # Gauss's theorem gives the value if convergent
411
+ if (convergent or finite) and not zerodiv:
412
+ return ctx.gammaprod([c, c-a-b], [c-a, c-b], _infsign=True)
413
+ # Otherwise, there is a pole and we take the
414
+ # sign to be that when approaching from below
415
+ # XXX: this evaluation is not necessarily correct in all cases
416
+ return ctx.hyp2f1(a,b,c,1-ctx.eps*2) * ctx.inf
417
+
418
+ # Equal to 1 (first term), unless there is a subsequent
419
+ # division by zero
420
+ if not z:
421
+ # Division by zero but power of z is higher than
422
+ # first order so cancels
423
+ if c or a == 0 or b == 0:
424
+ return 1+z
425
+ # Indeterminate
426
+ return ctx.nan
427
+
428
+ # Hit zero denominator unless numerator goes to 0 first
429
+ if ctx.isint(c) and c <= 0:
430
+ if (ctx.isint(a) and c <= a <= 0) or \
431
+ (ctx.isint(b) and c <= b <= 0):
432
+ pass
433
+ else:
434
+ # Pole in series
435
+ return ctx.inf
436
+
437
+ absz = abs(z)
438
+
439
+ # Fast case: standard series converges rapidly,
440
+ # possibly in finitely many terms
441
+ if absz <= 0.8 or (ctx.isint(a) and a <= 0 and a >= -1000) or \
442
+ (ctx.isint(b) and b <= 0 and b >= -1000):
443
+ return ctx.hypsum(2, 1, (atype, btype, ctype), [a, b, c], z, **kwargs)
444
+
445
+ orig = ctx.prec
446
+ try:
447
+ ctx.prec += 10
448
+
449
+ # Use 1/z transformation
450
+ if absz >= 1.3:
451
+ def h(a,b):
452
+ t = ctx.mpq_1-c; ab = a-b; rz = 1/z
453
+ T1 = ([-z],[-a], [c,-ab],[b,c-a], [a,t+a],[ctx.mpq_1+ab], rz)
454
+ T2 = ([-z],[-b], [c,ab],[a,c-b], [b,t+b],[ctx.mpq_1-ab], rz)
455
+ return T1, T2
456
+ v = ctx.hypercomb(h, [a,b], **kwargs)
457
+
458
+ # Use 1-z transformation
459
+ elif abs(1-z) <= 0.75:
460
+ def h(a,b):
461
+ t = c-a-b; ca = c-a; cb = c-b; rz = 1-z
462
+ T1 = [], [], [c,t], [ca,cb], [a,b], [1-t], rz
463
+ T2 = [rz], [t], [c,a+b-c], [a,b], [ca,cb], [1+t], rz
464
+ return T1, T2
465
+ v = ctx.hypercomb(h, [a,b], **kwargs)
466
+
467
+ # Use z/(z-1) transformation
468
+ elif abs(z/(z-1)) <= 0.75:
469
+ v = ctx.hyp2f1(a, c-b, c, z/(z-1)) / (1-z)**a
470
+
471
+ # Remaining part of unit circle
472
+ else:
473
+ v = _hyp2f1_gosper(ctx,a,b,c,z,**kwargs)
474
+
475
+ finally:
476
+ ctx.prec = orig
477
+ return +v
478
+
479
+ @defun
480
+ def _hypq1fq(ctx, p, q, a_s, b_s, z, **kwargs):
481
+ r"""
482
+ Evaluates 3F2, 4F3, 5F4, ...
483
+ """
484
+ a_s, a_types = zip(*a_s)
485
+ b_s, b_types = zip(*b_s)
486
+ a_s = list(a_s)
487
+ b_s = list(b_s)
488
+ absz = abs(z)
489
+ ispoly = False
490
+ for a in a_s:
491
+ if ctx.isint(a) and a <= 0:
492
+ ispoly = True
493
+ break
494
+ # Direct summation
495
+ if absz < 1 or ispoly:
496
+ try:
497
+ return ctx.hypsum(p, q, a_types+b_types, a_s+b_s, z, **kwargs)
498
+ except ctx.NoConvergence:
499
+ if absz > 1.1 or ispoly:
500
+ raise
501
+ # Use expansion at |z-1| -> 0.
502
+ # Reference: Wolfgang Buhring, "Generalized Hypergeometric Functions at
503
+ # Unit Argument", Proc. Amer. Math. Soc., Vol. 114, No. 1 (Jan. 1992),
504
+ # pp.145-153
505
+ # The current implementation has several problems:
506
+ # 1. We only implement it for 3F2. The expansion coefficients are
507
+ # given by extremely messy nested sums in the higher degree cases
508
+ # (see reference). Is efficient sequential generation of the coefficients
509
+ # possible in the > 3F2 case?
510
+ # 2. Although the series converges, it may do so slowly, so we need
511
+ # convergence acceleration. The acceleration implemented by
512
+ # nsum does not always help, so results returned are sometimes
513
+ # inaccurate! Can we do better?
514
+ # 3. We should check conditions for convergence, and possibly
515
+ # do a better job of cancelling out gamma poles if possible.
516
+ if z == 1:
517
+ # XXX: should also check for division by zero in the
518
+ # denominator of the series (cf. hyp2f1)
519
+ S = ctx.re(sum(b_s)-sum(a_s))
520
+ if S <= 0:
521
+ #return ctx.hyper(a_s, b_s, 1-ctx.eps*2, **kwargs) * ctx.inf
522
+ return ctx.hyper(a_s, b_s, 0.9, **kwargs) * ctx.inf
523
+ if (p,q) == (3,2) and abs(z-1) < 0.05: # and kwargs.get('sum1')
524
+ #print "Using alternate summation (experimental)"
525
+ a1,a2,a3 = a_s
526
+ b1,b2 = b_s
527
+ u = b1+b2-a3
528
+ initial = ctx.gammaprod([b2-a3,b1-a3,a1,a2],[b2-a3,b1-a3,1,u])
529
+ def term(k, _cache={0:initial}):
530
+ u = b1+b2-a3+k
531
+ if k in _cache:
532
+ t = _cache[k]
533
+ else:
534
+ t = _cache[k-1]
535
+ t *= (b1+k-a3-1)*(b2+k-a3-1)
536
+ t /= k*(u-1)
537
+ _cache[k] = t
538
+ return t * ctx.hyp2f1(a1,a2,u,z)
539
+ try:
540
+ S = ctx.nsum(term, [0,ctx.inf], verbose=kwargs.get('verbose'),
541
+ strict=kwargs.get('strict', True))
542
+ return S * ctx.gammaprod([b1,b2],[a1,a2,a3])
543
+ except ctx.NoConvergence:
544
+ pass
545
+ # Try to use convergence acceleration on and close to the unit circle.
546
+ # Problem: the convergence acceleration degenerates as |z-1| -> 0,
547
+ # except for special cases. Everywhere else, the Shanks transformation
548
+ # is very efficient.
549
+ if absz < 1.1 and ctx._re(z) <= 1:
550
+
551
+ def term(kk, _cache={0:ctx.one}):
552
+ k = int(kk)
553
+ if k != kk:
554
+ t = z ** ctx.mpf(kk) / ctx.fac(kk)
555
+ for a in a_s: t *= ctx.rf(a,kk)
556
+ for b in b_s: t /= ctx.rf(b,kk)
557
+ return t
558
+ if k in _cache:
559
+ return _cache[k]
560
+ t = term(k-1)
561
+ m = k-1
562
+ for j in xrange(p): t *= (a_s[j]+m)
563
+ for j in xrange(q): t /= (b_s[j]+m)
564
+ t *= z
565
+ t /= k
566
+ _cache[k] = t
567
+ return t
568
+
569
+ sum_method = kwargs.get('sum_method', 'r+s+e')
570
+
571
+ try:
572
+ return ctx.nsum(term, [0,ctx.inf], verbose=kwargs.get('verbose'),
573
+ strict=kwargs.get('strict', True),
574
+ method=sum_method.replace('e',''))
575
+ except ctx.NoConvergence:
576
+ if 'e' not in sum_method:
577
+ raise
578
+ pass
579
+
580
+ if kwargs.get('verbose'):
581
+ print("Attempting Euler-Maclaurin summation")
582
+
583
+
584
+ """
585
+ Somewhat slower version (one diffs_exp for each factor).
586
+ However, this would be faster with fast direct derivatives
587
+ of the gamma function.
588
+
589
+ def power_diffs(k0):
590
+ r = 0
591
+ l = ctx.log(z)
592
+ while 1:
593
+ yield z**ctx.mpf(k0) * l**r
594
+ r += 1
595
+
596
+ def loggamma_diffs(x, reciprocal=False):
597
+ sign = (-1) ** reciprocal
598
+ yield sign * ctx.loggamma(x)
599
+ i = 0
600
+ while 1:
601
+ yield sign * ctx.psi(i,x)
602
+ i += 1
603
+
604
+ def hyper_diffs(k0):
605
+ b2 = b_s + [1]
606
+ A = [ctx.diffs_exp(loggamma_diffs(a+k0)) for a in a_s]
607
+ B = [ctx.diffs_exp(loggamma_diffs(b+k0,True)) for b in b2]
608
+ Z = [power_diffs(k0)]
609
+ C = ctx.gammaprod([b for b in b2], [a for a in a_s])
610
+ for d in ctx.diffs_prod(A + B + Z):
611
+ v = C * d
612
+ yield v
613
+ """
614
+
615
+ def log_diffs(k0):
616
+ b2 = b_s + [1]
617
+ yield sum(ctx.loggamma(a+k0) for a in a_s) - \
618
+ sum(ctx.loggamma(b+k0) for b in b2) + k0*ctx.log(z)
619
+ i = 0
620
+ while 1:
621
+ v = sum(ctx.psi(i,a+k0) for a in a_s) - \
622
+ sum(ctx.psi(i,b+k0) for b in b2)
623
+ if i == 0:
624
+ v += ctx.log(z)
625
+ yield v
626
+ i += 1
627
+
628
+ def hyper_diffs(k0):
629
+ C = ctx.gammaprod([b for b in b_s], [a for a in a_s])
630
+ for d in ctx.diffs_exp(log_diffs(k0)):
631
+ v = C * d
632
+ yield v
633
+
634
+ tol = ctx.eps / 1024
635
+ prec = ctx.prec
636
+ try:
637
+ trunc = 50 * ctx.dps
638
+ ctx.prec += 20
639
+ for i in xrange(5):
640
+ head = ctx.fsum(term(k) for k in xrange(trunc))
641
+ tail, err = ctx.sumem(term, [trunc, ctx.inf], tol=tol,
642
+ adiffs=hyper_diffs(trunc),
643
+ verbose=kwargs.get('verbose'),
644
+ error=True,
645
+ _fast_abort=True)
646
+ if err < tol:
647
+ v = head + tail
648
+ break
649
+ trunc *= 2
650
+ # Need to increase precision because calculation of
651
+ # derivatives may be inaccurate
652
+ ctx.prec += ctx.prec//2
653
+ if i == 4:
654
+ raise ctx.NoConvergence(\
655
+ "Euler-Maclaurin summation did not converge")
656
+ finally:
657
+ ctx.prec = prec
658
+ return +v
659
+
660
+ # Use 1/z transformation
661
+ # http://functions.wolfram.com/HypergeometricFunctions/
662
+ # HypergeometricPFQ/06/01/05/02/0004/
663
+ def h(*args):
664
+ a_s = list(args[:p])
665
+ b_s = list(args[p:])
666
+ Ts = []
667
+ recz = ctx.one/z
668
+ negz = ctx.fneg(z, exact=True)
669
+ for k in range(q+1):
670
+ ak = a_s[k]
671
+ C = [negz]
672
+ Cp = [-ak]
673
+ Gn = b_s + [ak] + [a_s[j]-ak for j in range(q+1) if j != k]
674
+ Gd = a_s + [b_s[j]-ak for j in range(q)]
675
+ Fn = [ak] + [ak-b_s[j]+1 for j in range(q)]
676
+ Fd = [1-a_s[j]+ak for j in range(q+1) if j != k]
677
+ Ts.append((C, Cp, Gn, Gd, Fn, Fd, recz))
678
+ return Ts
679
+ return ctx.hypercomb(h, a_s+b_s, **kwargs)
680
+
681
+ @defun
682
+ def _hyp_borel(ctx, p, q, a_s, b_s, z, **kwargs):
683
+ if a_s:
684
+ a_s, a_types = zip(*a_s)
685
+ a_s = list(a_s)
686
+ else:
687
+ a_s, a_types = [], ()
688
+ if b_s:
689
+ b_s, b_types = zip(*b_s)
690
+ b_s = list(b_s)
691
+ else:
692
+ b_s, b_types = [], ()
693
+ kwargs['maxterms'] = kwargs.get('maxterms', ctx.prec)
694
+ try:
695
+ return ctx.hypsum(p, q, a_types+b_types, a_s+b_s, z, **kwargs)
696
+ except ctx.NoConvergence:
697
+ pass
698
+ prec = ctx.prec
699
+ try:
700
+ tol = kwargs.get('asymp_tol', ctx.eps/4)
701
+ ctx.prec += 10
702
+ # hypsum is has a conservative tolerance. So we try again:
703
+ def term(k, cache={0:ctx.one}):
704
+ if k in cache:
705
+ return cache[k]
706
+ t = term(k-1)
707
+ for a in a_s: t *= (a+(k-1))
708
+ for b in b_s: t /= (b+(k-1))
709
+ t *= z
710
+ t /= k
711
+ cache[k] = t
712
+ return t
713
+ s = ctx.one
714
+ for k in xrange(1, ctx.prec):
715
+ t = term(k)
716
+ s += t
717
+ if abs(t) <= tol:
718
+ return s
719
+ finally:
720
+ ctx.prec = prec
721
+ if p <= q+3:
722
+ contour = kwargs.get('contour')
723
+ if not contour:
724
+ if ctx.arg(z) < 0.25:
725
+ u = z / max(1, abs(z))
726
+ if ctx.arg(z) >= 0:
727
+ contour = [0, 2j, (2j+2)/u, 2/u, ctx.inf]
728
+ else:
729
+ contour = [0, -2j, (-2j+2)/u, 2/u, ctx.inf]
730
+ #contour = [0, 2j/z, 2/z, ctx.inf]
731
+ #contour = [0, 2j, 2/z, ctx.inf]
732
+ #contour = [0, 2j, ctx.inf]
733
+ else:
734
+ contour = [0, ctx.inf]
735
+ quad_kwargs = kwargs.get('quad_kwargs', {})
736
+ def g(t):
737
+ return ctx.exp(-t)*ctx.hyper(a_s, b_s+[1], t*z)
738
+ I, err = ctx.quad(g, contour, error=True, **quad_kwargs)
739
+ if err <= abs(I)*ctx.eps*8:
740
+ return I
741
+ raise ctx.NoConvergence
742
+
743
+
744
+ @defun
745
+ def _hyp2f2(ctx, a_s, b_s, z, **kwargs):
746
+ (a1, a1type), (a2, a2type) = a_s
747
+ (b1, b1type), (b2, b2type) = b_s
748
+
749
+ absz = abs(z)
750
+ magz = ctx.mag(z)
751
+ orig = ctx.prec
752
+
753
+ # Asymptotic expansion is ~ exp(z)
754
+ asymp_extraprec = magz
755
+
756
+ # Asymptotic series is in terms of 3F1
757
+ can_use_asymptotic = (not kwargs.get('force_series')) and \
758
+ (ctx.mag(absz) > 3)
759
+
760
+ # TODO: much of the following could be shared with 2F3 instead of
761
+ # copypasted
762
+ if can_use_asymptotic:
763
+ #print "using asymp"
764
+ try:
765
+ try:
766
+ ctx.prec += asymp_extraprec
767
+ # http://functions.wolfram.com/HypergeometricFunctions/
768
+ # Hypergeometric2F2/06/02/02/0002/
769
+ def h(a1,a2,b1,b2):
770
+ X = a1+a2-b1-b2
771
+ A2 = a1+a2
772
+ B2 = b1+b2
773
+ c = {}
774
+ c[0] = ctx.one
775
+ c[1] = (A2-1)*X+b1*b2-a1*a2
776
+ s1 = 0
777
+ k = 0
778
+ tprev = 0
779
+ while 1:
780
+ if k not in c:
781
+ uu1 = 1-B2+2*a1+a1**2+2*a2+a2**2-A2*B2+a1*a2+b1*b2+(2*B2-3*(A2+1))*k+2*k**2
782
+ uu2 = (k-A2+b1-1)*(k-A2+b2-1)*(k-X-2)
783
+ c[k] = ctx.one/k * (uu1*c[k-1]-uu2*c[k-2])
784
+ t1 = c[k] * z**(-k)
785
+ if abs(t1) < 0.1*ctx.eps:
786
+ #print "Convergence :)"
787
+ break
788
+ # Quit if the series doesn't converge quickly enough
789
+ if k > 5 and abs(tprev) / abs(t1) < 1.5:
790
+ #print "No convergence :("
791
+ raise ctx.NoConvergence
792
+ s1 += t1
793
+ tprev = t1
794
+ k += 1
795
+ S = ctx.exp(z)*s1
796
+ T1 = [z,S], [X,1], [b1,b2],[a1,a2],[],[],0
797
+ T2 = [-z],[-a1],[b1,b2,a2-a1],[a2,b1-a1,b2-a1],[a1,a1-b1+1,a1-b2+1],[a1-a2+1],-1/z
798
+ T3 = [-z],[-a2],[b1,b2,a1-a2],[a1,b1-a2,b2-a2],[a2,a2-b1+1,a2-b2+1],[-a1+a2+1],-1/z
799
+ return T1, T2, T3
800
+ v = ctx.hypercomb(h, [a1,a2,b1,b2], force_series=True, maxterms=4*ctx.prec)
801
+ if sum(ctx._is_real_type(u) for u in [a1,a2,b1,b2,z]) == 5:
802
+ v = ctx.re(v)
803
+ return v
804
+ except ctx.NoConvergence:
805
+ pass
806
+ finally:
807
+ ctx.prec = orig
808
+
809
+ return ctx.hypsum(2, 2, (a1type, a2type, b1type, b2type), [a1, a2, b1, b2], z, **kwargs)
810
+
811
+
812
+
813
+ @defun
814
+ def _hyp1f2(ctx, a_s, b_s, z, **kwargs):
815
+ (a1, a1type), = a_s
816
+ (b1, b1type), (b2, b2type) = b_s
817
+
818
+ absz = abs(z)
819
+ magz = ctx.mag(z)
820
+ orig = ctx.prec
821
+
822
+ # Asymptotic expansion is ~ exp(sqrt(z))
823
+ asymp_extraprec = z and magz//2
824
+
825
+ # Asymptotic series is in terms of 3F0
826
+ can_use_asymptotic = (not kwargs.get('force_series')) and \
827
+ (ctx.mag(absz) > 19) and \
828
+ (ctx.sqrt(absz) > 1.5*orig) # and \
829
+ # ctx._hyp_check_convergence([a1, a1-b1+1, a1-b2+1], [],
830
+ # 1/absz, orig+40+asymp_extraprec)
831
+
832
+ # TODO: much of the following could be shared with 2F3 instead of
833
+ # copypasted
834
+ if can_use_asymptotic:
835
+ #print "using asymp"
836
+ try:
837
+ try:
838
+ ctx.prec += asymp_extraprec
839
+ # http://functions.wolfram.com/HypergeometricFunctions/
840
+ # Hypergeometric1F2/06/02/03/
841
+ def h(a1,b1,b2):
842
+ X = ctx.mpq_1_2*(a1-b1-b2+ctx.mpq_1_2)
843
+ c = {}
844
+ c[0] = ctx.one
845
+ c[1] = 2*(ctx.mpq_1_4*(3*a1+b1+b2-2)*(a1-b1-b2)+b1*b2-ctx.mpq_3_16)
846
+ c[2] = 2*(b1*b2+ctx.mpq_1_4*(a1-b1-b2)*(3*a1+b1+b2-2)-ctx.mpq_3_16)**2+\
847
+ ctx.mpq_1_16*(-16*(2*a1-3)*b1*b2 + \
848
+ 4*(a1-b1-b2)*(-8*a1**2+11*a1+b1+b2-2)-3)
849
+ s1 = 0
850
+ s2 = 0
851
+ k = 0
852
+ tprev = 0
853
+ while 1:
854
+ if k not in c:
855
+ uu1 = (3*k**2+(-6*a1+2*b1+2*b2-4)*k + 3*a1**2 - \
856
+ (b1-b2)**2 - 2*a1*(b1+b2-2) + ctx.mpq_1_4)
857
+ uu2 = (k-a1+b1-b2-ctx.mpq_1_2)*(k-a1-b1+b2-ctx.mpq_1_2)*\
858
+ (k-a1+b1+b2-ctx.mpq_5_2)
859
+ c[k] = ctx.one/(2*k)*(uu1*c[k-1]-uu2*c[k-2])
860
+ w = c[k] * (-z)**(-0.5*k)
861
+ t1 = (-ctx.j)**k * ctx.mpf(2)**(-k) * w
862
+ t2 = ctx.j**k * ctx.mpf(2)**(-k) * w
863
+ if abs(t1) < 0.1*ctx.eps:
864
+ #print "Convergence :)"
865
+ break
866
+ # Quit if the series doesn't converge quickly enough
867
+ if k > 5 and abs(tprev) / abs(t1) < 1.5:
868
+ #print "No convergence :("
869
+ raise ctx.NoConvergence
870
+ s1 += t1
871
+ s2 += t2
872
+ tprev = t1
873
+ k += 1
874
+ S = ctx.expj(ctx.pi*X+2*ctx.sqrt(-z))*s1 + \
875
+ ctx.expj(-(ctx.pi*X+2*ctx.sqrt(-z)))*s2
876
+ T1 = [0.5*S, ctx.pi, -z], [1, -0.5, X], [b1, b2], [a1],\
877
+ [], [], 0
878
+ T2 = [-z], [-a1], [b1,b2],[b1-a1,b2-a1], \
879
+ [a1,a1-b1+1,a1-b2+1], [], 1/z
880
+ return T1, T2
881
+ v = ctx.hypercomb(h, [a1,b1,b2], force_series=True, maxterms=4*ctx.prec)
882
+ if sum(ctx._is_real_type(u) for u in [a1,b1,b2,z]) == 4:
883
+ v = ctx.re(v)
884
+ return v
885
+ except ctx.NoConvergence:
886
+ pass
887
+ finally:
888
+ ctx.prec = orig
889
+
890
+ #print "not using asymp"
891
+ return ctx.hypsum(1, 2, (a1type, b1type, b2type), [a1, b1, b2], z, **kwargs)
892
+
893
+
894
+
895
+ @defun
896
+ def _hyp2f3(ctx, a_s, b_s, z, **kwargs):
897
+ (a1, a1type), (a2, a2type) = a_s
898
+ (b1, b1type), (b2, b2type), (b3, b3type) = b_s
899
+
900
+ absz = abs(z)
901
+ magz = ctx.mag(z)
902
+
903
+ # Asymptotic expansion is ~ exp(sqrt(z))
904
+ asymp_extraprec = z and magz//2
905
+ orig = ctx.prec
906
+
907
+ # Asymptotic series is in terms of 4F1
908
+ # The square root below empirically provides a plausible criterion
909
+ # for the leading series to converge
910
+ can_use_asymptotic = (not kwargs.get('force_series')) and \
911
+ (ctx.mag(absz) > 19) and (ctx.sqrt(absz) > 1.5*orig)
912
+
913
+ if can_use_asymptotic:
914
+ #print "using asymp"
915
+ try:
916
+ try:
917
+ ctx.prec += asymp_extraprec
918
+ # http://functions.wolfram.com/HypergeometricFunctions/
919
+ # Hypergeometric2F3/06/02/03/01/0002/
920
+ def h(a1,a2,b1,b2,b3):
921
+ X = ctx.mpq_1_2*(a1+a2-b1-b2-b3+ctx.mpq_1_2)
922
+ A2 = a1+a2
923
+ B3 = b1+b2+b3
924
+ A = a1*a2
925
+ B = b1*b2+b3*b2+b1*b3
926
+ R = b1*b2*b3
927
+ c = {}
928
+ c[0] = ctx.one
929
+ c[1] = 2*(B - A + ctx.mpq_1_4*(3*A2+B3-2)*(A2-B3) - ctx.mpq_3_16)
930
+ c[2] = ctx.mpq_1_2*c[1]**2 + ctx.mpq_1_16*(-16*(2*A2-3)*(B-A) + 32*R +\
931
+ 4*(-8*A2**2 + 11*A2 + 8*A + B3 - 2)*(A2-B3)-3)
932
+ s1 = 0
933
+ s2 = 0
934
+ k = 0
935
+ tprev = 0
936
+ while 1:
937
+ if k not in c:
938
+ uu1 = (k-2*X-3)*(k-2*X-2*b1-1)*(k-2*X-2*b2-1)*\
939
+ (k-2*X-2*b3-1)
940
+ uu2 = (4*(k-1)**3 - 6*(4*X+B3)*(k-1)**2 + \
941
+ 2*(24*X**2+12*B3*X+4*B+B3-1)*(k-1) - 32*X**3 - \
942
+ 24*B3*X**2 - 4*B - 8*R - 4*(4*B+B3-1)*X + 2*B3-1)
943
+ uu3 = (5*(k-1)**2+2*(-10*X+A2-3*B3+3)*(k-1)+2*c[1])
944
+ c[k] = ctx.one/(2*k)*(uu1*c[k-3]-uu2*c[k-2]+uu3*c[k-1])
945
+ w = c[k] * ctx.power(-z, -0.5*k)
946
+ t1 = (-ctx.j)**k * ctx.mpf(2)**(-k) * w
947
+ t2 = ctx.j**k * ctx.mpf(2)**(-k) * w
948
+ if abs(t1) < 0.1*ctx.eps:
949
+ break
950
+ # Quit if the series doesn't converge quickly enough
951
+ if k > 5 and abs(tprev) / abs(t1) < 1.5:
952
+ raise ctx.NoConvergence
953
+ s1 += t1
954
+ s2 += t2
955
+ tprev = t1
956
+ k += 1
957
+ S = ctx.expj(ctx.pi*X+2*ctx.sqrt(-z))*s1 + \
958
+ ctx.expj(-(ctx.pi*X+2*ctx.sqrt(-z)))*s2
959
+ T1 = [0.5*S, ctx.pi, -z], [1, -0.5, X], [b1, b2, b3], [a1, a2],\
960
+ [], [], 0
961
+ T2 = [-z], [-a1], [b1,b2,b3,a2-a1],[a2,b1-a1,b2-a1,b3-a1], \
962
+ [a1,a1-b1+1,a1-b2+1,a1-b3+1], [a1-a2+1], 1/z
963
+ T3 = [-z], [-a2], [b1,b2,b3,a1-a2],[a1,b1-a2,b2-a2,b3-a2], \
964
+ [a2,a2-b1+1,a2-b2+1,a2-b3+1],[-a1+a2+1], 1/z
965
+ return T1, T2, T3
966
+ v = ctx.hypercomb(h, [a1,a2,b1,b2,b3], force_series=True, maxterms=4*ctx.prec)
967
+ if sum(ctx._is_real_type(u) for u in [a1,a2,b1,b2,b3,z]) == 6:
968
+ v = ctx.re(v)
969
+ return v
970
+ except ctx.NoConvergence:
971
+ pass
972
+ finally:
973
+ ctx.prec = orig
974
+
975
+ return ctx.hypsum(2, 3, (a1type, a2type, b1type, b2type, b3type), [a1, a2, b1, b2, b3], z, **kwargs)
976
+
977
+ @defun
978
+ def _hyp2f0(ctx, a_s, b_s, z, **kwargs):
979
+ (a, atype), (b, btype) = a_s
980
+ # We want to try aggressively to use the asymptotic expansion,
981
+ # and fall back only when absolutely necessary
982
+ try:
983
+ kwargsb = kwargs.copy()
984
+ kwargsb['maxterms'] = kwargsb.get('maxterms', ctx.prec)
985
+ return ctx.hypsum(2, 0, (atype,btype), [a,b], z, **kwargsb)
986
+ except ctx.NoConvergence:
987
+ if kwargs.get('force_series'):
988
+ raise
989
+ pass
990
+ def h(a, b):
991
+ w = ctx.sinpi(b)
992
+ rz = -1/z
993
+ T1 = ([ctx.pi,w,rz],[1,-1,a],[],[a-b+1,b],[a],[b],rz)
994
+ T2 = ([-ctx.pi,w,rz],[1,-1,1+a-b],[],[a,2-b],[a-b+1],[2-b],rz)
995
+ return T1, T2
996
+ return ctx.hypercomb(h, [a, 1+a-b], **kwargs)
997
+
998
+ @defun
999
+ def meijerg(ctx, a_s, b_s, z, r=1, series=None, **kwargs):
1000
+ an, ap = a_s
1001
+ bm, bq = b_s
1002
+ n = len(an)
1003
+ p = n + len(ap)
1004
+ m = len(bm)
1005
+ q = m + len(bq)
1006
+ a = an+ap
1007
+ b = bm+bq
1008
+ a = [ctx.convert(_) for _ in a]
1009
+ b = [ctx.convert(_) for _ in b]
1010
+ z = ctx.convert(z)
1011
+ if series is None:
1012
+ if p < q: series = 1
1013
+ if p > q: series = 2
1014
+ if p == q:
1015
+ if m+n == p and abs(z) > 1:
1016
+ series = 2
1017
+ else:
1018
+ series = 1
1019
+ if kwargs.get('verbose'):
1020
+ print("Meijer G m,n,p,q,series =", m,n,p,q,series)
1021
+ if series == 1:
1022
+ def h(*args):
1023
+ a = args[:p]
1024
+ b = args[p:]
1025
+ terms = []
1026
+ for k in range(m):
1027
+ bases = [z]
1028
+ expts = [b[k]/r]
1029
+ gn = [b[j]-b[k] for j in range(m) if j != k]
1030
+ gn += [1-a[j]+b[k] for j in range(n)]
1031
+ gd = [a[j]-b[k] for j in range(n,p)]
1032
+ gd += [1-b[j]+b[k] for j in range(m,q)]
1033
+ hn = [1-a[j]+b[k] for j in range(p)]
1034
+ hd = [1-b[j]+b[k] for j in range(q) if j != k]
1035
+ hz = (-ctx.one)**(p-m-n) * z**(ctx.one/r)
1036
+ terms.append((bases, expts, gn, gd, hn, hd, hz))
1037
+ return terms
1038
+ else:
1039
+ def h(*args):
1040
+ a = args[:p]
1041
+ b = args[p:]
1042
+ terms = []
1043
+ for k in range(n):
1044
+ bases = [z]
1045
+ if r == 1:
1046
+ expts = [a[k]-1]
1047
+ else:
1048
+ expts = [(a[k]-1)/ctx.convert(r)]
1049
+ gn = [a[k]-a[j] for j in range(n) if j != k]
1050
+ gn += [1-a[k]+b[j] for j in range(m)]
1051
+ gd = [a[k]-b[j] for j in range(m,q)]
1052
+ gd += [1-a[k]+a[j] for j in range(n,p)]
1053
+ hn = [1-a[k]+b[j] for j in range(q)]
1054
+ hd = [1+a[j]-a[k] for j in range(p) if j != k]
1055
+ hz = (-ctx.one)**(q-m-n) / z**(ctx.one/r)
1056
+ terms.append((bases, expts, gn, gd, hn, hd, hz))
1057
+ return terms
1058
+ return ctx.hypercomb(h, a+b, **kwargs)
1059
+
1060
+ @defun_wrapped
1061
+ def appellf1(ctx,a,b1,b2,c,x,y,**kwargs):
1062
+ # Assume x smaller
1063
+ # We will use x for the outer loop
1064
+ if abs(x) > abs(y):
1065
+ x, y = y, x
1066
+ b1, b2 = b2, b1
1067
+ def ok(x):
1068
+ return abs(x) < 0.99
1069
+ # Finite cases
1070
+ if ctx.isnpint(a):
1071
+ pass
1072
+ elif ctx.isnpint(b1):
1073
+ pass
1074
+ elif ctx.isnpint(b2):
1075
+ x, y, b1, b2 = y, x, b2, b1
1076
+ else:
1077
+ #print x, y
1078
+ # Note: ok if |y| > 1, because
1079
+ # 2F1 implements analytic continuation
1080
+ if not ok(x):
1081
+ u1 = (x-y)/(x-1)
1082
+ if not ok(u1):
1083
+ raise ValueError("Analytic continuation not implemented")
1084
+ #print "Using analytic continuation"
1085
+ return (1-x)**(-b1)*(1-y)**(c-a-b2)*\
1086
+ ctx.appellf1(c-a,b1,c-b1-b2,c,u1,y,**kwargs)
1087
+ return ctx.hyper2d({'m+n':[a],'m':[b1],'n':[b2]}, {'m+n':[c]}, x,y, **kwargs)
1088
+
1089
+ @defun
1090
+ def appellf2(ctx,a,b1,b2,c1,c2,x,y,**kwargs):
1091
+ # TODO: continuation
1092
+ return ctx.hyper2d({'m+n':[a],'m':[b1],'n':[b2]},
1093
+ {'m':[c1],'n':[c2]}, x,y, **kwargs)
1094
+
1095
+ @defun
1096
+ def appellf3(ctx,a1,a2,b1,b2,c,x,y,**kwargs):
1097
+ outer_polynomial = ctx.isnpint(a1) or ctx.isnpint(b1)
1098
+ inner_polynomial = ctx.isnpint(a2) or ctx.isnpint(b2)
1099
+ if not outer_polynomial:
1100
+ if inner_polynomial or abs(x) > abs(y):
1101
+ x, y = y, x
1102
+ a1,a2,b1,b2 = a2,a1,b2,b1
1103
+ return ctx.hyper2d({'m':[a1,b1],'n':[a2,b2]}, {'m+n':[c]},x,y,**kwargs)
1104
+
1105
+ @defun
1106
+ def appellf4(ctx,a,b,c1,c2,x,y,**kwargs):
1107
+ # TODO: continuation
1108
+ return ctx.hyper2d({'m+n':[a,b]}, {'m':[c1],'n':[c2]},x,y,**kwargs)
1109
+
1110
+ @defun
1111
+ def hyper2d(ctx, a, b, x, y, **kwargs):
1112
+ r"""
1113
+ Sums the generalized 2D hypergeometric series
1114
+
1115
+ .. math ::
1116
+
1117
+ \sum_{m=0}^{\infty} \sum_{n=0}^{\infty}
1118
+ \frac{P((a),m,n)}{Q((b),m,n)}
1119
+ \frac{x^m y^n} {m! n!}
1120
+
1121
+ where `(a) = (a_1,\ldots,a_r)`, `(b) = (b_1,\ldots,b_s)` and where
1122
+ `P` and `Q` are products of rising factorials such as `(a_j)_n` or
1123
+ `(a_j)_{m+n}`. `P` and `Q` are specified in the form of dicts, with
1124
+ the `m` and `n` dependence as keys and parameter lists as values.
1125
+ The supported rising factorials are given in the following table
1126
+ (note that only a few are supported in `Q`):
1127
+
1128
+ +------------+-------------------+--------+
1129
+ | Key | Rising factorial | `Q` |
1130
+ +============+===================+========+
1131
+ | ``'m'`` | `(a_j)_m` | Yes |
1132
+ +------------+-------------------+--------+
1133
+ | ``'n'`` | `(a_j)_n` | Yes |
1134
+ +------------+-------------------+--------+
1135
+ | ``'m+n'`` | `(a_j)_{m+n}` | Yes |
1136
+ +------------+-------------------+--------+
1137
+ | ``'m-n'`` | `(a_j)_{m-n}` | No |
1138
+ +------------+-------------------+--------+
1139
+ | ``'n-m'`` | `(a_j)_{n-m}` | No |
1140
+ +------------+-------------------+--------+
1141
+ | ``'2m+n'`` | `(a_j)_{2m+n}` | No |
1142
+ +------------+-------------------+--------+
1143
+ | ``'2m-n'`` | `(a_j)_{2m-n}` | No |
1144
+ +------------+-------------------+--------+
1145
+ | ``'2n-m'`` | `(a_j)_{2n-m}` | No |
1146
+ +------------+-------------------+--------+
1147
+
1148
+ For example, the Appell F1 and F4 functions
1149
+
1150
+ .. math ::
1151
+
1152
+ F_1 = \sum_{m=0}^{\infty} \sum_{n=0}^{\infty}
1153
+ \frac{(a)_{m+n} (b)_m (c)_n}{(d)_{m+n}}
1154
+ \frac{x^m y^n}{m! n!}
1155
+
1156
+ F_4 = \sum_{m=0}^{\infty} \sum_{n=0}^{\infty}
1157
+ \frac{(a)_{m+n} (b)_{m+n}}{(c)_m (d)_{n}}
1158
+ \frac{x^m y^n}{m! n!}
1159
+
1160
+ can be represented respectively as
1161
+
1162
+ ``hyper2d({'m+n':[a], 'm':[b], 'n':[c]}, {'m+n':[d]}, x, y)``
1163
+
1164
+ ``hyper2d({'m+n':[a,b]}, {'m':[c], 'n':[d]}, x, y)``
1165
+
1166
+ More generally, :func:`~mpmath.hyper2d` can evaluate any of the 34 distinct
1167
+ convergent second-order (generalized Gaussian) hypergeometric
1168
+ series enumerated by Horn, as well as the Kampe de Feriet
1169
+ function.
1170
+
1171
+ The series is computed by rewriting it so that the inner
1172
+ series (i.e. the series containing `n` and `y`) has the form of an
1173
+ ordinary generalized hypergeometric series and thereby can be
1174
+ evaluated efficiently using :func:`~mpmath.hyper`. If possible,
1175
+ manually swapping `x` and `y` and the corresponding parameters
1176
+ can sometimes give better results.
1177
+
1178
+ **Examples**
1179
+
1180
+ Two separable cases: a product of two geometric series, and a
1181
+ product of two Gaussian hypergeometric functions::
1182
+
1183
+ >>> from mpmath import *
1184
+ >>> mp.dps = 25; mp.pretty = True
1185
+ >>> x, y = mpf(0.25), mpf(0.5)
1186
+ >>> hyper2d({'m':1,'n':1}, {}, x,y)
1187
+ 2.666666666666666666666667
1188
+ >>> 1/(1-x)/(1-y)
1189
+ 2.666666666666666666666667
1190
+ >>> hyper2d({'m':[1,2],'n':[3,4]}, {'m':[5],'n':[6]}, x,y)
1191
+ 4.164358531238938319669856
1192
+ >>> hyp2f1(1,2,5,x)*hyp2f1(3,4,6,y)
1193
+ 4.164358531238938319669856
1194
+
1195
+ Some more series that can be done in closed form::
1196
+
1197
+ >>> hyper2d({'m':1,'n':1},{'m+n':1},x,y)
1198
+ 2.013417124712514809623881
1199
+ >>> (exp(x)*x-exp(y)*y)/(x-y)
1200
+ 2.013417124712514809623881
1201
+
1202
+ Six of the 34 Horn functions, G1-G3 and H1-H3::
1203
+
1204
+ >>> from mpmath import *
1205
+ >>> mp.dps = 10; mp.pretty = True
1206
+ >>> x, y = 0.0625, 0.125
1207
+ >>> a1,a2,b1,b2,c1,c2,d = 1.1,-1.2,-1.3,-1.4,1.5,-1.6,1.7
1208
+ >>> hyper2d({'m+n':a1,'n-m':b1,'m-n':b2},{},x,y) # G1
1209
+ 1.139090746
1210
+ >>> nsum(lambda m,n: rf(a1,m+n)*rf(b1,n-m)*rf(b2,m-n)*\
1211
+ ... x**m*y**n/fac(m)/fac(n), [0,inf], [0,inf])
1212
+ 1.139090746
1213
+ >>> hyper2d({'m':a1,'n':a2,'n-m':b1,'m-n':b2},{},x,y) # G2
1214
+ 0.9503682696
1215
+ >>> nsum(lambda m,n: rf(a1,m)*rf(a2,n)*rf(b1,n-m)*rf(b2,m-n)*\
1216
+ ... x**m*y**n/fac(m)/fac(n), [0,inf], [0,inf])
1217
+ 0.9503682696
1218
+ >>> hyper2d({'2n-m':a1,'2m-n':a2},{},x,y) # G3
1219
+ 1.029372029
1220
+ >>> nsum(lambda m,n: rf(a1,2*n-m)*rf(a2,2*m-n)*\
1221
+ ... x**m*y**n/fac(m)/fac(n), [0,inf], [0,inf])
1222
+ 1.029372029
1223
+ >>> hyper2d({'m-n':a1,'m+n':b1,'n':c1},{'m':d},x,y) # H1
1224
+ -1.605331256
1225
+ >>> nsum(lambda m,n: rf(a1,m-n)*rf(b1,m+n)*rf(c1,n)/rf(d,m)*\
1226
+ ... x**m*y**n/fac(m)/fac(n), [0,inf], [0,inf])
1227
+ -1.605331256
1228
+ >>> hyper2d({'m-n':a1,'m':b1,'n':[c1,c2]},{'m':d},x,y) # H2
1229
+ -2.35405404
1230
+ >>> nsum(lambda m,n: rf(a1,m-n)*rf(b1,m)*rf(c1,n)*rf(c2,n)/rf(d,m)*\
1231
+ ... x**m*y**n/fac(m)/fac(n), [0,inf], [0,inf])
1232
+ -2.35405404
1233
+ >>> hyper2d({'2m+n':a1,'n':b1},{'m+n':c1},x,y) # H3
1234
+ 0.974479074
1235
+ >>> nsum(lambda m,n: rf(a1,2*m+n)*rf(b1,n)/rf(c1,m+n)*\
1236
+ ... x**m*y**n/fac(m)/fac(n), [0,inf], [0,inf])
1237
+ 0.974479074
1238
+
1239
+ **References**
1240
+
1241
+ 1. [SrivastavaKarlsson]_
1242
+ 2. [Weisstein]_ http://mathworld.wolfram.com/HornFunction.html
1243
+ 3. [Weisstein]_ http://mathworld.wolfram.com/AppellHypergeometricFunction.html
1244
+
1245
+ """
1246
+ x = ctx.convert(x)
1247
+ y = ctx.convert(y)
1248
+ def parse(dct, key):
1249
+ args = dct.pop(key, [])
1250
+ try:
1251
+ args = list(args)
1252
+ except TypeError:
1253
+ args = [args]
1254
+ return [ctx.convert(arg) for arg in args]
1255
+ a_s = dict(a)
1256
+ b_s = dict(b)
1257
+ a_m = parse(a, 'm')
1258
+ a_n = parse(a, 'n')
1259
+ a_m_add_n = parse(a, 'm+n')
1260
+ a_m_sub_n = parse(a, 'm-n')
1261
+ a_n_sub_m = parse(a, 'n-m')
1262
+ a_2m_add_n = parse(a, '2m+n')
1263
+ a_2m_sub_n = parse(a, '2m-n')
1264
+ a_2n_sub_m = parse(a, '2n-m')
1265
+ b_m = parse(b, 'm')
1266
+ b_n = parse(b, 'n')
1267
+ b_m_add_n = parse(b, 'm+n')
1268
+ if a: raise ValueError("unsupported key: %r" % a.keys()[0])
1269
+ if b: raise ValueError("unsupported key: %r" % b.keys()[0])
1270
+ s = 0
1271
+ outer = ctx.one
1272
+ m = ctx.mpf(0)
1273
+ ok_count = 0
1274
+ prec = ctx.prec
1275
+ maxterms = kwargs.get('maxterms', 20*prec)
1276
+ try:
1277
+ ctx.prec += 10
1278
+ tol = +ctx.eps
1279
+ while 1:
1280
+ inner_sign = 1
1281
+ outer_sign = 1
1282
+ inner_a = list(a_n)
1283
+ inner_b = list(b_n)
1284
+ outer_a = [a+m for a in a_m]
1285
+ outer_b = [b+m for b in b_m]
1286
+ # (a)_{m+n} = (a)_m (a+m)_n
1287
+ for a in a_m_add_n:
1288
+ a = a+m
1289
+ inner_a.append(a)
1290
+ outer_a.append(a)
1291
+ # (b)_{m+n} = (b)_m (b+m)_n
1292
+ for b in b_m_add_n:
1293
+ b = b+m
1294
+ inner_b.append(b)
1295
+ outer_b.append(b)
1296
+ # (a)_{n-m} = (a-m)_n / (a-m)_m
1297
+ for a in a_n_sub_m:
1298
+ inner_a.append(a-m)
1299
+ outer_b.append(a-m-1)
1300
+ # (a)_{m-n} = (-1)^(m+n) (1-a-m)_m / (1-a-m)_n
1301
+ for a in a_m_sub_n:
1302
+ inner_sign *= (-1)
1303
+ outer_sign *= (-1)**(m)
1304
+ inner_b.append(1-a-m)
1305
+ outer_a.append(-a-m)
1306
+ # (a)_{2m+n} = (a)_{2m} (a+2m)_n
1307
+ for a in a_2m_add_n:
1308
+ inner_a.append(a+2*m)
1309
+ outer_a.append((a+2*m)*(1+a+2*m))
1310
+ # (a)_{2m-n} = (-1)^(2m+n) (1-a-2m)_{2m} / (1-a-2m)_n
1311
+ for a in a_2m_sub_n:
1312
+ inner_sign *= (-1)
1313
+ inner_b.append(1-a-2*m)
1314
+ outer_a.append((a+2*m)*(1+a+2*m))
1315
+ # (a)_{2n-m} = 4^n ((a-m)/2)_n ((a-m+1)/2)_n / (a-m)_m
1316
+ for a in a_2n_sub_m:
1317
+ inner_sign *= 4
1318
+ inner_a.append(0.5*(a-m))
1319
+ inner_a.append(0.5*(a-m+1))
1320
+ outer_b.append(a-m-1)
1321
+ inner = ctx.hyper(inner_a, inner_b, inner_sign*y,
1322
+ zeroprec=ctx.prec, **kwargs)
1323
+ term = outer * inner * outer_sign
1324
+ if abs(term) < tol:
1325
+ ok_count += 1
1326
+ else:
1327
+ ok_count = 0
1328
+ if ok_count >= 3 or not outer:
1329
+ break
1330
+ s += term
1331
+ for a in outer_a: outer *= a
1332
+ for b in outer_b: outer /= b
1333
+ m += 1
1334
+ outer = outer * x / m
1335
+ if m > maxterms:
1336
+ raise ctx.NoConvergence("maxterms exceeded in hyper2d")
1337
+ finally:
1338
+ ctx.prec = prec
1339
+ return +s
1340
+
1341
+ """
1342
+ @defun
1343
+ def kampe_de_feriet(ctx,a,b,c,d,e,f,x,y,**kwargs):
1344
+ return ctx.hyper2d({'m+n':a,'m':b,'n':c},
1345
+ {'m+n':d,'m':e,'n':f}, x,y, **kwargs)
1346
+ """
1347
+
1348
+ @defun
1349
+ def bihyper(ctx, a_s, b_s, z, **kwargs):
1350
+ r"""
1351
+ Evaluates the bilateral hypergeometric series
1352
+
1353
+ .. math ::
1354
+
1355
+ \,_AH_B(a_1, \ldots, a_k; b_1, \ldots, b_B; z) =
1356
+ \sum_{n=-\infty}^{\infty}
1357
+ \frac{(a_1)_n \ldots (a_A)_n}
1358
+ {(b_1)_n \ldots (b_B)_n} \, z^n
1359
+
1360
+ where, for direct convergence, `A = B` and `|z| = 1`, although a
1361
+ regularized sum exists more generally by considering the
1362
+ bilateral series as a sum of two ordinary hypergeometric
1363
+ functions. In order for the series to make sense, none of the
1364
+ parameters may be integers.
1365
+
1366
+ **Examples**
1367
+
1368
+ The value of `\,_2H_2` at `z = 1` is given by Dougall's formula::
1369
+
1370
+ >>> from mpmath import *
1371
+ >>> mp.dps = 25; mp.pretty = True
1372
+ >>> a,b,c,d = 0.5, 1.5, 2.25, 3.25
1373
+ >>> bihyper([a,b],[c,d],1)
1374
+ -14.49118026212345786148847
1375
+ >>> gammaprod([c,d,1-a,1-b,c+d-a-b-1],[c-a,d-a,c-b,d-b])
1376
+ -14.49118026212345786148847
1377
+
1378
+ The regularized function `\,_1H_0` can be expressed as the
1379
+ sum of one `\,_2F_0` function and one `\,_1F_1` function::
1380
+
1381
+ >>> a = mpf(0.25)
1382
+ >>> z = mpf(0.75)
1383
+ >>> bihyper([a], [], z)
1384
+ (0.2454393389657273841385582 + 0.2454393389657273841385582j)
1385
+ >>> hyper([a,1],[],z) + (hyper([1],[1-a],-1/z)-1)
1386
+ (0.2454393389657273841385582 + 0.2454393389657273841385582j)
1387
+ >>> hyper([a,1],[],z) + hyper([1],[2-a],-1/z)/z/(a-1)
1388
+ (0.2454393389657273841385582 + 0.2454393389657273841385582j)
1389
+
1390
+ **References**
1391
+
1392
+ 1. [Slater]_ (chapter 6: "Bilateral Series", pp. 180-189)
1393
+ 2. [Wikipedia]_ http://en.wikipedia.org/wiki/Bilateral_hypergeometric_series
1394
+
1395
+ """
1396
+ z = ctx.convert(z)
1397
+ c_s = a_s + b_s
1398
+ p = len(a_s)
1399
+ q = len(b_s)
1400
+ if (p, q) == (0,0) or (p, q) == (1,1):
1401
+ return ctx.zero * z
1402
+ neg = (p-q) % 2
1403
+ def h(*c_s):
1404
+ a_s = list(c_s[:p])
1405
+ b_s = list(c_s[p:])
1406
+ aa_s = [2-b for b in b_s]
1407
+ bb_s = [2-a for a in a_s]
1408
+ rp = [(-1)**neg * z] + [1-b for b in b_s] + [1-a for a in a_s]
1409
+ rc = [-1] + [1]*len(b_s) + [-1]*len(a_s)
1410
+ T1 = [], [], [], [], a_s + [1], b_s, z
1411
+ T2 = rp, rc, [], [], aa_s + [1], bb_s, (-1)**neg / z
1412
+ return T1, T2
1413
+ return ctx.hypercomb(h, c_s, **kwargs)
lib/python3.11/site-packages/mpmath/functions/orthogonal.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .functions import defun, defun_wrapped
2
+
3
+ def _hermite_param(ctx, n, z, parabolic_cylinder):
4
+ """
5
+ Combined calculation of the Hermite polynomial H_n(z) (and its
6
+ generalization to complex n) and the parabolic cylinder
7
+ function D.
8
+ """
9
+ n, ntyp = ctx._convert_param(n)
10
+ z = ctx.convert(z)
11
+ q = -ctx.mpq_1_2
12
+ # For re(z) > 0, 2F0 -- http://functions.wolfram.com/
13
+ # HypergeometricFunctions/HermiteHGeneral/06/02/0009/
14
+ # Otherwise, there is a reflection formula
15
+ # 2F0 + http://functions.wolfram.com/HypergeometricFunctions/
16
+ # HermiteHGeneral/16/01/01/0006/
17
+ #
18
+ # TODO:
19
+ # An alternative would be to use
20
+ # http://functions.wolfram.com/HypergeometricFunctions/
21
+ # HermiteHGeneral/06/02/0006/
22
+ #
23
+ # Also, the 1F1 expansion
24
+ # http://functions.wolfram.com/HypergeometricFunctions/
25
+ # HermiteHGeneral/26/01/02/0001/
26
+ # should probably be used for tiny z
27
+ if not z:
28
+ T1 = [2, ctx.pi], [n, 0.5], [], [q*(n-1)], [], [], 0
29
+ if parabolic_cylinder:
30
+ T1[1][0] += q*n
31
+ return T1,
32
+ can_use_2f0 = ctx.isnpint(-n) or ctx.re(z) > 0 or \
33
+ (ctx.re(z) == 0 and ctx.im(z) > 0)
34
+ expprec = ctx.prec*4 + 20
35
+ if parabolic_cylinder:
36
+ u = ctx.fmul(ctx.fmul(z,z,prec=expprec), -0.25, exact=True)
37
+ w = ctx.fmul(z, ctx.sqrt(0.5,prec=expprec), prec=expprec)
38
+ else:
39
+ w = z
40
+ w2 = ctx.fmul(w, w, prec=expprec)
41
+ rw2 = ctx.fdiv(1, w2, prec=expprec)
42
+ nrw2 = ctx.fneg(rw2, exact=True)
43
+ nw = ctx.fneg(w, exact=True)
44
+ if can_use_2f0:
45
+ T1 = [2, w], [n, n], [], [], [q*n, q*(n-1)], [], nrw2
46
+ terms = [T1]
47
+ else:
48
+ T1 = [2, nw], [n, n], [], [], [q*n, q*(n-1)], [], nrw2
49
+ T2 = [2, ctx.pi, nw], [n+2, 0.5, 1], [], [q*n], [q*(n-1)], [1-q], w2
50
+ terms = [T1,T2]
51
+ # Multiply by prefactor for D_n
52
+ if parabolic_cylinder:
53
+ expu = ctx.exp(u)
54
+ for i in range(len(terms)):
55
+ terms[i][1][0] += q*n
56
+ terms[i][0].append(expu)
57
+ terms[i][1].append(1)
58
+ return tuple(terms)
59
+
60
+ @defun
61
+ def hermite(ctx, n, z, **kwargs):
62
+ return ctx.hypercomb(lambda: _hermite_param(ctx, n, z, 0), [], **kwargs)
63
+
64
+ @defun
65
+ def pcfd(ctx, n, z, **kwargs):
66
+ r"""
67
+ Gives the parabolic cylinder function in Whittaker's notation
68
+ `D_n(z) = U(-n-1/2, z)` (see :func:`~mpmath.pcfu`).
69
+ It solves the differential equation
70
+
71
+ .. math ::
72
+
73
+ y'' + \left(n + \frac{1}{2} - \frac{1}{4} z^2\right) y = 0.
74
+
75
+ and can be represented in terms of Hermite polynomials
76
+ (see :func:`~mpmath.hermite`) as
77
+
78
+ .. math ::
79
+
80
+ D_n(z) = 2^{-n/2} e^{-z^2/4} H_n\left(\frac{z}{\sqrt{2}}\right).
81
+
82
+ **Plots**
83
+
84
+ .. literalinclude :: /plots/pcfd.py
85
+ .. image :: /plots/pcfd.png
86
+
87
+ **Examples**
88
+
89
+ >>> from mpmath import *
90
+ >>> mp.dps = 25; mp.pretty = True
91
+ >>> pcfd(0,0); pcfd(1,0); pcfd(2,0); pcfd(3,0)
92
+ 1.0
93
+ 0.0
94
+ -1.0
95
+ 0.0
96
+ >>> pcfd(4,0); pcfd(-3,0)
97
+ 3.0
98
+ 0.6266570686577501256039413
99
+ >>> pcfd('1/2', 2+3j)
100
+ (-5.363331161232920734849056 - 3.858877821790010714163487j)
101
+ >>> pcfd(2, -10)
102
+ 1.374906442631438038871515e-9
103
+
104
+ Verifying the differential equation::
105
+
106
+ >>> n = mpf(2.5)
107
+ >>> y = lambda z: pcfd(n,z)
108
+ >>> z = 1.75
109
+ >>> chop(diff(y,z,2) + (n+0.5-0.25*z**2)*y(z))
110
+ 0.0
111
+
112
+ Rational Taylor series expansion when `n` is an integer::
113
+
114
+ >>> taylor(lambda z: pcfd(5,z), 0, 7)
115
+ [0.0, 15.0, 0.0, -13.75, 0.0, 3.96875, 0.0, -0.6015625]
116
+
117
+ """
118
+ return ctx.hypercomb(lambda: _hermite_param(ctx, n, z, 1), [], **kwargs)
119
+
120
+ @defun
121
+ def pcfu(ctx, a, z, **kwargs):
122
+ r"""
123
+ Gives the parabolic cylinder function `U(a,z)`, which may be
124
+ defined for `\Re(z) > 0` in terms of the confluent
125
+ U-function (see :func:`~mpmath.hyperu`) by
126
+
127
+ .. math ::
128
+
129
+ U(a,z) = 2^{-\frac{1}{4}-\frac{a}{2}} e^{-\frac{1}{4} z^2}
130
+ U\left(\frac{a}{2}+\frac{1}{4},
131
+ \frac{1}{2}, \frac{1}{2}z^2\right)
132
+
133
+ or, for arbitrary `z`,
134
+
135
+ .. math ::
136
+
137
+ e^{-\frac{1}{4}z^2} U(a,z) =
138
+ U(a,0) \,_1F_1\left(-\tfrac{a}{2}+\tfrac{1}{4};
139
+ \tfrac{1}{2}; -\tfrac{1}{2}z^2\right) +
140
+ U'(a,0) z \,_1F_1\left(-\tfrac{a}{2}+\tfrac{3}{4};
141
+ \tfrac{3}{2}; -\tfrac{1}{2}z^2\right).
142
+
143
+ **Examples**
144
+
145
+ Connection to other functions::
146
+
147
+ >>> from mpmath import *
148
+ >>> mp.dps = 25; mp.pretty = True
149
+ >>> z = mpf(3)
150
+ >>> pcfu(0.5,z)
151
+ 0.03210358129311151450551963
152
+ >>> sqrt(pi/2)*exp(z**2/4)*erfc(z/sqrt(2))
153
+ 0.03210358129311151450551963
154
+ >>> pcfu(0.5,-z)
155
+ 23.75012332835297233711255
156
+ >>> sqrt(pi/2)*exp(z**2/4)*erfc(-z/sqrt(2))
157
+ 23.75012332835297233711255
158
+ >>> pcfu(0.5,-z)
159
+ 23.75012332835297233711255
160
+ >>> sqrt(pi/2)*exp(z**2/4)*erfc(-z/sqrt(2))
161
+ 23.75012332835297233711255
162
+
163
+ """
164
+ n, _ = ctx._convert_param(a)
165
+ return ctx.pcfd(-n-ctx.mpq_1_2, z)
166
+
167
+ @defun
168
+ def pcfv(ctx, a, z, **kwargs):
169
+ r"""
170
+ Gives the parabolic cylinder function `V(a,z)`, which can be
171
+ represented in terms of :func:`~mpmath.pcfu` as
172
+
173
+ .. math ::
174
+
175
+ V(a,z) = \frac{\Gamma(a+\tfrac{1}{2}) (U(a,-z)-\sin(\pi a) U(a,z)}{\pi}.
176
+
177
+ **Examples**
178
+
179
+ Wronskian relation between `U` and `V`::
180
+
181
+ >>> from mpmath import *
182
+ >>> mp.dps = 25; mp.pretty = True
183
+ >>> a, z = 2, 3
184
+ >>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
185
+ 0.7978845608028653558798921
186
+ >>> sqrt(2/pi)
187
+ 0.7978845608028653558798921
188
+ >>> a, z = 2.5, 3
189
+ >>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
190
+ 0.7978845608028653558798921
191
+ >>> a, z = 0.25, -1
192
+ >>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
193
+ 0.7978845608028653558798921
194
+ >>> a, z = 2+1j, 2+3j
195
+ >>> chop(pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z))
196
+ 0.7978845608028653558798921
197
+
198
+ """
199
+ n, ntype = ctx._convert_param(a)
200
+ z = ctx.convert(z)
201
+ q = ctx.mpq_1_2
202
+ r = ctx.mpq_1_4
203
+ if ntype == 'Q' and ctx.isint(n*2):
204
+ # Faster for half-integers
205
+ def h():
206
+ jz = ctx.fmul(z, -1j, exact=True)
207
+ T1terms = _hermite_param(ctx, -n-q, z, 1)
208
+ T2terms = _hermite_param(ctx, n-q, jz, 1)
209
+ for T in T1terms:
210
+ T[0].append(1j)
211
+ T[1].append(1)
212
+ T[3].append(q-n)
213
+ u = ctx.expjpi((q*n-r)) * ctx.sqrt(2/ctx.pi)
214
+ for T in T2terms:
215
+ T[0].append(u)
216
+ T[1].append(1)
217
+ return T1terms + T2terms
218
+ v = ctx.hypercomb(h, [], **kwargs)
219
+ if ctx._is_real_type(n) and ctx._is_real_type(z):
220
+ v = ctx._re(v)
221
+ return v
222
+ else:
223
+ def h(n):
224
+ w = ctx.square_exp_arg(z, -0.25)
225
+ u = ctx.square_exp_arg(z, 0.5)
226
+ e = ctx.exp(w)
227
+ l = [ctx.pi, q, ctx.exp(w)]
228
+ Y1 = l, [-q, n*q+r, 1], [r-q*n], [], [q*n+r], [q], u
229
+ Y2 = l + [z], [-q, n*q-r, 1, 1], [1-r-q*n], [], [q*n+1-r], [1+q], u
230
+ c, s = ctx.cospi_sinpi(r+q*n)
231
+ Y1[0].append(s)
232
+ Y2[0].append(c)
233
+ for Y in (Y1, Y2):
234
+ Y[1].append(1)
235
+ Y[3].append(q-n)
236
+ return Y1, Y2
237
+ return ctx.hypercomb(h, [n], **kwargs)
238
+
239
+
240
+ @defun
241
+ def pcfw(ctx, a, z, **kwargs):
242
+ r"""
243
+ Gives the parabolic cylinder function `W(a,z)` defined in (DLMF 12.14).
244
+
245
+ **Examples**
246
+
247
+ Value at the origin::
248
+
249
+ >>> from mpmath import *
250
+ >>> mp.dps = 25; mp.pretty = True
251
+ >>> a = mpf(0.25)
252
+ >>> pcfw(a,0)
253
+ 0.9722833245718180765617104
254
+ >>> power(2,-0.75)*sqrt(abs(gamma(0.25+0.5j*a)/gamma(0.75+0.5j*a)))
255
+ 0.9722833245718180765617104
256
+ >>> diff(pcfw,(a,0),(0,1))
257
+ -0.5142533944210078966003624
258
+ >>> -power(2,-0.25)*sqrt(abs(gamma(0.75+0.5j*a)/gamma(0.25+0.5j*a)))
259
+ -0.5142533944210078966003624
260
+
261
+ """
262
+ n, _ = ctx._convert_param(a)
263
+ z = ctx.convert(z)
264
+ def terms():
265
+ phi2 = ctx.arg(ctx.gamma(0.5 + ctx.j*n))
266
+ phi2 = (ctx.loggamma(0.5+ctx.j*n) - ctx.loggamma(0.5-ctx.j*n))/2j
267
+ rho = ctx.pi/8 + 0.5*phi2
268
+ # XXX: cancellation computing k
269
+ k = ctx.sqrt(1 + ctx.exp(2*ctx.pi*n)) - ctx.exp(ctx.pi*n)
270
+ C = ctx.sqrt(k/2) * ctx.exp(0.25*ctx.pi*n)
271
+ yield C * ctx.expj(rho) * ctx.pcfu(ctx.j*n, z*ctx.expjpi(-0.25))
272
+ yield C * ctx.expj(-rho) * ctx.pcfu(-ctx.j*n, z*ctx.expjpi(0.25))
273
+ v = ctx.sum_accurately(terms)
274
+ if ctx._is_real_type(n) and ctx._is_real_type(z):
275
+ v = ctx._re(v)
276
+ return v
277
+
278
+ """
279
+ Even/odd PCFs. Useful?
280
+
281
+ @defun
282
+ def pcfy1(ctx, a, z, **kwargs):
283
+ a, _ = ctx._convert_param(n)
284
+ z = ctx.convert(z)
285
+ def h():
286
+ w = ctx.square_exp_arg(z)
287
+ w1 = ctx.fmul(w, -0.25, exact=True)
288
+ w2 = ctx.fmul(w, 0.5, exact=True)
289
+ e = ctx.exp(w1)
290
+ return [e], [1], [], [], [ctx.mpq_1_2*a+ctx.mpq_1_4], [ctx.mpq_1_2], w2
291
+ return ctx.hypercomb(h, [], **kwargs)
292
+
293
+ @defun
294
+ def pcfy2(ctx, a, z, **kwargs):
295
+ a, _ = ctx._convert_param(n)
296
+ z = ctx.convert(z)
297
+ def h():
298
+ w = ctx.square_exp_arg(z)
299
+ w1 = ctx.fmul(w, -0.25, exact=True)
300
+ w2 = ctx.fmul(w, 0.5, exact=True)
301
+ e = ctx.exp(w1)
302
+ return [e, z], [1, 1], [], [], [ctx.mpq_1_2*a+ctx.mpq_3_4], \
303
+ [ctx.mpq_3_2], w2
304
+ return ctx.hypercomb(h, [], **kwargs)
305
+ """
306
+
307
+ @defun_wrapped
308
+ def gegenbauer(ctx, n, a, z, **kwargs):
309
+ # Special cases: a+0.5, a*2 poles
310
+ if ctx.isnpint(a):
311
+ return 0*(z+n)
312
+ if ctx.isnpint(a+0.5):
313
+ # TODO: something else is required here
314
+ # E.g.: gegenbauer(-2, -0.5, 3) == -12
315
+ if ctx.isnpint(n+1):
316
+ raise NotImplementedError("Gegenbauer function with two limits")
317
+ def h(a):
318
+ a2 = 2*a
319
+ T = [], [], [n+a2], [n+1, a2], [-n, n+a2], [a+0.5], 0.5*(1-z)
320
+ return [T]
321
+ return ctx.hypercomb(h, [a], **kwargs)
322
+ def h(n):
323
+ a2 = 2*a
324
+ T = [], [], [n+a2], [n+1, a2], [-n, n+a2], [a+0.5], 0.5*(1-z)
325
+ return [T]
326
+ return ctx.hypercomb(h, [n], **kwargs)
327
+
328
+ @defun_wrapped
329
+ def jacobi(ctx, n, a, b, x, **kwargs):
330
+ if not ctx.isnpint(a):
331
+ def h(n):
332
+ return (([], [], [a+n+1], [n+1, a+1], [-n, a+b+n+1], [a+1], (1-x)*0.5),)
333
+ return ctx.hypercomb(h, [n], **kwargs)
334
+ if not ctx.isint(b):
335
+ def h(n, a):
336
+ return (([], [], [-b], [n+1, -b-n], [-n, a+b+n+1], [b+1], (x+1)*0.5),)
337
+ return ctx.hypercomb(h, [n, a], **kwargs)
338
+ # XXX: determine appropriate limit
339
+ return ctx.binomial(n+a,n) * ctx.hyp2f1(-n,1+n+a+b,a+1,(1-x)/2, **kwargs)
340
+
341
+ @defun_wrapped
342
+ def laguerre(ctx, n, a, z, **kwargs):
343
+ # XXX: limits, poles
344
+ #if ctx.isnpint(n):
345
+ # return 0*(a+z)
346
+ def h(a):
347
+ return (([], [], [a+n+1], [a+1, n+1], [-n], [a+1], z),)
348
+ return ctx.hypercomb(h, [a], **kwargs)
349
+
350
+ @defun_wrapped
351
+ def legendre(ctx, n, x, **kwargs):
352
+ if ctx.isint(n):
353
+ n = int(n)
354
+ # Accuracy near zeros
355
+ if (n + (n < 0)) & 1:
356
+ if not x:
357
+ return x
358
+ mag = ctx.mag(x)
359
+ if mag < -2*ctx.prec-10:
360
+ return x
361
+ if mag < -5:
362
+ ctx.prec += -mag
363
+ return ctx.hyp2f1(-n,n+1,1,(1-x)/2, **kwargs)
364
+
365
+ @defun
366
+ def legenp(ctx, n, m, z, type=2, **kwargs):
367
+ # Legendre function, 1st kind
368
+ n = ctx.convert(n)
369
+ m = ctx.convert(m)
370
+ # Faster
371
+ if not m:
372
+ return ctx.legendre(n, z, **kwargs)
373
+ # TODO: correct evaluation at singularities
374
+ if type == 2:
375
+ def h(n,m):
376
+ g = m*0.5
377
+ T = [1+z, 1-z], [g, -g], [], [1-m], [-n, n+1], [1-m], 0.5*(1-z)
378
+ return (T,)
379
+ return ctx.hypercomb(h, [n,m], **kwargs)
380
+ if type == 3:
381
+ def h(n,m):
382
+ g = m*0.5
383
+ T = [z+1, z-1], [g, -g], [], [1-m], [-n, n+1], [1-m], 0.5*(1-z)
384
+ return (T,)
385
+ return ctx.hypercomb(h, [n,m], **kwargs)
386
+ raise ValueError("requires type=2 or type=3")
387
+
388
+ @defun
389
+ def legenq(ctx, n, m, z, type=2, **kwargs):
390
+ # Legendre function, 2nd kind
391
+ n = ctx.convert(n)
392
+ m = ctx.convert(m)
393
+ z = ctx.convert(z)
394
+ if z in (1, -1):
395
+ #if ctx.isint(m):
396
+ # return ctx.nan
397
+ #return ctx.inf # unsigned
398
+ return ctx.nan
399
+ if type == 2:
400
+ def h(n, m):
401
+ cos, sin = ctx.cospi_sinpi(m)
402
+ s = 2 * sin / ctx.pi
403
+ c = cos
404
+ a = 1+z
405
+ b = 1-z
406
+ u = m/2
407
+ w = (1-z)/2
408
+ T1 = [s, c, a, b], [-1, 1, u, -u], [], [1-m], \
409
+ [-n, n+1], [1-m], w
410
+ T2 = [-s, a, b], [-1, -u, u], [n+m+1], [n-m+1, m+1], \
411
+ [-n, n+1], [m+1], w
412
+ return T1, T2
413
+ return ctx.hypercomb(h, [n, m], **kwargs)
414
+ if type == 3:
415
+ # The following is faster when there only is a single series
416
+ # Note: not valid for -1 < z < 0 (?)
417
+ if abs(z) > 1:
418
+ def h(n, m):
419
+ T1 = [ctx.expjpi(m), 2, ctx.pi, z, z-1, z+1], \
420
+ [1, -n-1, 0.5, -n-m-1, 0.5*m, 0.5*m], \
421
+ [n+m+1], [n+1.5], \
422
+ [0.5*(2+n+m), 0.5*(1+n+m)], [n+1.5], z**(-2)
423
+ return [T1]
424
+ return ctx.hypercomb(h, [n, m], **kwargs)
425
+ else:
426
+ # not valid for 1 < z < inf ?
427
+ def h(n, m):
428
+ s = 2 * ctx.sinpi(m) / ctx.pi
429
+ c = ctx.expjpi(m)
430
+ a = 1+z
431
+ b = z-1
432
+ u = m/2
433
+ w = (1-z)/2
434
+ T1 = [s, c, a, b], [-1, 1, u, -u], [], [1-m], \
435
+ [-n, n+1], [1-m], w
436
+ T2 = [-s, c, a, b], [-1, 1, -u, u], [n+m+1], [n-m+1, m+1], \
437
+ [-n, n+1], [m+1], w
438
+ return T1, T2
439
+ return ctx.hypercomb(h, [n, m], **kwargs)
440
+ raise ValueError("requires type=2 or type=3")
441
+
442
+ @defun_wrapped
443
+ def chebyt(ctx, n, x, **kwargs):
444
+ if (not x) and ctx.isint(n) and int(ctx._re(n)) % 2 == 1:
445
+ return x * 0
446
+ return ctx.hyp2f1(-n,n,(1,2),(1-x)/2, **kwargs)
447
+
448
+ @defun_wrapped
449
+ def chebyu(ctx, n, x, **kwargs):
450
+ if (not x) and ctx.isint(n) and int(ctx._re(n)) % 2 == 1:
451
+ return x * 0
452
+ return (n+1) * ctx.hyp2f1(-n, n+2, (3,2), (1-x)/2, **kwargs)
453
+
454
+ @defun
455
+ def spherharm(ctx, l, m, theta, phi, **kwargs):
456
+ l = ctx.convert(l)
457
+ m = ctx.convert(m)
458
+ theta = ctx.convert(theta)
459
+ phi = ctx.convert(phi)
460
+ l_isint = ctx.isint(l)
461
+ l_natural = l_isint and l >= 0
462
+ m_isint = ctx.isint(m)
463
+ if l_isint and l < 0 and m_isint:
464
+ return ctx.spherharm(-(l+1), m, theta, phi, **kwargs)
465
+ if theta == 0 and m_isint and m < 0:
466
+ return ctx.zero * 1j
467
+ if l_natural and m_isint:
468
+ if abs(m) > l:
469
+ return ctx.zero * 1j
470
+ # http://functions.wolfram.com/Polynomials/
471
+ # SphericalHarmonicY/26/01/02/0004/
472
+ def h(l,m):
473
+ absm = abs(m)
474
+ C = [-1, ctx.expj(m*phi),
475
+ (2*l+1)*ctx.fac(l+absm)/ctx.pi/ctx.fac(l-absm),
476
+ ctx.sin(theta)**2,
477
+ ctx.fac(absm), 2]
478
+ P = [0.5*m*(ctx.sign(m)+1), 1, 0.5, 0.5*absm, -1, -absm-1]
479
+ return ((C, P, [], [], [absm-l, l+absm+1], [absm+1],
480
+ ctx.sin(0.5*theta)**2),)
481
+ else:
482
+ # http://functions.wolfram.com/HypergeometricFunctions/
483
+ # SphericalHarmonicYGeneral/26/01/02/0001/
484
+ def h(l,m):
485
+ if ctx.isnpint(l-m+1) or ctx.isnpint(l+m+1) or ctx.isnpint(1-m):
486
+ return (([0], [-1], [], [], [], [], 0),)
487
+ cos, sin = ctx.cos_sin(0.5*theta)
488
+ C = [0.5*ctx.expj(m*phi), (2*l+1)/ctx.pi,
489
+ ctx.gamma(l-m+1), ctx.gamma(l+m+1),
490
+ cos**2, sin**2]
491
+ P = [1, 0.5, 0.5, -0.5, 0.5*m, -0.5*m]
492
+ return ((C, P, [], [1-m], [-l,l+1], [1-m], sin**2),)
493
+ return ctx.hypercomb(h, [l,m], **kwargs)
lib/python3.11/site-packages/mpmath/functions/qfunctions.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .functions import defun, defun_wrapped
2
+
3
+ @defun
4
+ def qp(ctx, a, q=None, n=None, **kwargs):
5
+ r"""
6
+ Evaluates the q-Pochhammer symbol (or q-rising factorial)
7
+
8
+ .. math ::
9
+
10
+ (a; q)_n = \prod_{k=0}^{n-1} (1-a q^k)
11
+
12
+ where `n = \infty` is permitted if `|q| < 1`. Called with two arguments,
13
+ ``qp(a,q)`` computes `(a;q)_{\infty}`; with a single argument, ``qp(q)``
14
+ computes `(q;q)_{\infty}`. The special case
15
+
16
+ .. math ::
17
+
18
+ \phi(q) = (q; q)_{\infty} = \prod_{k=1}^{\infty} (1-q^k) =
19
+ \sum_{k=-\infty}^{\infty} (-1)^k q^{(3k^2-k)/2}
20
+
21
+ is also known as the Euler function, or (up to a factor `q^{-1/24}`)
22
+ the Dedekind eta function.
23
+
24
+ **Examples**
25
+
26
+ If `n` is a positive integer, the function amounts to a finite product::
27
+
28
+ >>> from mpmath import *
29
+ >>> mp.dps = 25; mp.pretty = True
30
+ >>> qp(2,3,5)
31
+ -725305.0
32
+ >>> fprod(1-2*3**k for k in range(5))
33
+ -725305.0
34
+ >>> qp(2,3,0)
35
+ 1.0
36
+
37
+ Complex arguments are allowed::
38
+
39
+ >>> qp(2-1j, 0.75j)
40
+ (0.4628842231660149089976379 + 4.481821753552703090628793j)
41
+
42
+ The regular Pochhammer symbol `(a)_n` is obtained in the
43
+ following limit as `q \to 1`::
44
+
45
+ >>> a, n = 4, 7
46
+ >>> limit(lambda q: qp(q**a,q,n) / (1-q)**n, 1)
47
+ 604800.0
48
+ >>> rf(a,n)
49
+ 604800.0
50
+
51
+ The Taylor series of the reciprocal Euler function gives
52
+ the partition function `P(n)`, i.e. the number of ways of writing
53
+ `n` as a sum of positive integers::
54
+
55
+ >>> taylor(lambda q: 1/qp(q), 0, 10)
56
+ [1.0, 1.0, 2.0, 3.0, 5.0, 7.0, 11.0, 15.0, 22.0, 30.0, 42.0]
57
+
58
+ Special values include::
59
+
60
+ >>> qp(0)
61
+ 1.0
62
+ >>> findroot(diffun(qp), -0.4) # location of maximum
63
+ -0.4112484791779547734440257
64
+ >>> qp(_)
65
+ 1.228348867038575112586878
66
+
67
+ The q-Pochhammer symbol is related to the Jacobi theta functions.
68
+ For example, the following identity holds::
69
+
70
+ >>> q = mpf(0.5) # arbitrary
71
+ >>> qp(q)
72
+ 0.2887880950866024212788997
73
+ >>> root(3,-2)*root(q,-24)*jtheta(2,pi/6,root(q,6))
74
+ 0.2887880950866024212788997
75
+
76
+ """
77
+ a = ctx.convert(a)
78
+ if n is None:
79
+ n = ctx.inf
80
+ else:
81
+ n = ctx.convert(n)
82
+ if n < 0:
83
+ raise ValueError("n cannot be negative")
84
+ if q is None:
85
+ q = a
86
+ else:
87
+ q = ctx.convert(q)
88
+ if n == 0:
89
+ return ctx.one + 0*(a+q)
90
+ infinite = (n == ctx.inf)
91
+ same = (a == q)
92
+ if infinite:
93
+ if abs(q) >= 1:
94
+ if same and (q == -1 or q == 1):
95
+ return ctx.zero * q
96
+ raise ValueError("q-function only defined for |q| < 1")
97
+ elif q == 0:
98
+ return ctx.one - a
99
+ maxterms = kwargs.get('maxterms', 50*ctx.prec)
100
+ if infinite and same:
101
+ # Euler's pentagonal theorem
102
+ def terms():
103
+ t = 1
104
+ yield t
105
+ k = 1
106
+ x1 = q
107
+ x2 = q**2
108
+ while 1:
109
+ yield (-1)**k * x1
110
+ yield (-1)**k * x2
111
+ x1 *= q**(3*k+1)
112
+ x2 *= q**(3*k+2)
113
+ k += 1
114
+ if k > maxterms:
115
+ raise ctx.NoConvergence
116
+ return ctx.sum_accurately(terms)
117
+ # return ctx.nprod(lambda k: 1-a*q**k, [0,n-1])
118
+ def factors():
119
+ k = 0
120
+ r = ctx.one
121
+ while 1:
122
+ yield 1 - a*r
123
+ r *= q
124
+ k += 1
125
+ if k >= n:
126
+ return
127
+ if k > maxterms:
128
+ raise ctx.NoConvergence
129
+ return ctx.mul_accurately(factors)
130
+
131
+ @defun_wrapped
132
+ def qgamma(ctx, z, q, **kwargs):
133
+ r"""
134
+ Evaluates the q-gamma function
135
+
136
+ .. math ::
137
+
138
+ \Gamma_q(z) = \frac{(q; q)_{\infty}}{(q^z; q)_{\infty}} (1-q)^{1-z}.
139
+
140
+
141
+ **Examples**
142
+
143
+ Evaluation for real and complex arguments::
144
+
145
+ >>> from mpmath import *
146
+ >>> mp.dps = 25; mp.pretty = True
147
+ >>> qgamma(4,0.75)
148
+ 4.046875
149
+ >>> qgamma(6,6)
150
+ 121226245.0
151
+ >>> qgamma(3+4j, 0.5j)
152
+ (0.1663082382255199834630088 + 0.01952474576025952984418217j)
153
+
154
+ The q-gamma function satisfies a functional equation similar
155
+ to that of the ordinary gamma function::
156
+
157
+ >>> q = mpf(0.25)
158
+ >>> z = mpf(2.5)
159
+ >>> qgamma(z+1,q)
160
+ 1.428277424823760954685912
161
+ >>> (1-q**z)/(1-q)*qgamma(z,q)
162
+ 1.428277424823760954685912
163
+
164
+ """
165
+ if abs(q) > 1:
166
+ return ctx.qgamma(z,1/q)*q**((z-2)*(z-1)*0.5)
167
+ return ctx.qp(q, q, None, **kwargs) / \
168
+ ctx.qp(q**z, q, None, **kwargs) * (1-q)**(1-z)
169
+
170
+ @defun_wrapped
171
+ def qfac(ctx, z, q, **kwargs):
172
+ r"""
173
+ Evaluates the q-factorial,
174
+
175
+ .. math ::
176
+
177
+ [n]_q! = (1+q)(1+q+q^2)\cdots(1+q+\cdots+q^{n-1})
178
+
179
+ or more generally
180
+
181
+ .. math ::
182
+
183
+ [z]_q! = \frac{(q;q)_z}{(1-q)^z}.
184
+
185
+ **Examples**
186
+
187
+ >>> from mpmath import *
188
+ >>> mp.dps = 25; mp.pretty = True
189
+ >>> qfac(0,0)
190
+ 1.0
191
+ >>> qfac(4,3)
192
+ 2080.0
193
+ >>> qfac(5,6)
194
+ 121226245.0
195
+ >>> qfac(1+1j, 2+1j)
196
+ (0.4370556551322672478613695 + 0.2609739839216039203708921j)
197
+
198
+ """
199
+ if ctx.isint(z) and ctx._re(z) > 0:
200
+ n = int(ctx._re(z))
201
+ return ctx.qp(q, q, n, **kwargs) / (1-q)**n
202
+ return ctx.qgamma(z+1, q, **kwargs)
203
+
204
+ @defun
205
+ def qhyper(ctx, a_s, b_s, q, z, **kwargs):
206
+ r"""
207
+ Evaluates the basic hypergeometric series or hypergeometric q-series
208
+
209
+ .. math ::
210
+
211
+ \,_r\phi_s \left[\begin{matrix}
212
+ a_1 & a_2 & \ldots & a_r \\
213
+ b_1 & b_2 & \ldots & b_s
214
+ \end{matrix} ; q,z \right] =
215
+ \sum_{n=0}^\infty
216
+ \frac{(a_1;q)_n, \ldots, (a_r;q)_n}
217
+ {(b_1;q)_n, \ldots, (b_s;q)_n}
218
+ \left((-1)^n q^{n\choose 2}\right)^{1+s-r}
219
+ \frac{z^n}{(q;q)_n}
220
+
221
+ where `(a;q)_n` denotes the q-Pochhammer symbol (see :func:`~mpmath.qp`).
222
+
223
+ **Examples**
224
+
225
+ Evaluation works for real and complex arguments::
226
+
227
+ >>> from mpmath import *
228
+ >>> mp.dps = 25; mp.pretty = True
229
+ >>> qhyper([0.5], [2.25], 0.25, 4)
230
+ -0.1975849091263356009534385
231
+ >>> qhyper([0.5], [2.25], 0.25-0.25j, 4)
232
+ (2.806330244925716649839237 + 3.568997623337943121769938j)
233
+ >>> qhyper([1+j], [2,3+0.5j], 0.25, 3+4j)
234
+ (9.112885171773400017270226 - 1.272756997166375050700388j)
235
+
236
+ Comparing with a summation of the defining series, using
237
+ :func:`~mpmath.nsum`::
238
+
239
+ >>> b, q, z = 3, 0.25, 0.5
240
+ >>> qhyper([], [b], q, z)
241
+ 0.6221136748254495583228324
242
+ >>> nsum(lambda n: z**n / qp(q,q,n)/qp(b,q,n) * q**(n*(n-1)), [0,inf])
243
+ 0.6221136748254495583228324
244
+
245
+ """
246
+ #a_s = [ctx._convert_param(a)[0] for a in a_s]
247
+ #b_s = [ctx._convert_param(b)[0] for b in b_s]
248
+ #q = ctx._convert_param(q)[0]
249
+ a_s = [ctx.convert(a) for a in a_s]
250
+ b_s = [ctx.convert(b) for b in b_s]
251
+ q = ctx.convert(q)
252
+ z = ctx.convert(z)
253
+ r = len(a_s)
254
+ s = len(b_s)
255
+ d = 1+s-r
256
+ maxterms = kwargs.get('maxterms', 50*ctx.prec)
257
+ def terms():
258
+ t = ctx.one
259
+ yield t
260
+ qk = 1
261
+ k = 0
262
+ x = 1
263
+ while 1:
264
+ for a in a_s:
265
+ p = 1 - a*qk
266
+ t *= p
267
+ for b in b_s:
268
+ p = 1 - b*qk
269
+ if not p:
270
+ raise ValueError
271
+ t /= p
272
+ t *= z
273
+ x *= (-1)**d * qk ** d
274
+ qk *= q
275
+ t /= (1 - qk)
276
+ k += 1
277
+ yield t * x
278
+ if k > maxterms:
279
+ raise ctx.NoConvergence
280
+ return ctx.sum_accurately(terms)
lib/python3.11/site-packages/mpmath/functions/rszeta.py ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ---------------------------------------------------------------------
3
+ .. sectionauthor:: Juan Arias de Reyna <[email protected]>
4
+
5
+ This module implements zeta-related functions using the Riemann-Siegel
6
+ expansion: zeta_offline(s,k=0)
7
+
8
+ * coef(J, eps): Need in the computation of Rzeta(s,k)
9
+
10
+ * Rzeta_simul(s, der=0) computes Rzeta^(k)(s) and Rzeta^(k)(1-s) simultaneously
11
+ for 0 <= k <= der. Used by zeta_offline and z_offline
12
+
13
+ * Rzeta_set(s, derivatives) computes Rzeta^(k)(s) for given derivatives, used by
14
+ z_half(t,k) and zeta_half
15
+
16
+ * z_offline(w,k): Z(w) and its derivatives of order k <= 4
17
+ * z_half(t,k): Z(t) (Riemann Siegel function) and its derivatives of order k <= 4
18
+ * zeta_offline(s): zeta(s) and its derivatives of order k<= 4
19
+ * zeta_half(1/2+it,k): zeta(s) and its derivatives of order k<= 4
20
+
21
+ * rs_zeta(s,k=0) Computes zeta^(k)(s) Unifies zeta_half and zeta_offline
22
+ * rs_z(w,k=0) Computes Z^(k)(w) Unifies z_offline and z_half
23
+ ----------------------------------------------------------------------
24
+
25
+ This program uses Riemann-Siegel expansion even to compute
26
+ zeta(s) on points s = sigma + i t with sigma arbitrary not
27
+ necessarily equal to 1/2.
28
+
29
+ It is founded on a new deduction of the formula, with rigorous
30
+ and sharp bounds for the terms and rest of this expansion.
31
+
32
+ More information on the papers:
33
+
34
+ J. Arias de Reyna, High Precision Computation of Riemann's
35
+ Zeta Function by the Riemann-Siegel Formula I, II
36
+
37
+ We refer to them as I, II.
38
+
39
+ In them we shall find detailed explanation of all the
40
+ procedure.
41
+
42
+ The program uses Riemann-Siegel expansion.
43
+ This is useful when t is big, ( say t > 10000 ).
44
+ The precision is limited, roughly it can compute zeta(sigma+it)
45
+ with an error less than exp(-c t) for some constant c depending
46
+ on sigma. The program gives an error when the Riemann-Siegel
47
+ formula can not compute to the wanted precision.
48
+
49
+ """
50
+
51
+ import math
52
+
53
+ class RSCache(object):
54
+ def __init__(ctx):
55
+ ctx._rs_cache = [0, 10, {}, {}]
56
+
57
+ from .functions import defun
58
+
59
+ #-------------------------------------------------------------------------------#
60
+ # #
61
+ # coef(ctx, J, eps, _cache=[0, 10, {} ] ) #
62
+ # #
63
+ #-------------------------------------------------------------------------------#
64
+
65
+ # This function computes the coefficients c[n] defined on (I, equation (47))
66
+ # but see also (II, section 3.14).
67
+ #
68
+ # Since these coefficients are very difficult to compute we save the values
69
+ # in a cache. So if we compute several values of the functions Rzeta(s) for
70
+ # near values of s, we do not recompute these coefficients.
71
+ #
72
+ # c[n] are the Taylor coefficients of the function:
73
+ #
74
+ # F(z):= (exp(pi*j*(z*z/2+3/8))-j* sqrt(2) cos(pi*z/2))/(2*cos(pi *z))
75
+ #
76
+ #
77
+
78
+ def _coef(ctx, J, eps):
79
+ r"""
80
+ Computes the coefficients `c_n` for `0\le n\le 2J` with error less than eps
81
+
82
+ **Definition**
83
+
84
+ The coefficients c_n are defined by
85
+
86
+ .. math ::
87
+
88
+ \begin{equation}
89
+ F(z)=\frac{e^{\pi i
90
+ \bigl(\frac{z^2}{2}+\frac38\bigr)}-i\sqrt{2}\cos\frac{\pi}{2}z}{2\cos\pi
91
+ z}=\sum_{n=0}^\infty c_{2n} z^{2n}
92
+ \end{equation}
93
+
94
+ they are computed applying the relation
95
+
96
+ .. math ::
97
+
98
+ \begin{multline}
99
+ c_{2n}=-\frac{i}{\sqrt{2}}\Bigl(\frac{\pi}{2}\Bigr)^{2n}
100
+ \sum_{k=0}^n\frac{(-1)^k}{(2k)!}
101
+ 2^{2n-2k}\frac{(-1)^{n-k}E_{2n-2k}}{(2n-2k)!}+\\
102
+ +e^{3\pi i/8}\sum_{j=0}^n(-1)^j\frac{
103
+ E_{2j}}{(2j)!}\frac{i^{n-j}\pi^{n+j}}{(n-j)!2^{n-j+1}}.
104
+ \end{multline}
105
+ """
106
+
107
+ newJ = J+2 # compute more coefficients that are needed
108
+ neweps6 = eps/2. # compute with a slight more precision that are needed
109
+
110
+ # PREPARATION FOR THE COMPUTATION OF V(N) AND W(N)
111
+ # See II Section 3.16
112
+ #
113
+ # Computing the exponent wpvw of the error II equation (81)
114
+ wpvw = max(ctx.mag(10*(newJ+3)), 4*newJ+5-ctx.mag(neweps6))
115
+
116
+ # Preparation of Euler numbers (we need until the 2*RS_NEWJ)
117
+ E = ctx._eulernum(2*newJ)
118
+
119
+ # Now we have in the cache all the needed Euler numbers.
120
+ #
121
+ # Computing the powers of pi
122
+ #
123
+ # We need to compute the powers pi**n for 1<= n <= 2*J
124
+ # with relative error less than 2**(-wpvw)
125
+ # it is easy to show that this is obtained
126
+ # taking wppi as the least d with
127
+ # 2**d>40*J and 2**d> 4.24 *newJ + 2**wpvw
128
+ # In II Section 3.9 we need also that
129
+ # wppi > wptcoef[0], and that the powers
130
+ # here computed 0<= k <= 2*newJ are more
131
+ # than those needed there that are 2*L-2.
132
+ # so we need J >= L this will be checked
133
+ # before computing tcoef[]
134
+ wppi = max(ctx.mag(40*newJ), ctx.mag(newJ)+3 +wpvw)
135
+ ctx.prec = wppi
136
+ pipower = {}
137
+ pipower[0] = ctx.one
138
+ pipower[1] = ctx.pi
139
+ for n in range(2,2*newJ+1):
140
+ pipower[n] = pipower[n-1]*ctx.pi
141
+
142
+ # COMPUTING THE COEFFICIENTS v(n) AND w(n)
143
+ # see II equation (61) and equations (81) and (82)
144
+ ctx.prec = wpvw+2
145
+ v={}
146
+ w={}
147
+ for n in range(0,newJ+1):
148
+ va = (-1)**n * ctx._eulernum(2*n)
149
+ va = ctx.mpf(va)/ctx.fac(2*n)
150
+ v[n]=va*pipower[2*n]
151
+ for n in range(0,2*newJ+1):
152
+ wa = ctx.one/ctx.fac(n)
153
+ wa=wa/(2**n)
154
+ w[n]=wa*pipower[n]
155
+
156
+ # COMPUTATION OF THE CONVOLUTIONS RS_P1 AND RS_P2
157
+ # See II Section 3.16
158
+ ctx.prec = 15
159
+ wpp1a = 9 - ctx.mag(neweps6)
160
+ P1 = {}
161
+ for n in range(0,newJ+1):
162
+ ctx.prec = 15
163
+ wpp1 = max(ctx.mag(10*(n+4)),4*n+wpp1a)
164
+ ctx.prec = wpp1
165
+ sump = 0
166
+ for k in range(0,n+1):
167
+ sump += ((-1)**k) * v[k]*w[2*n-2*k]
168
+ P1[n]=((-1)**(n+1))*ctx.j*sump
169
+ P2={}
170
+ for n in range(0,newJ+1):
171
+ ctx.prec = 15
172
+ wpp2 = max(ctx.mag(10*(n+4)),4*n+wpp1a)
173
+ ctx.prec = wpp2
174
+ sump = 0
175
+ for k in range(0,n+1):
176
+ sump += (ctx.j**(n-k)) * v[k]*w[n-k]
177
+ P2[n]=sump
178
+ # COMPUTING THE COEFFICIENTS c[2n]
179
+ # See II Section 3.14
180
+ ctx.prec = 15
181
+ wpc0 = 5 - ctx.mag(neweps6)
182
+ wpc = max(6,4*newJ+wpc0)
183
+ ctx.prec = wpc
184
+ mu = ctx.sqrt(ctx.mpf('2'))/2
185
+ nu = ctx.expjpi(3./8)/2
186
+ c={}
187
+ for n in range(0,newJ):
188
+ ctx.prec = 15
189
+ wpc = max(6,4*n+wpc0)
190
+ ctx.prec = wpc
191
+ c[2*n] = mu*P1[n]+nu*P2[n]
192
+ for n in range(1,2*newJ,2):
193
+ c[n] = 0
194
+ return [newJ, neweps6, c, pipower]
195
+
196
+ def coef(ctx, J, eps):
197
+ _cache = ctx._rs_cache
198
+ if J <= _cache[0] and eps >= _cache[1]:
199
+ return _cache[2], _cache[3]
200
+ orig = ctx._mp.prec
201
+ try:
202
+ data = _coef(ctx._mp, J, eps)
203
+ finally:
204
+ ctx._mp.prec = orig
205
+ if ctx is not ctx._mp:
206
+ data[2] = dict((k,ctx.convert(v)) for (k,v) in data[2].items())
207
+ data[3] = dict((k,ctx.convert(v)) for (k,v) in data[3].items())
208
+ ctx._rs_cache[:] = data
209
+ return ctx._rs_cache[2], ctx._rs_cache[3]
210
+
211
+ #-------------------------------------------------------------------------------#
212
+ # #
213
+ # Rzeta_simul(s,k=0) #
214
+ # #
215
+ #-------------------------------------------------------------------------------#
216
+ # This function return a list with the values:
217
+ # Rzeta(sigma+it), conj(Rzeta(1-sigma+it)),Rzeta'(sigma+it), conj(Rzeta'(1-sigma+it)),
218
+ # .... , Rzeta^{(k)}(sigma+it), conj(Rzeta^{(k)}(1-sigma+it))
219
+ #
220
+ # Useful to compute the function zeta(s) and Z(w) or its derivatives.
221
+ #
222
+
223
+ def aux_M_Fp(ctx, xA, xeps4, a, xB1, xL):
224
+ # COMPUTING M NUMBER OF DERIVATIVES Fp[m] TO COMPUTE
225
+ # See II Section 3.11 equations (47) and (48)
226
+ aux1 = 126.0657606*xA/xeps4 # 126.06.. = 316/sqrt(2*pi)
227
+ aux1 = ctx.ln(aux1)
228
+ aux2 = (2*ctx.ln(ctx.pi)+ctx.ln(xB1)+ctx.ln(a))/3 -ctx.ln(2*ctx.pi)/2
229
+ m = 3*xL-3
230
+ aux3= (ctx.loggamma(m+1)-ctx.loggamma(m/3.0+2))/2 -ctx.loggamma((m+1)/2.)
231
+ while((aux1 < m*aux2+ aux3)and (m>1)):
232
+ m = m - 1
233
+ aux3 = (ctx.loggamma(m+1)-ctx.loggamma(m/3.0+2))/2 -ctx.loggamma((m+1)/2.)
234
+ xM = m
235
+ return xM
236
+
237
+ def aux_J_needed(ctx, xA, xeps4, a, xB1, xM):
238
+ # DETERMINATION OF J THE NUMBER OF TERMS NEEDED
239
+ # IN THE TAYLOR SERIES OF F.
240
+ # See II Section 3.11 equation (49))
241
+ # Only determine one
242
+ h1 = xeps4/(632*xA)
243
+ h2 = xB1*a * 126.31337419529260248 # = pi^2*e^2*sqrt(3)
244
+ h2 = h1 * ctx.power((h2/xM**2),(xM-1)/3) / xM
245
+ h3 = min(h1,h2)
246
+ return h3
247
+
248
+ def Rzeta_simul(ctx, s, der=0):
249
+ # First we take the value of ctx.prec
250
+ wpinitial = ctx.prec
251
+
252
+ # INITIALIZATION
253
+ # Take the real and imaginary part of s
254
+ t = ctx._im(s)
255
+ xsigma = ctx._re(s)
256
+ ysigma = 1 - xsigma
257
+
258
+ # Now compute several parameter that appear on the program
259
+ ctx.prec = 15
260
+ a = ctx.sqrt(t/(2*ctx.pi))
261
+ xasigma = a ** xsigma
262
+ yasigma = a ** ysigma
263
+
264
+ # We need a simple bound A1 < asigma (see II Section 3.1 and 3.3)
265
+ xA1=ctx.power(2, ctx.mag(xasigma)-1)
266
+ yA1=ctx.power(2, ctx.mag(yasigma)-1)
267
+
268
+ # We compute various epsilon's (see II end of Section 3.1)
269
+ eps = ctx.power(2, -wpinitial)
270
+ eps1 = eps/6.
271
+ xeps2 = eps * xA1/3.
272
+ yeps2 = eps * yA1/3.
273
+
274
+ # COMPUTING SOME COEFFICIENTS THAT DEPENDS
275
+ # ON sigma
276
+ # constant b and c (see I Theorem 2 formula (26) )
277
+ # coefficients A and B1 (see I Section 6.1 equation (50))
278
+ #
279
+ # here we not need high precision
280
+ ctx.prec = 15
281
+ if xsigma > 0:
282
+ xb = 2.
283
+ xc = math.pow(9,xsigma)/4.44288
284
+ # 4.44288 =(math.sqrt(2)*math.pi)
285
+ xA = math.pow(9,xsigma)
286
+ xB1 = 1
287
+ else:
288
+ xb = 2.25158 # math.sqrt( (3-2* math.log(2))*math.pi )
289
+ xc = math.pow(2,-xsigma)/4.44288
290
+ xA = math.pow(2,-xsigma)
291
+ xB1 = 1.10789 # = 2*sqrt(1-log(2))
292
+
293
+ if(ysigma > 0):
294
+ yb = 2.
295
+ yc = math.pow(9,ysigma)/4.44288
296
+ # 4.44288 =(math.sqrt(2)*math.pi)
297
+ yA = math.pow(9,ysigma)
298
+ yB1 = 1
299
+ else:
300
+ yb = 2.25158 # math.sqrt( (3-2* math.log(2))*math.pi )
301
+ yc = math.pow(2,-ysigma)/4.44288
302
+ yA = math.pow(2,-ysigma)
303
+ yB1 = 1.10789 # = 2*sqrt(1-log(2))
304
+
305
+ # COMPUTING L THE NUMBER OF TERMS NEEDED IN THE RIEMANN-SIEGEL
306
+ # CORRECTION
307
+ # See II Section 3.2
308
+ ctx.prec = 15
309
+ xL = 1
310
+ while 3*xc*ctx.gamma(xL*0.5) * ctx.power(xb*a,-xL) >= xeps2:
311
+ xL = xL+1
312
+ xL = max(2,xL)
313
+ yL = 1
314
+ while 3*yc*ctx.gamma(yL*0.5) * ctx.power(yb*a,-yL) >= yeps2:
315
+ yL = yL+1
316
+ yL = max(2,yL)
317
+
318
+ # The number L has to satify some conditions.
319
+ # If not RS can not compute Rzeta(s) with the prescribed precision
320
+ # (see II, Section 3.2 condition (20) ) and
321
+ # (II, Section 3.3 condition (22) ). Also we have added
322
+ # an additional technical condition in Section 3.17 Proposition 17
323
+ if ((3*xL >= 2*a*a/25.) or (3*xL+2+xsigma<0) or (abs(xsigma) > a/2.) or \
324
+ (3*yL >= 2*a*a/25.) or (3*yL+2+ysigma<0) or (abs(ysigma) > a/2.)):
325
+ ctx.prec = wpinitial
326
+ raise NotImplementedError("Riemann-Siegel can not compute with such precision")
327
+
328
+ # We take the maximum of the two values
329
+ L = max(xL, yL)
330
+
331
+ # INITIALIZATION (CONTINUATION)
332
+ #
333
+ # eps3 is the constant defined on (II, Section 3.5 equation (27) )
334
+ # each term of the RS correction must be computed with error <= eps3
335
+ xeps3 = xeps2/(4*xL)
336
+ yeps3 = yeps2/(4*yL)
337
+
338
+ # eps4 is defined on (II Section 3.6 equation (30) )
339
+ # each component of the formula (II Section 3.6 equation (29) )
340
+ # must be computed with error <= eps4
341
+ xeps4 = xeps3/(3*xL)
342
+ yeps4 = yeps3/(3*yL)
343
+
344
+ # COMPUTING M NUMBER OF DERIVATIVES Fp[m] TO COMPUTE
345
+ xM = aux_M_Fp(ctx, xA, xeps4, a, xB1, xL)
346
+ yM = aux_M_Fp(ctx, yA, yeps4, a, yB1, yL)
347
+ M = max(xM, yM)
348
+
349
+ # COMPUTING NUMBER OF TERMS J NEEDED
350
+ h3 = aux_J_needed(ctx, xA, xeps4, a, xB1, xM)
351
+ h4 = aux_J_needed(ctx, yA, yeps4, a, yB1, yM)
352
+ h3 = min(h3,h4)
353
+ J = 12
354
+ jvalue = (2*ctx.pi)**J / ctx.gamma(J+1)
355
+ while jvalue > h3:
356
+ J = J+1
357
+ jvalue = (2*ctx.pi)*jvalue/J
358
+
359
+ # COMPUTING eps5[m] for 1 <= m <= 21
360
+ # See II Section 10 equation (43)
361
+ # We choose the minimum of the two possibilities
362
+ eps5={}
363
+ xforeps5 = math.pi*math.pi*xB1*a
364
+ yforeps5 = math.pi*math.pi*yB1*a
365
+ for m in range(0,22):
366
+ xaux1 = math.pow(xforeps5, m/3)/(316.*xA)
367
+ yaux1 = math.pow(yforeps5, m/3)/(316.*yA)
368
+ aux1 = min(xaux1, yaux1)
369
+ aux2 = ctx.gamma(m+1)/ctx.gamma(m/3.0+0.5)
370
+ aux2 = math.sqrt(aux2)
371
+ eps5[m] = (aux1*aux2*min(xeps4,yeps4))
372
+
373
+ # COMPUTING wpfp
374
+ # See II Section 3.13 equation (59)
375
+ twenty = min(3*L-3, 21)+1
376
+ aux = 6812*J
377
+ wpfp = ctx.mag(44*J)
378
+ for m in range(0,twenty):
379
+ wpfp = max(wpfp, ctx.mag(aux*ctx.gamma(m+1)/eps5[m]))
380
+
381
+ # COMPUTING N AND p
382
+ # See II Section
383
+ ctx.prec = wpfp + ctx.mag(t)+20
384
+ a = ctx.sqrt(t/(2*ctx.pi))
385
+ N = ctx.floor(a)
386
+ p = 1-2*(a-N)
387
+
388
+ # now we get a rounded version of p
389
+ # to the precision wpfp
390
+ # this possibly is not necessary
391
+ num=ctx.floor(p*(ctx.mpf('2')**wpfp))
392
+ difference = p * (ctx.mpf('2')**wpfp)-num
393
+ if (difference < 0.5):
394
+ num = num
395
+ else:
396
+ num = num+1
397
+ p = ctx.convert(num * (ctx.mpf('2')**(-wpfp)))
398
+
399
+ # COMPUTING THE COEFFICIENTS c[n] = cc[n]
400
+ # We shall use the notation cc[n], since there is
401
+ # a constant that is called c
402
+ # See II Section 3.14
403
+ # We compute the coefficients and also save then in a
404
+ # cache. The bulk of the computation is passed to
405
+ # the function coef()
406
+ #
407
+ # eps6 is defined in II Section 3.13 equation (58)
408
+ eps6 = ctx.power(ctx.convert(2*ctx.pi), J)/(ctx.gamma(J+1)*3*J)
409
+
410
+ # Now we compute the coefficients
411
+ cc = {}
412
+ cont = {}
413
+ cont, pipowers = coef(ctx, J, eps6)
414
+ cc=cont.copy() # we need a copy since we have to change his values.
415
+ Fp={} # this is the adequate locus of this
416
+ for n in range(M, 3*L-2):
417
+ Fp[n] = 0
418
+ Fp={}
419
+ ctx.prec = wpfp
420
+ for m in range(0,M+1):
421
+ sumP = 0
422
+ for k in range(2*J-m-1,-1,-1):
423
+ sumP = (sumP * p)+ cc[k]
424
+ Fp[m] = sumP
425
+ # preparation of the new coefficients
426
+ for k in range(0,2*J-m-1):
427
+ cc[k] = (k+1)* cc[k+1]
428
+
429
+ # COMPUTING THE NUMBERS xd[u,n,k], yd[u,n,k]
430
+ # See II Section 3.17
431
+ #
432
+ # First we compute the working precisions xwpd[k]
433
+ # Se II equation (92)
434
+ xwpd={}
435
+ d1 = max(6,ctx.mag(40*L*L))
436
+ xd2 = 13+ctx.mag((1+abs(xsigma))*xA)-ctx.mag(xeps4)-1
437
+ xconst = ctx.ln(8/(ctx.pi*ctx.pi*a*a*xB1*xB1)) /2
438
+ for n in range(0,L):
439
+ xd3 = ctx.mag(ctx.sqrt(ctx.gamma(n-0.5)))-ctx.floor(n*xconst)+xd2
440
+ xwpd[n]=max(xd3,d1)
441
+
442
+ # procedure of II Section 3.17
443
+ ctx.prec = xwpd[1]+10
444
+ xpsigma = 1-(2*xsigma)
445
+ xd = {}
446
+ xd[0,0,-2]=0; xd[0,0,-1]=0; xd[0,0,0]=1; xd[0,0,1]=0
447
+ xd[0,-1,-2]=0; xd[0,-1,-1]=0; xd[0,-1,0]=1; xd[0,-1,1]=0
448
+ for n in range(1,L):
449
+ ctx.prec = xwpd[n]+10
450
+ for k in range(0,3*n//2+1):
451
+ m = 3*n-2*k
452
+ if(m!=0):
453
+ m1 = ctx.one/m
454
+ c1= m1/4
455
+ c2=(xpsigma*m1)/2
456
+ c3=-(m+1)
457
+ xd[0,n,k]=c3*xd[0,n-1,k-2]+c1*xd[0,n-1,k]+c2*xd[0,n-1,k-1]
458
+ else:
459
+ xd[0,n,k]=0
460
+ for r in range(0,k):
461
+ add=xd[0,n,r]*(ctx.mpf('1.0')*ctx.fac(2*k-2*r)/ctx.fac(k-r))
462
+ xd[0,n,k] -= ((-1)**(k-r))*add
463
+ xd[0,n,-2]=0; xd[0,n,-1]=0; xd[0,n,3*n//2+1]=0
464
+ for mu in range(-2,der+1):
465
+ for n in range(-2,L):
466
+ for k in range(-3,max(1,3*n//2+2)):
467
+ if( (mu<0)or (n<0) or(k<0)or (k>3*n//2)):
468
+ xd[mu,n,k] = 0
469
+ for mu in range(1,der+1):
470
+ for n in range(0,L):
471
+ ctx.prec = xwpd[n]+10
472
+ for k in range(0,3*n//2+1):
473
+ aux=(2*mu-2)*xd[mu-2,n-2,k-3]+2*(xsigma+n-2)*xd[mu-1,n-2,k-3]
474
+ xd[mu,n,k] = aux - xd[mu-1,n-1,k-1]
475
+
476
+ # Now we compute the working precisions ywpd[k]
477
+ # Se II equation (92)
478
+ ywpd={}
479
+ d1 = max(6,ctx.mag(40*L*L))
480
+ yd2 = 13+ctx.mag((1+abs(ysigma))*yA)-ctx.mag(yeps4)-1
481
+ yconst = ctx.ln(8/(ctx.pi*ctx.pi*a*a*yB1*yB1)) /2
482
+ for n in range(0,L):
483
+ yd3 = ctx.mag(ctx.sqrt(ctx.gamma(n-0.5)))-ctx.floor(n*yconst)+yd2
484
+ ywpd[n]=max(yd3,d1)
485
+
486
+ # procedure of II Section 3.17
487
+ ctx.prec = ywpd[1]+10
488
+ ypsigma = 1-(2*ysigma)
489
+ yd = {}
490
+ yd[0,0,-2]=0; yd[0,0,-1]=0; yd[0,0,0]=1; yd[0,0,1]=0
491
+ yd[0,-1,-2]=0; yd[0,-1,-1]=0; yd[0,-1,0]=1; yd[0,-1,1]=0
492
+ for n in range(1,L):
493
+ ctx.prec = ywpd[n]+10
494
+ for k in range(0,3*n//2+1):
495
+ m = 3*n-2*k
496
+ if(m!=0):
497
+ m1 = ctx.one/m
498
+ c1= m1/4
499
+ c2=(ypsigma*m1)/2
500
+ c3=-(m+1)
501
+ yd[0,n,k]=c3*yd[0,n-1,k-2]+c1*yd[0,n-1,k]+c2*yd[0,n-1,k-1]
502
+ else:
503
+ yd[0,n,k]=0
504
+ for r in range(0,k):
505
+ add=yd[0,n,r]*(ctx.mpf('1.0')*ctx.fac(2*k-2*r)/ctx.fac(k-r))
506
+ yd[0,n,k] -= ((-1)**(k-r))*add
507
+ yd[0,n,-2]=0; yd[0,n,-1]=0; yd[0,n,3*n//2+1]=0
508
+
509
+ for mu in range(-2,der+1):
510
+ for n in range(-2,L):
511
+ for k in range(-3,max(1,3*n//2+2)):
512
+ if( (mu<0)or (n<0) or(k<0)or (k>3*n//2)):
513
+ yd[mu,n,k] = 0
514
+ for mu in range(1,der+1):
515
+ for n in range(0,L):
516
+ ctx.prec = ywpd[n]+10
517
+ for k in range(0,3*n//2+1):
518
+ aux=(2*mu-2)*yd[mu-2,n-2,k-3]+2*(ysigma+n-2)*yd[mu-1,n-2,k-3]
519
+ yd[mu,n,k] = aux - yd[mu-1,n-1,k-1]
520
+
521
+ # COMPUTING THE COEFFICIENTS xtcoef[k,l]
522
+ # See II Section 3.9
523
+ #
524
+ # computing the needed wp
525
+ xwptcoef={}
526
+ xwpterm={}
527
+ ctx.prec = 15
528
+ c1 = ctx.mag(40*(L+2))
529
+ xc2 = ctx.mag(68*(L+2)*xA)
530
+ xc4 = ctx.mag(xB1*a*math.sqrt(ctx.pi))-1
531
+ for k in range(0,L):
532
+ xc3 = xc2 - k*xc4+ctx.mag(ctx.fac(k+0.5))/2.
533
+ xwptcoef[k] = (max(c1,xc3-ctx.mag(xeps4)+1)+1 +20)*1.5
534
+ xwpterm[k] = (max(c1,ctx.mag(L+2)+xc3-ctx.mag(xeps3)+1)+1 +20)
535
+ ywptcoef={}
536
+ ywpterm={}
537
+ ctx.prec = 15
538
+ c1 = ctx.mag(40*(L+2))
539
+ yc2 = ctx.mag(68*(L+2)*yA)
540
+ yc4 = ctx.mag(yB1*a*math.sqrt(ctx.pi))-1
541
+ for k in range(0,L):
542
+ yc3 = yc2 - k*yc4+ctx.mag(ctx.fac(k+0.5))/2.
543
+ ywptcoef[k] = ((max(c1,yc3-ctx.mag(yeps4)+1))+10)*1.5
544
+ ywpterm[k] = (max(c1,ctx.mag(L+2)+yc3-ctx.mag(yeps3)+1)+1)+10
545
+
546
+ # check of power of pi
547
+ # computing the fortcoef[mu,k,ell]
548
+ xfortcoef={}
549
+ for mu in range(0,der+1):
550
+ for k in range(0,L):
551
+ for ell in range(-2,3*k//2+1):
552
+ xfortcoef[mu,k,ell]=0
553
+ for mu in range(0,der+1):
554
+ for k in range(0,L):
555
+ ctx.prec = xwptcoef[k]
556
+ for ell in range(0,3*k//2+1):
557
+ xfortcoef[mu,k,ell]=xd[mu,k,ell]*Fp[3*k-2*ell]/pipowers[2*k-ell]
558
+ xfortcoef[mu,k,ell]=xfortcoef[mu,k,ell]/((2*ctx.j)**ell)
559
+
560
+ def trunc_a(t):
561
+ wp = ctx.prec
562
+ ctx.prec = wp + 2
563
+ aa = ctx.sqrt(t/(2*ctx.pi))
564
+ ctx.prec = wp
565
+ return aa
566
+
567
+ # computing the tcoef[k,ell]
568
+ xtcoef={}
569
+ for mu in range(0,der+1):
570
+ for k in range(0,L):
571
+ for ell in range(-2,3*k//2+1):
572
+ xtcoef[mu,k,ell]=0
573
+ ctx.prec = max(xwptcoef[0],ywptcoef[0])+3
574
+ aa= trunc_a(t)
575
+ la = -ctx.ln(aa)
576
+
577
+ for chi in range(0,der+1):
578
+ for k in range(0,L):
579
+ ctx.prec = xwptcoef[k]
580
+ for ell in range(0,3*k//2+1):
581
+ xtcoef[chi,k,ell] =0
582
+ for mu in range(0, chi+1):
583
+ tcoefter=ctx.binomial(chi,mu)*ctx.power(la,mu)*xfortcoef[chi-mu,k,ell]
584
+ xtcoef[chi,k,ell] += tcoefter
585
+
586
+ # COMPUTING THE COEFFICIENTS ytcoef[k,l]
587
+ # See II Section 3.9
588
+ #
589
+ # computing the needed wp
590
+ # check of power of pi
591
+ # computing the fortcoef[mu,k,ell]
592
+ yfortcoef={}
593
+ for mu in range(0,der+1):
594
+ for k in range(0,L):
595
+ for ell in range(-2,3*k//2+1):
596
+ yfortcoef[mu,k,ell]=0
597
+ for mu in range(0,der+1):
598
+ for k in range(0,L):
599
+ ctx.prec = ywptcoef[k]
600
+ for ell in range(0,3*k//2+1):
601
+ yfortcoef[mu,k,ell]=yd[mu,k,ell]*Fp[3*k-2*ell]/pipowers[2*k-ell]
602
+ yfortcoef[mu,k,ell]=yfortcoef[mu,k,ell]/((2*ctx.j)**ell)
603
+ # computing the tcoef[k,ell]
604
+ ytcoef={}
605
+ for chi in range(0,der+1):
606
+ for k in range(0,L):
607
+ for ell in range(-2,3*k//2+1):
608
+ ytcoef[chi,k,ell]=0
609
+ for chi in range(0,der+1):
610
+ for k in range(0,L):
611
+ ctx.prec = ywptcoef[k]
612
+ for ell in range(0,3*k//2+1):
613
+ ytcoef[chi,k,ell] =0
614
+ for mu in range(0, chi+1):
615
+ tcoefter=ctx.binomial(chi,mu)*ctx.power(la,mu)*yfortcoef[chi-mu,k,ell]
616
+ ytcoef[chi,k,ell] += tcoefter
617
+
618
+ # COMPUTING tv[k,ell]
619
+ # See II Section 3.8
620
+ #
621
+ # a has a good value
622
+ ctx.prec = max(xwptcoef[0], ywptcoef[0])+2
623
+ av = {}
624
+ av[0] = 1
625
+ av[1] = av[0]/a
626
+
627
+ ctx.prec = max(xwptcoef[0],ywptcoef[0])
628
+ for k in range(2,L):
629
+ av[k] = av[k-1] * av[1]
630
+
631
+ # Computing the quotients
632
+ xtv = {}
633
+ for chi in range(0,der+1):
634
+ for k in range(0,L):
635
+ ctx.prec = xwptcoef[k]
636
+ for ell in range(0,3*k//2+1):
637
+ xtv[chi,k,ell] = xtcoef[chi,k,ell]* av[k]
638
+ # Computing the quotients
639
+ ytv = {}
640
+ for chi in range(0,der+1):
641
+ for k in range(0,L):
642
+ ctx.prec = ywptcoef[k]
643
+ for ell in range(0,3*k//2+1):
644
+ ytv[chi,k,ell] = ytcoef[chi,k,ell]* av[k]
645
+
646
+ # COMPUTING THE TERMS xterm[k]
647
+ # See II Section 3.6
648
+ xterm = {}
649
+ for chi in range(0,der+1):
650
+ for n in range(0,L):
651
+ ctx.prec = xwpterm[n]
652
+ te = 0
653
+ for k in range(0, 3*n//2+1):
654
+ te += xtv[chi,n,k]
655
+ xterm[chi,n] = te
656
+
657
+ # COMPUTING THE TERMS yterm[k]
658
+ # See II Section 3.6
659
+ yterm = {}
660
+ for chi in range(0,der+1):
661
+ for n in range(0,L):
662
+ ctx.prec = ywpterm[n]
663
+ te = 0
664
+ for k in range(0, 3*n//2+1):
665
+ te += ytv[chi,n,k]
666
+ yterm[chi,n] = te
667
+
668
+ # COMPUTING rssum
669
+ # See II Section 3.5
670
+ xrssum={}
671
+ ctx.prec=15
672
+ xrsbound = math.sqrt(ctx.pi) * xc /(xb*a)
673
+ ctx.prec=15
674
+ xwprssum = ctx.mag(4.4*((L+3)**2)*xrsbound / xeps2)
675
+ xwprssum = max(xwprssum, ctx.mag(10*(L+1)))
676
+ ctx.prec = xwprssum
677
+ for chi in range(0,der+1):
678
+ xrssum[chi] = 0
679
+ for k in range(1,L+1):
680
+ xrssum[chi] += xterm[chi,L-k]
681
+ yrssum={}
682
+ ctx.prec=15
683
+ yrsbound = math.sqrt(ctx.pi) * yc /(yb*a)
684
+ ctx.prec=15
685
+ ywprssum = ctx.mag(4.4*((L+3)**2)*yrsbound / yeps2)
686
+ ywprssum = max(ywprssum, ctx.mag(10*(L+1)))
687
+ ctx.prec = ywprssum
688
+ for chi in range(0,der+1):
689
+ yrssum[chi] = 0
690
+ for k in range(1,L+1):
691
+ yrssum[chi] += yterm[chi,L-k]
692
+
693
+ # COMPUTING S3
694
+ # See II Section 3.19
695
+ ctx.prec = 15
696
+ A2 = 2**(max(ctx.mag(abs(xrssum[0])), ctx.mag(abs(yrssum[0]))))
697
+ eps8 = eps/(3*A2)
698
+ T = t *ctx.ln(t/(2*ctx.pi))
699
+ xwps3 = 5 + ctx.mag((1+(2/eps8)*ctx.power(a,-xsigma))*T)
700
+ ywps3 = 5 + ctx.mag((1+(2/eps8)*ctx.power(a,-ysigma))*T)
701
+
702
+ ctx.prec = max(xwps3, ywps3)
703
+
704
+ tpi = t/(2*ctx.pi)
705
+ arg = (t/2)*ctx.ln(tpi)-(t/2)-ctx.pi/8
706
+ U = ctx.expj(-arg)
707
+ a = trunc_a(t)
708
+ xasigma = ctx.power(a, -xsigma)
709
+ yasigma = ctx.power(a, -ysigma)
710
+ xS3 = ((-1)**(N-1)) * xasigma * U
711
+ yS3 = ((-1)**(N-1)) * yasigma * U
712
+
713
+ # COMPUTING S1 the zetasum
714
+ # See II Section 3.18
715
+ ctx.prec = 15
716
+ xwpsum = 4+ ctx.mag((N+ctx.power(N,1-xsigma))*ctx.ln(N) /eps1)
717
+ ywpsum = 4+ ctx.mag((N+ctx.power(N,1-ysigma))*ctx.ln(N) /eps1)
718
+ wpsum = max(xwpsum, ywpsum)
719
+
720
+ ctx.prec = wpsum +10
721
+ '''
722
+ # This can be improved
723
+ xS1={}
724
+ yS1={}
725
+ for chi in range(0,der+1):
726
+ xS1[chi] = 0
727
+ yS1[chi] = 0
728
+ for n in range(1,int(N)+1):
729
+ ln = ctx.ln(n)
730
+ xexpn = ctx.exp(-ln*(xsigma+ctx.j*t))
731
+ yexpn = ctx.conj(1/(n*xexpn))
732
+ for chi in range(0,der+1):
733
+ pown = ctx.power(-ln, chi)
734
+ xterm = pown*xexpn
735
+ yterm = pown*yexpn
736
+ xS1[chi] += xterm
737
+ yS1[chi] += yterm
738
+ '''
739
+ xS1, yS1 = ctx._zetasum(s, 1, int(N)-1, range(0,der+1), True)
740
+
741
+ # END OF COMPUTATION of xrz, yrz
742
+ # See II Section 3.1
743
+ ctx.prec = 15
744
+ xabsS1 = abs(xS1[der])
745
+ xabsS2 = abs(xrssum[der] * xS3)
746
+ xwpend = max(6, wpinitial+ctx.mag(6*(3*xabsS1+7*xabsS2) ) )
747
+
748
+ ctx.prec = xwpend
749
+ xrz={}
750
+ for chi in range(0,der+1):
751
+ xrz[chi] = xS1[chi]+xrssum[chi]*xS3
752
+
753
+ ctx.prec = 15
754
+ yabsS1 = abs(yS1[der])
755
+ yabsS2 = abs(yrssum[der] * yS3)
756
+ ywpend = max(6, wpinitial+ctx.mag(6*(3*yabsS1+7*yabsS2) ) )
757
+
758
+ ctx.prec = ywpend
759
+ yrz={}
760
+ for chi in range(0,der+1):
761
+ yrz[chi] = yS1[chi]+yrssum[chi]*yS3
762
+ yrz[chi] = ctx.conj(yrz[chi])
763
+ ctx.prec = wpinitial
764
+ return xrz, yrz
765
+
766
+ def Rzeta_set(ctx, s, derivatives=[0]):
767
+ r"""
768
+ Computes several derivatives of the auxiliary function of Riemann `R(s)`.
769
+
770
+ **Definition**
771
+
772
+ The function is defined by
773
+
774
+ .. math ::
775
+
776
+ \begin{equation}
777
+ {\mathop{\mathcal R }\nolimits}(s)=
778
+ \int_{0\swarrow1}\frac{x^{-s} e^{\pi i x^2}}{e^{\pi i x}-
779
+ e^{-\pi i x}}\,dx
780
+ \end{equation}
781
+
782
+ To this function we apply the Riemann-Siegel expansion.
783
+ """
784
+ der = max(derivatives)
785
+ # First we take the value of ctx.prec
786
+ # During the computation we will change ctx.prec, and finally we will
787
+ # restaurate the initial value
788
+ wpinitial = ctx.prec
789
+ # Take the real and imaginary part of s
790
+ t = ctx._im(s)
791
+ sigma = ctx._re(s)
792
+ # Now compute several parameter that appear on the program
793
+ ctx.prec = 15
794
+ a = ctx.sqrt(t/(2*ctx.pi)) # Careful
795
+ asigma = ctx.power(a, sigma) # Careful
796
+ # We need a simple bound A1 < asigma (see II Section 3.1 and 3.3)
797
+ A1 = ctx.power(2, ctx.mag(asigma)-1)
798
+ # We compute various epsilon's (see II end of Section 3.1)
799
+ eps = ctx.power(2, -wpinitial)
800
+ eps1 = eps/6.
801
+ eps2 = eps * A1/3.
802
+ # COMPUTING SOME COEFFICIENTS THAT DEPENDS
803
+ # ON sigma
804
+ # constant b and c (see I Theorem 2 formula (26) )
805
+ # coefficients A and B1 (see I Section 6.1 equation (50))
806
+ # here we not need high precision
807
+ ctx.prec = 15
808
+ if sigma > 0:
809
+ b = 2.
810
+ c = math.pow(9,sigma)/4.44288
811
+ # 4.44288 =(math.sqrt(2)*math.pi)
812
+ A = math.pow(9,sigma)
813
+ B1 = 1
814
+ else:
815
+ b = 2.25158 # math.sqrt( (3-2* math.log(2))*math.pi )
816
+ c = math.pow(2,-sigma)/4.44288
817
+ A = math.pow(2,-sigma)
818
+ B1 = 1.10789 # = 2*sqrt(1-log(2))
819
+ # COMPUTING L THE NUMBER OF TERMS NEEDED IN THE RIEMANN-SIEGEL
820
+ # CORRECTION
821
+ # See II Section 3.2
822
+ ctx.prec = 15
823
+ L = 1
824
+ while 3*c*ctx.gamma(L*0.5) * ctx.power(b*a,-L) >= eps2:
825
+ L = L+1
826
+ L = max(2,L)
827
+ # The number L has to satify some conditions.
828
+ # If not RS can not compute Rzeta(s) with the prescribed precision
829
+ # (see II, Section 3.2 condition (20) ) and
830
+ # (II, Section 3.3 condition (22) ). Also we have added
831
+ # an additional technical condition in Section 3.17 Proposition 17
832
+ if ((3*L >= 2*a*a/25.) or (3*L+2+sigma<0) or (abs(sigma)> a/2.)):
833
+ #print 'Error Riemann-Siegel can not compute with such precision'
834
+ ctx.prec = wpinitial
835
+ raise NotImplementedError("Riemann-Siegel can not compute with such precision")
836
+
837
+ # INITIALIZATION (CONTINUATION)
838
+ #
839
+ # eps3 is the constant defined on (II, Section 3.5 equation (27) )
840
+ # each term of the RS correction must be computed with error <= eps3
841
+ eps3 = eps2/(4*L)
842
+
843
+ # eps4 is defined on (II Section 3.6 equation (30) )
844
+ # each component of the formula (II Section 3.6 equation (29) )
845
+ # must be computed with error <= eps4
846
+ eps4 = eps3/(3*L)
847
+
848
+ # COMPUTING M. NUMBER OF DERIVATIVES Fp[m] TO COMPUTE
849
+ M = aux_M_Fp(ctx, A, eps4, a, B1, L)
850
+ Fp = {}
851
+ for n in range(M, 3*L-2):
852
+ Fp[n] = 0
853
+
854
+ # But I have not seen an instance of M != 3*L-3
855
+ #
856
+ # DETERMINATION OF J THE NUMBER OF TERMS NEEDED
857
+ # IN THE TAYLOR SERIES OF F.
858
+ # See II Section 3.11 equation (49))
859
+ h1 = eps4/(632*A)
860
+ h2 = ctx.pi*ctx.pi*B1*a *ctx.sqrt(3)*math.e*math.e
861
+ h2 = h1 * ctx.power((h2/M**2),(M-1)/3) / M
862
+ h3 = min(h1,h2)
863
+ J=12
864
+ jvalue = (2*ctx.pi)**J / ctx.gamma(J+1)
865
+ while jvalue > h3:
866
+ J = J+1
867
+ jvalue = (2*ctx.pi)*jvalue/J
868
+
869
+ # COMPUTING eps5[m] for 1 <= m <= 21
870
+ # See II Section 10 equation (43)
871
+ eps5={}
872
+ foreps5 = math.pi*math.pi*B1*a
873
+ for m in range(0,22):
874
+ aux1 = math.pow(foreps5, m/3)/(316.*A)
875
+ aux2 = ctx.gamma(m+1)/ctx.gamma(m/3.0+0.5)
876
+ aux2 = math.sqrt(aux2)
877
+ eps5[m] = aux1*aux2*eps4
878
+
879
+ # COMPUTING wpfp
880
+ # See II Section 3.13 equation (59)
881
+ twenty = min(3*L-3, 21)+1
882
+ aux = 6812*J
883
+ wpfp = ctx.mag(44*J)
884
+ for m in range(0, twenty):
885
+ wpfp = max(wpfp, ctx.mag(aux*ctx.gamma(m+1)/eps5[m]))
886
+ # COMPUTING N AND p
887
+ # See II Section
888
+ ctx.prec = wpfp + ctx.mag(t) + 20
889
+ a = ctx.sqrt(t/(2*ctx.pi))
890
+ N = ctx.floor(a)
891
+ p = 1-2*(a-N)
892
+
893
+ # now we get a rounded version of p to the precision wpfp
894
+ # this possibly is not necessary
895
+ num = ctx.floor(p*(ctx.mpf(2)**wpfp))
896
+ difference = p * (ctx.mpf(2)**wpfp)-num
897
+ if difference < 0.5:
898
+ num = num
899
+ else:
900
+ num = num+1
901
+ p = ctx.convert(num * (ctx.mpf(2)**(-wpfp)))
902
+
903
+ # COMPUTING THE COEFFICIENTS c[n] = cc[n]
904
+ # We shall use the notation cc[n], since there is
905
+ # a constant that is called c
906
+ # See II Section 3.14
907
+ # We compute the coefficients and also save then in a
908
+ # cache. The bulk of the computation is passed to
909
+ # the function coef()
910
+ #
911
+ # eps6 is defined in II Section 3.13 equation (58)
912
+ eps6 = ctx.power(2*ctx.pi, J)/(ctx.gamma(J+1)*3*J)
913
+
914
+ # Now we compute the coefficients
915
+ cc={}
916
+ cont={}
917
+ cont, pipowers = coef(ctx, J, eps6)
918
+ cc = cont.copy() # we need a copy since we have
919
+ Fp={}
920
+ for n in range(M, 3*L-2):
921
+ Fp[n] = 0
922
+ ctx.prec = wpfp
923
+ for m in range(0,M+1):
924
+ sumP = 0
925
+ for k in range(2*J-m-1,-1,-1):
926
+ sumP = (sumP * p) + cc[k]
927
+ Fp[m] = sumP
928
+ # preparation of the new coefficients
929
+ for k in range(0, 2*J-m-1):
930
+ cc[k] = (k+1) * cc[k+1]
931
+
932
+ # COMPUTING THE NUMBERS d[n,k]
933
+ # See II Section 3.17
934
+
935
+ # First we compute the working precisions wpd[k]
936
+ # Se II equation (92)
937
+ wpd = {}
938
+ d1 = max(6, ctx.mag(40*L*L))
939
+ d2 = 13+ctx.mag((1+abs(sigma))*A)-ctx.mag(eps4)-1
940
+ const = ctx.ln(8/(ctx.pi*ctx.pi*a*a*B1*B1)) /2
941
+ for n in range(0,L):
942
+ d3 = ctx.mag(ctx.sqrt(ctx.gamma(n-0.5)))-ctx.floor(n*const)+d2
943
+ wpd[n] = max(d3,d1)
944
+
945
+ # procedure of II Section 3.17
946
+ ctx.prec = wpd[1]+10
947
+ psigma = 1-(2*sigma)
948
+ d = {}
949
+ d[0,0,-2]=0; d[0,0,-1]=0; d[0,0,0]=1; d[0,0,1]=0
950
+ d[0,-1,-2]=0; d[0,-1,-1]=0; d[0,-1,0]=1; d[0,-1,1]=0
951
+ for n in range(1,L):
952
+ ctx.prec = wpd[n]+10
953
+ for k in range(0,3*n//2+1):
954
+ m = 3*n-2*k
955
+ if (m!=0):
956
+ m1 = ctx.one/m
957
+ c1 = m1/4
958
+ c2 = (psigma*m1)/2
959
+ c3 = -(m+1)
960
+ d[0,n,k] = c3*d[0,n-1,k-2]+c1*d[0,n-1,k]+c2*d[0,n-1,k-1]
961
+ else:
962
+ d[0,n,k]=0
963
+ for r in range(0,k):
964
+ add = d[0,n,r]*(ctx.one*ctx.fac(2*k-2*r)/ctx.fac(k-r))
965
+ d[0,n,k] -= ((-1)**(k-r))*add
966
+ d[0,n,-2]=0; d[0,n,-1]=0; d[0,n,3*n//2+1]=0
967
+
968
+ for mu in range(-2,der+1):
969
+ for n in range(-2,L):
970
+ for k in range(-3,max(1,3*n//2+2)):
971
+ if ((mu<0)or (n<0) or(k<0)or (k>3*n//2)):
972
+ d[mu,n,k] = 0
973
+
974
+ for mu in range(1,der+1):
975
+ for n in range(0,L):
976
+ ctx.prec = wpd[n]+10
977
+ for k in range(0,3*n//2+1):
978
+ aux=(2*mu-2)*d[mu-2,n-2,k-3]+2*(sigma+n-2)*d[mu-1,n-2,k-3]
979
+ d[mu,n,k] = aux - d[mu-1,n-1,k-1]
980
+
981
+ # COMPUTING THE COEFFICIENTS t[k,l]
982
+ # See II Section 3.9
983
+ #
984
+ # computing the needed wp
985
+ wptcoef = {}
986
+ wpterm = {}
987
+ ctx.prec = 15
988
+ c1 = ctx.mag(40*(L+2))
989
+ c2 = ctx.mag(68*(L+2)*A)
990
+ c4 = ctx.mag(B1*a*math.sqrt(ctx.pi))-1
991
+ for k in range(0,L):
992
+ c3 = c2 - k*c4+ctx.mag(ctx.fac(k+0.5))/2.
993
+ wptcoef[k] = max(c1,c3-ctx.mag(eps4)+1)+1 +10
994
+ wpterm[k] = max(c1,ctx.mag(L+2)+c3-ctx.mag(eps3)+1)+1 +10
995
+
996
+ # check of power of pi
997
+
998
+ # computing the fortcoef[mu,k,ell]
999
+ fortcoef={}
1000
+ for mu in derivatives:
1001
+ for k in range(0,L):
1002
+ for ell in range(-2,3*k//2+1):
1003
+ fortcoef[mu,k,ell]=0
1004
+
1005
+ for mu in derivatives:
1006
+ for k in range(0,L):
1007
+ ctx.prec = wptcoef[k]
1008
+ for ell in range(0,3*k//2+1):
1009
+ fortcoef[mu,k,ell]=d[mu,k,ell]*Fp[3*k-2*ell]/pipowers[2*k-ell]
1010
+ fortcoef[mu,k,ell]=fortcoef[mu,k,ell]/((2*ctx.j)**ell)
1011
+
1012
+ def trunc_a(t):
1013
+ wp = ctx.prec
1014
+ ctx.prec = wp + 2
1015
+ aa = ctx.sqrt(t/(2*ctx.pi))
1016
+ ctx.prec = wp
1017
+ return aa
1018
+
1019
+ # computing the tcoef[chi,k,ell]
1020
+ tcoef={}
1021
+ for chi in derivatives:
1022
+ for k in range(0,L):
1023
+ for ell in range(-2,3*k//2+1):
1024
+ tcoef[chi,k,ell]=0
1025
+ ctx.prec = wptcoef[0]+3
1026
+ aa = trunc_a(t)
1027
+ la = -ctx.ln(aa)
1028
+
1029
+ for chi in derivatives:
1030
+ for k in range(0,L):
1031
+ ctx.prec = wptcoef[k]
1032
+ for ell in range(0,3*k//2+1):
1033
+ tcoef[chi,k,ell] = 0
1034
+ for mu in range(0, chi+1):
1035
+ tcoefter = ctx.binomial(chi,mu) * la**mu * \
1036
+ fortcoef[chi-mu,k,ell]
1037
+ tcoef[chi,k,ell] += tcoefter
1038
+
1039
+ # COMPUTING tv[k,ell]
1040
+ # See II Section 3.8
1041
+
1042
+ # Computing the powers av[k] = a**(-k)
1043
+ ctx.prec = wptcoef[0] + 2
1044
+
1045
+ # a has a good value of a.
1046
+ # See II Section 3.6
1047
+ av = {}
1048
+ av[0] = 1
1049
+ av[1] = av[0]/a
1050
+
1051
+ ctx.prec = wptcoef[0]
1052
+ for k in range(2,L):
1053
+ av[k] = av[k-1] * av[1]
1054
+
1055
+ # Computing the quotients
1056
+ tv = {}
1057
+ for chi in derivatives:
1058
+ for k in range(0,L):
1059
+ ctx.prec = wptcoef[k]
1060
+ for ell in range(0,3*k//2+1):
1061
+ tv[chi,k,ell] = tcoef[chi,k,ell]* av[k]
1062
+
1063
+ # COMPUTING THE TERMS term[k]
1064
+ # See II Section 3.6
1065
+ term = {}
1066
+ for chi in derivatives:
1067
+ for n in range(0,L):
1068
+ ctx.prec = wpterm[n]
1069
+ te = 0
1070
+ for k in range(0, 3*n//2+1):
1071
+ te += tv[chi,n,k]
1072
+ term[chi,n] = te
1073
+
1074
+ # COMPUTING rssum
1075
+ # See II Section 3.5
1076
+ rssum={}
1077
+ ctx.prec=15
1078
+ rsbound = math.sqrt(ctx.pi) * c /(b*a)
1079
+ ctx.prec=15
1080
+ wprssum = ctx.mag(4.4*((L+3)**2)*rsbound / eps2)
1081
+ wprssum = max(wprssum, ctx.mag(10*(L+1)))
1082
+ ctx.prec = wprssum
1083
+ for chi in derivatives:
1084
+ rssum[chi] = 0
1085
+ for k in range(1,L+1):
1086
+ rssum[chi] += term[chi,L-k]
1087
+
1088
+ # COMPUTING S3
1089
+ # See II Section 3.19
1090
+ ctx.prec = 15
1091
+ A2 = 2**(ctx.mag(rssum[0]))
1092
+ eps8 = eps/(3* A2)
1093
+ T = t * ctx.ln(t/(2*ctx.pi))
1094
+ wps3 = 5 + ctx.mag((1+(2/eps8)*ctx.power(a,-sigma))*T)
1095
+
1096
+ ctx.prec = wps3
1097
+ tpi = t/(2*ctx.pi)
1098
+ arg = (t/2)*ctx.ln(tpi)-(t/2)-ctx.pi/8
1099
+ U = ctx.expj(-arg)
1100
+ a = trunc_a(t)
1101
+ asigma = ctx.power(a, -sigma)
1102
+ S3 = ((-1)**(N-1)) * asigma * U
1103
+
1104
+ # COMPUTING S1 the zetasum
1105
+ # See II Section 3.18
1106
+ ctx.prec = 15
1107
+ wpsum = 4 + ctx.mag((N+ctx.power(N,1-sigma))*ctx.ln(N)/eps1)
1108
+
1109
+ ctx.prec = wpsum + 10
1110
+ '''
1111
+ # This can be improved
1112
+ S1 = {}
1113
+ for chi in derivatives:
1114
+ S1[chi] = 0
1115
+ for n in range(1,int(N)+1):
1116
+ ln = ctx.ln(n)
1117
+ expn = ctx.exp(-ln*(sigma+ctx.j*t))
1118
+ for chi in derivatives:
1119
+ term = ctx.power(-ln, chi)*expn
1120
+ S1[chi] += term
1121
+ '''
1122
+ S1 = ctx._zetasum(s, 1, int(N)-1, derivatives)[0]
1123
+
1124
+ # END OF COMPUTATION
1125
+ # See II Section 3.1
1126
+ ctx.prec = 15
1127
+ absS1 = abs(S1[der])
1128
+ absS2 = abs(rssum[der] * S3)
1129
+ wpend = max(6, wpinitial + ctx.mag(6*(3*absS1+7*absS2)))
1130
+ ctx.prec = wpend
1131
+ rz = {}
1132
+ for chi in derivatives:
1133
+ rz[chi] = S1[chi]+rssum[chi]*S3
1134
+ ctx.prec = wpinitial
1135
+ return rz
1136
+
1137
+
1138
+ def z_half(ctx,t,der=0):
1139
+ r"""
1140
+ z_half(t,der=0) Computes Z^(der)(t)
1141
+ """
1142
+ s=ctx.mpf('0.5')+ctx.j*t
1143
+ wpinitial = ctx.prec
1144
+ ctx.prec = 15
1145
+ tt = t/(2*ctx.pi)
1146
+ wptheta = wpinitial +1 + ctx.mag(3*(tt**1.5)*ctx.ln(tt))
1147
+ wpz = wpinitial + 1 + ctx.mag(12*tt*ctx.ln(tt))
1148
+ ctx.prec = wptheta
1149
+ theta = ctx.siegeltheta(t)
1150
+ ctx.prec = wpz
1151
+ rz = Rzeta_set(ctx,s, range(der+1))
1152
+ if der > 0: ps1 = ctx._re(ctx.psi(0,s/2)/2 - ctx.ln(ctx.pi)/2)
1153
+ if der > 1: ps2 = ctx._re(ctx.j*ctx.psi(1,s/2)/4)
1154
+ if der > 2: ps3 = ctx._re(-ctx.psi(2,s/2)/8)
1155
+ if der > 3: ps4 = ctx._re(-ctx.j*ctx.psi(3,s/2)/16)
1156
+ exptheta = ctx.expj(theta)
1157
+ if der == 0:
1158
+ z = 2*exptheta*rz[0]
1159
+ if der == 1:
1160
+ zf = 2j*exptheta
1161
+ z = zf*(ps1*rz[0]+rz[1])
1162
+ if der == 2:
1163
+ zf = 2 * exptheta
1164
+ z = -zf*(2*rz[1]*ps1+rz[0]*ps1**2+rz[2]-ctx.j*rz[0]*ps2)
1165
+ if der == 3:
1166
+ zf = -2j*exptheta
1167
+ z = 3*rz[1]*ps1**2+rz[0]*ps1**3+3*ps1*rz[2]
1168
+ z = zf*(z-3j*rz[1]*ps2-3j*rz[0]*ps1*ps2+rz[3]-rz[0]*ps3)
1169
+ if der == 4:
1170
+ zf = 2*exptheta
1171
+ z = 4*rz[1]*ps1**3+rz[0]*ps1**4+6*ps1**2*rz[2]
1172
+ z = z-12j*rz[1]*ps1*ps2-6j*rz[0]*ps1**2*ps2-6j*rz[2]*ps2-3*rz[0]*ps2*ps2
1173
+ z = z + 4*ps1*rz[3]-4*rz[1]*ps3-4*rz[0]*ps1*ps3+rz[4]+ctx.j*rz[0]*ps4
1174
+ z = zf*z
1175
+ ctx.prec = wpinitial
1176
+ return ctx._re(z)
1177
+
1178
+ def zeta_half(ctx, s, k=0):
1179
+ """
1180
+ zeta_half(s,k=0) Computes zeta^(k)(s) when Re s = 0.5
1181
+ """
1182
+ wpinitial = ctx.prec
1183
+ sigma = ctx._re(s)
1184
+ t = ctx._im(s)
1185
+ #--- compute wptheta, wpR, wpbasic ---
1186
+ ctx.prec = 53
1187
+ # X see II Section 3.21 (109) and (110)
1188
+ if sigma > 0:
1189
+ X = ctx.sqrt(abs(s))
1190
+ else:
1191
+ X = (2*ctx.pi)**(sigma-1) * abs(1-s)**(0.5-sigma)
1192
+ # M1 see II Section 3.21 (111) and (112)
1193
+ if sigma > 0:
1194
+ M1 = 2*ctx.sqrt(t/(2*ctx.pi))
1195
+ else:
1196
+ M1 = 4 * t * X
1197
+ # T see II Section 3.21 (113)
1198
+ abst = abs(0.5-s)
1199
+ T = 2* abst*math.log(abst)
1200
+ # computing wpbasic, wptheta, wpR see II Section 3.21
1201
+ wpbasic = max(6,3+ctx.mag(t))
1202
+ wpbasic2 = 2+ctx.mag(2.12*M1+21.2*M1*X+1.3*M1*X*T)+wpinitial+1
1203
+ wpbasic = max(wpbasic, wpbasic2)
1204
+ wptheta = max(4, 3+ctx.mag(2.7*M1*X)+wpinitial+1)
1205
+ wpR = 3+ctx.mag(1.1+2*X)+wpinitial+1
1206
+ ctx.prec = wptheta
1207
+ theta = ctx.siegeltheta(t-ctx.j*(sigma-ctx.mpf('0.5')))
1208
+ if k > 0: ps1 = (ctx._re(ctx.psi(0,s/2)))/2 - ctx.ln(ctx.pi)/2
1209
+ if k > 1: ps2 = -(ctx._im(ctx.psi(1,s/2)))/4
1210
+ if k > 2: ps3 = -(ctx._re(ctx.psi(2,s/2)))/8
1211
+ if k > 3: ps4 = (ctx._im(ctx.psi(3,s/2)))/16
1212
+ ctx.prec = wpR
1213
+ xrz = Rzeta_set(ctx,s,range(k+1))
1214
+ yrz={}
1215
+ for chi in range(0,k+1):
1216
+ yrz[chi] = ctx.conj(xrz[chi])
1217
+ ctx.prec = wpbasic
1218
+ exptheta = ctx.expj(-2*theta)
1219
+ if k==0:
1220
+ zv = xrz[0]+exptheta*yrz[0]
1221
+ if k==1:
1222
+ zv1 = -yrz[1] - 2*yrz[0]*ps1
1223
+ zv = xrz[1] + exptheta*zv1
1224
+ if k==2:
1225
+ zv1 = 4*yrz[1]*ps1+4*yrz[0]*(ps1**2)+yrz[2]+2j*yrz[0]*ps2
1226
+ zv = xrz[2]+exptheta*zv1
1227
+ if k==3:
1228
+ zv1 = -12*yrz[1]*ps1**2-8*yrz[0]*ps1**3-6*yrz[2]*ps1-6j*yrz[1]*ps2
1229
+ zv1 = zv1 - 12j*yrz[0]*ps1*ps2-yrz[3]+2*yrz[0]*ps3
1230
+ zv = xrz[3]+exptheta*zv1
1231
+ if k == 4:
1232
+ zv1 = 32*yrz[1]*ps1**3 +16*yrz[0]*ps1**4+24*yrz[2]*ps1**2
1233
+ zv1 = zv1 +48j*yrz[1]*ps1*ps2+48j*yrz[0]*(ps1**2)*ps2
1234
+ zv1 = zv1+12j*yrz[2]*ps2-12*yrz[0]*ps2**2+8*yrz[3]*ps1-8*yrz[1]*ps3
1235
+ zv1 = zv1-16*yrz[0]*ps1*ps3+yrz[4]-2j*yrz[0]*ps4
1236
+ zv = xrz[4]+exptheta*zv1
1237
+ ctx.prec = wpinitial
1238
+ return zv
1239
+
1240
+ def zeta_offline(ctx, s, k=0):
1241
+ """
1242
+ Computes zeta^(k)(s) off the line
1243
+ """
1244
+ wpinitial = ctx.prec
1245
+ sigma = ctx._re(s)
1246
+ t = ctx._im(s)
1247
+ #--- compute wptheta, wpR, wpbasic ---
1248
+ ctx.prec = 53
1249
+ # X see II Section 3.21 (109) and (110)
1250
+ if sigma > 0:
1251
+ X = ctx.power(abs(s), 0.5)
1252
+ else:
1253
+ X = ctx.power(2*ctx.pi, sigma-1)*ctx.power(abs(1-s),0.5-sigma)
1254
+ # M1 see II Section 3.21 (111) and (112)
1255
+ if (sigma > 0):
1256
+ M1 = 2*ctx.sqrt(t/(2*ctx.pi))
1257
+ else:
1258
+ M1 = 4 * t * X
1259
+ # M2 see II Section 3.21 (111) and (112)
1260
+ if (1-sigma > 0):
1261
+ M2 = 2*ctx.sqrt(t/(2*ctx.pi))
1262
+ else:
1263
+ M2 = 4*t*ctx.power(2*ctx.pi, -sigma)*ctx.power(abs(s),sigma-0.5)
1264
+ # T see II Section 3.21 (113)
1265
+ abst = abs(0.5-s)
1266
+ T = 2* abst*math.log(abst)
1267
+ # computing wpbasic, wptheta, wpR see II Section 3.21
1268
+ wpbasic = max(6,3+ctx.mag(t))
1269
+ wpbasic2 = 2+ctx.mag(2.12*M1+21.2*M2*X+1.3*M2*X*T)+wpinitial+1
1270
+ wpbasic = max(wpbasic, wpbasic2)
1271
+ wptheta = max(4, 3+ctx.mag(2.7*M2*X)+wpinitial+1)
1272
+ wpR = 3+ctx.mag(1.1+2*X)+wpinitial+1
1273
+ ctx.prec = wptheta
1274
+ theta = ctx.siegeltheta(t-ctx.j*(sigma-ctx.mpf('0.5')))
1275
+ s1 = s
1276
+ s2 = ctx.conj(1-s1)
1277
+ ctx.prec = wpR
1278
+ xrz, yrz = Rzeta_simul(ctx, s, k)
1279
+ if k > 0: ps1 = (ctx.psi(0,s1/2)+ctx.psi(0,(1-s1)/2))/4 - ctx.ln(ctx.pi)/2
1280
+ if k > 1: ps2 = ctx.j*(ctx.psi(1,s1/2)-ctx.psi(1,(1-s1)/2))/8
1281
+ if k > 2: ps3 = -(ctx.psi(2,s1/2)+ctx.psi(2,(1-s1)/2))/16
1282
+ if k > 3: ps4 = -ctx.j*(ctx.psi(3,s1/2)-ctx.psi(3,(1-s1)/2))/32
1283
+ ctx.prec = wpbasic
1284
+ exptheta = ctx.expj(-2*theta)
1285
+ if k == 0:
1286
+ zv = xrz[0]+exptheta*yrz[0]
1287
+ if k == 1:
1288
+ zv1 = -yrz[1]-2*yrz[0]*ps1
1289
+ zv = xrz[1]+exptheta*zv1
1290
+ if k == 2:
1291
+ zv1 = 4*yrz[1]*ps1+4*yrz[0]*(ps1**2) +yrz[2]+2j*yrz[0]*ps2
1292
+ zv = xrz[2]+exptheta*zv1
1293
+ if k == 3:
1294
+ zv1 = -12*yrz[1]*ps1**2 -8*yrz[0]*ps1**3-6*yrz[2]*ps1-6j*yrz[1]*ps2
1295
+ zv1 = zv1 - 12j*yrz[0]*ps1*ps2-yrz[3]+2*yrz[0]*ps3
1296
+ zv = xrz[3]+exptheta*zv1
1297
+ if k == 4:
1298
+ zv1 = 32*yrz[1]*ps1**3 +16*yrz[0]*ps1**4+24*yrz[2]*ps1**2
1299
+ zv1 = zv1 +48j*yrz[1]*ps1*ps2+48j*yrz[0]*(ps1**2)*ps2
1300
+ zv1 = zv1+12j*yrz[2]*ps2-12*yrz[0]*ps2**2+8*yrz[3]*ps1-8*yrz[1]*ps3
1301
+ zv1 = zv1-16*yrz[0]*ps1*ps3+yrz[4]-2j*yrz[0]*ps4
1302
+ zv = xrz[4]+exptheta*zv1
1303
+ ctx.prec = wpinitial
1304
+ return zv
1305
+
1306
+ def z_offline(ctx, w, k=0):
1307
+ r"""
1308
+ Computes Z(w) and its derivatives off the line
1309
+ """
1310
+ s = ctx.mpf('0.5')+ctx.j*w
1311
+ s1 = s
1312
+ s2 = ctx.conj(1-s1)
1313
+ wpinitial = ctx.prec
1314
+ ctx.prec = 35
1315
+ # X see II Section 3.21 (109) and (110)
1316
+ # M1 see II Section 3.21 (111) and (112)
1317
+ if (ctx._re(s1) >= 0):
1318
+ M1 = 2*ctx.sqrt(ctx._im(s1)/(2 * ctx.pi))
1319
+ X = ctx.sqrt(abs(s1))
1320
+ else:
1321
+ X = (2*ctx.pi)**(ctx._re(s1)-1) * abs(1-s1)**(0.5-ctx._re(s1))
1322
+ M1 = 4 * ctx._im(s1)*X
1323
+ # M2 see II Section 3.21 (111) and (112)
1324
+ if (ctx._re(s2) >= 0):
1325
+ M2 = 2*ctx.sqrt(ctx._im(s2)/(2 * ctx.pi))
1326
+ else:
1327
+ M2 = 4 * ctx._im(s2)*(2*ctx.pi)**(ctx._re(s2)-1)*abs(1-s2)**(0.5-ctx._re(s2))
1328
+ # T see II Section 3.21 Prop. 27
1329
+ T = 2*abs(ctx.siegeltheta(w))
1330
+ # defining some precisions
1331
+ # see II Section 3.22 (115), (116), (117)
1332
+ aux1 = ctx.sqrt(X)
1333
+ aux2 = aux1*(M1+M2)
1334
+ aux3 = 3 +wpinitial
1335
+ wpbasic = max(6, 3+ctx.mag(T), ctx.mag(aux2*(26+2*T))+aux3)
1336
+ wptheta = max(4,ctx.mag(2.04*aux2)+aux3)
1337
+ wpR = ctx.mag(4*aux1)+aux3
1338
+ # now the computations
1339
+ ctx.prec = wptheta
1340
+ theta = ctx.siegeltheta(w)
1341
+ ctx.prec = wpR
1342
+ xrz, yrz = Rzeta_simul(ctx,s,k)
1343
+ pta = 0.25 + 0.5j*w
1344
+ ptb = 0.25 - 0.5j*w
1345
+ if k > 0: ps1 = 0.25*(ctx.psi(0,pta)+ctx.psi(0,ptb)) - ctx.ln(ctx.pi)/2
1346
+ if k > 1: ps2 = (1j/8)*(ctx.psi(1,pta)-ctx.psi(1,ptb))
1347
+ if k > 2: ps3 = (-1./16)*(ctx.psi(2,pta)+ctx.psi(2,ptb))
1348
+ if k > 3: ps4 = (-1j/32)*(ctx.psi(3,pta)-ctx.psi(3,ptb))
1349
+ ctx.prec = wpbasic
1350
+ exptheta = ctx.expj(theta)
1351
+ if k == 0:
1352
+ zv = exptheta*xrz[0]+yrz[0]/exptheta
1353
+ j = ctx.j
1354
+ if k == 1:
1355
+ zv = j*exptheta*(xrz[1]+xrz[0]*ps1)-j*(yrz[1]+yrz[0]*ps1)/exptheta
1356
+ if k == 2:
1357
+ zv = exptheta*(-2*xrz[1]*ps1-xrz[0]*ps1**2-xrz[2]+j*xrz[0]*ps2)
1358
+ zv =zv + (-2*yrz[1]*ps1-yrz[0]*ps1**2-yrz[2]-j*yrz[0]*ps2)/exptheta
1359
+ if k == 3:
1360
+ zv1 = -3*xrz[1]*ps1**2-xrz[0]*ps1**3-3*xrz[2]*ps1+j*3*xrz[1]*ps2
1361
+ zv1 = (zv1+ 3j*xrz[0]*ps1*ps2-xrz[3]+xrz[0]*ps3)*j*exptheta
1362
+ zv2 = 3*yrz[1]*ps1**2+yrz[0]*ps1**3+3*yrz[2]*ps1+j*3*yrz[1]*ps2
1363
+ zv2 = j*(zv2 + 3j*yrz[0]*ps1*ps2+ yrz[3]-yrz[0]*ps3)/exptheta
1364
+ zv = zv1+zv2
1365
+ if k == 4:
1366
+ zv1 = 4*xrz[1]*ps1**3+xrz[0]*ps1**4 + 6*xrz[2]*ps1**2
1367
+ zv1 = zv1-12j*xrz[1]*ps1*ps2-6j*xrz[0]*ps1**2*ps2-6j*xrz[2]*ps2
1368
+ zv1 = zv1-3*xrz[0]*ps2*ps2+4*xrz[3]*ps1-4*xrz[1]*ps3-4*xrz[0]*ps1*ps3
1369
+ zv1 = zv1+xrz[4]+j*xrz[0]*ps4
1370
+ zv2 = 4*yrz[1]*ps1**3+yrz[0]*ps1**4 + 6*yrz[2]*ps1**2
1371
+ zv2 = zv2+12j*yrz[1]*ps1*ps2+6j*yrz[0]*ps1**2*ps2+6j*yrz[2]*ps2
1372
+ zv2 = zv2-3*yrz[0]*ps2*ps2+4*yrz[3]*ps1-4*yrz[1]*ps3-4*yrz[0]*ps1*ps3
1373
+ zv2 = zv2+yrz[4]-j*yrz[0]*ps4
1374
+ zv = exptheta*zv1+zv2/exptheta
1375
+ ctx.prec = wpinitial
1376
+ return zv
1377
+
1378
+ @defun
1379
+ def rs_zeta(ctx, s, derivative=0, **kwargs):
1380
+ if derivative > 4:
1381
+ raise NotImplementedError
1382
+ s = ctx.convert(s)
1383
+ re = ctx._re(s); im = ctx._im(s)
1384
+ if im < 0:
1385
+ z = ctx.conj(ctx.rs_zeta(ctx.conj(s), derivative))
1386
+ return z
1387
+ critical_line = (re == 0.5)
1388
+ if critical_line:
1389
+ return zeta_half(ctx, s, derivative)
1390
+ else:
1391
+ return zeta_offline(ctx, s, derivative)
1392
+
1393
+ @defun
1394
+ def rs_z(ctx, w, derivative=0):
1395
+ w = ctx.convert(w)
1396
+ re = ctx._re(w); im = ctx._im(w)
1397
+ if re < 0:
1398
+ return rs_z(ctx, -w, derivative)
1399
+ critical_line = (im == 0)
1400
+ if critical_line :
1401
+ return z_half(ctx, w, derivative)
1402
+ else:
1403
+ return z_offline(ctx, w, derivative)
lib/python3.11/site-packages/mpmath/functions/signals.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .functions import defun_wrapped
2
+
3
+ @defun_wrapped
4
+ def squarew(ctx, t, amplitude=1, period=1):
5
+ P = period
6
+ A = amplitude
7
+ return A*((-1)**ctx.floor(2*t/P))
8
+
9
+ @defun_wrapped
10
+ def trianglew(ctx, t, amplitude=1, period=1):
11
+ A = amplitude
12
+ P = period
13
+
14
+ return 2*A*(0.5 - ctx.fabs(1 - 2*ctx.frac(t/P + 0.25)))
15
+
16
+ @defun_wrapped
17
+ def sawtoothw(ctx, t, amplitude=1, period=1):
18
+ A = amplitude
19
+ P = period
20
+ return A*ctx.frac(t/P)
21
+
22
+ @defun_wrapped
23
+ def unit_triangle(ctx, t, amplitude=1):
24
+ A = amplitude
25
+ if t <= -1 or t >= 1:
26
+ return ctx.zero
27
+ return A*(-ctx.fabs(t) + 1)
28
+
29
+ @defun_wrapped
30
+ def sigmoid(ctx, t, amplitude=1):
31
+ A = amplitude
32
+ return A / (1 + ctx.exp(-t))
lib/python3.11/site-packages/mpmath/functions/theta.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .functions import defun, defun_wrapped
2
+
3
+ @defun
4
+ def _jacobi_theta2(ctx, z, q):
5
+ extra1 = 10
6
+ extra2 = 20
7
+ # the loops below break when the fixed precision quantities
8
+ # a and b go to zero;
9
+ # right shifting small negative numbers by wp one obtains -1, not zero,
10
+ # so the condition a**2 + b**2 > MIN is used to break the loops.
11
+ MIN = 2
12
+ if z == ctx.zero:
13
+ if (not ctx._im(q)):
14
+ wp = ctx.prec + extra1
15
+ x = ctx.to_fixed(ctx._re(q), wp)
16
+ x2 = (x*x) >> wp
17
+ a = b = x2
18
+ s = x2
19
+ while abs(a) > MIN:
20
+ b = (b*x2) >> wp
21
+ a = (a*b) >> wp
22
+ s += a
23
+ s = (1 << (wp+1)) + (s << 1)
24
+ s = ctx.ldexp(s, -wp)
25
+ else:
26
+ wp = ctx.prec + extra1
27
+ xre = ctx.to_fixed(ctx._re(q), wp)
28
+ xim = ctx.to_fixed(ctx._im(q), wp)
29
+ x2re = (xre*xre - xim*xim) >> wp
30
+ x2im = (xre*xim) >> (wp-1)
31
+ are = bre = x2re
32
+ aim = bim = x2im
33
+ sre = (1<<wp) + are
34
+ sim = aim
35
+ while are**2 + aim**2 > MIN:
36
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
37
+ (bre * x2im + bim * x2re) >> wp
38
+ are, aim = (are * bre - aim * bim) >> wp, \
39
+ (are * bim + aim * bre) >> wp
40
+ sre += are
41
+ sim += aim
42
+ sre = (sre << 1)
43
+ sim = (sim << 1)
44
+ sre = ctx.ldexp(sre, -wp)
45
+ sim = ctx.ldexp(sim, -wp)
46
+ s = ctx.mpc(sre, sim)
47
+ else:
48
+ if (not ctx._im(q)) and (not ctx._im(z)):
49
+ wp = ctx.prec + extra1
50
+ x = ctx.to_fixed(ctx._re(q), wp)
51
+ x2 = (x*x) >> wp
52
+ a = b = x2
53
+ c1, s1 = ctx.cos_sin(ctx._re(z), prec=wp)
54
+ cn = c1 = ctx.to_fixed(c1, wp)
55
+ sn = s1 = ctx.to_fixed(s1, wp)
56
+ c2 = (c1*c1 - s1*s1) >> wp
57
+ s2 = (c1 * s1) >> (wp - 1)
58
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
59
+ s = c1 + ((a * cn) >> wp)
60
+ while abs(a) > MIN:
61
+ b = (b*x2) >> wp
62
+ a = (a*b) >> wp
63
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
64
+ s += (a * cn) >> wp
65
+ s = (s << 1)
66
+ s = ctx.ldexp(s, -wp)
67
+ s *= ctx.nthroot(q, 4)
68
+ return s
69
+ # case z real, q complex
70
+ elif not ctx._im(z):
71
+ wp = ctx.prec + extra2
72
+ xre = ctx.to_fixed(ctx._re(q), wp)
73
+ xim = ctx.to_fixed(ctx._im(q), wp)
74
+ x2re = (xre*xre - xim*xim) >> wp
75
+ x2im = (xre*xim) >> (wp - 1)
76
+ are = bre = x2re
77
+ aim = bim = x2im
78
+ c1, s1 = ctx.cos_sin(ctx._re(z), prec=wp)
79
+ cn = c1 = ctx.to_fixed(c1, wp)
80
+ sn = s1 = ctx.to_fixed(s1, wp)
81
+ c2 = (c1*c1 - s1*s1) >> wp
82
+ s2 = (c1 * s1) >> (wp - 1)
83
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
84
+ sre = c1 + ((are * cn) >> wp)
85
+ sim = ((aim * cn) >> wp)
86
+ while are**2 + aim**2 > MIN:
87
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
88
+ (bre * x2im + bim * x2re) >> wp
89
+ are, aim = (are * bre - aim * bim) >> wp, \
90
+ (are * bim + aim * bre) >> wp
91
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
92
+ sre += ((are * cn) >> wp)
93
+ sim += ((aim * cn) >> wp)
94
+ sre = (sre << 1)
95
+ sim = (sim << 1)
96
+ sre = ctx.ldexp(sre, -wp)
97
+ sim = ctx.ldexp(sim, -wp)
98
+ s = ctx.mpc(sre, sim)
99
+ #case z complex, q real
100
+ elif not ctx._im(q):
101
+ wp = ctx.prec + extra2
102
+ x = ctx.to_fixed(ctx._re(q), wp)
103
+ x2 = (x*x) >> wp
104
+ a = b = x2
105
+ prec0 = ctx.prec
106
+ ctx.prec = wp
107
+ c1, s1 = ctx.cos_sin(z)
108
+ ctx.prec = prec0
109
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
110
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
111
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
112
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
113
+ #c2 = (c1*c1 - s1*s1) >> wp
114
+ c2re = (c1re*c1re - c1im*c1im - s1re*s1re + s1im*s1im) >> wp
115
+ c2im = (c1re*c1im - s1re*s1im) >> (wp - 1)
116
+ #s2 = (c1 * s1) >> (wp - 1)
117
+ s2re = (c1re*s1re - c1im*s1im) >> (wp - 1)
118
+ s2im = (c1re*s1im + c1im*s1re) >> (wp - 1)
119
+ #cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
120
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
121
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
122
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
123
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
124
+ cnre = t1
125
+ cnim = t2
126
+ snre = t3
127
+ snim = t4
128
+ sre = c1re + ((a * cnre) >> wp)
129
+ sim = c1im + ((a * cnim) >> wp)
130
+ while abs(a) > MIN:
131
+ b = (b*x2) >> wp
132
+ a = (a*b) >> wp
133
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
134
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
135
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
136
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
137
+ cnre = t1
138
+ cnim = t2
139
+ snre = t3
140
+ snim = t4
141
+ sre += ((a * cnre) >> wp)
142
+ sim += ((a * cnim) >> wp)
143
+ sre = (sre << 1)
144
+ sim = (sim << 1)
145
+ sre = ctx.ldexp(sre, -wp)
146
+ sim = ctx.ldexp(sim, -wp)
147
+ s = ctx.mpc(sre, sim)
148
+ # case z and q complex
149
+ else:
150
+ wp = ctx.prec + extra2
151
+ xre = ctx.to_fixed(ctx._re(q), wp)
152
+ xim = ctx.to_fixed(ctx._im(q), wp)
153
+ x2re = (xre*xre - xim*xim) >> wp
154
+ x2im = (xre*xim) >> (wp - 1)
155
+ are = bre = x2re
156
+ aim = bim = x2im
157
+ prec0 = ctx.prec
158
+ ctx.prec = wp
159
+ # cos(z), sin(z) with z complex
160
+ c1, s1 = ctx.cos_sin(z)
161
+ ctx.prec = prec0
162
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
163
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
164
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
165
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
166
+ c2re = (c1re*c1re - c1im*c1im - s1re*s1re + s1im*s1im) >> wp
167
+ c2im = (c1re*c1im - s1re*s1im) >> (wp - 1)
168
+ s2re = (c1re*s1re - c1im*s1im) >> (wp - 1)
169
+ s2im = (c1re*s1im + c1im*s1re) >> (wp - 1)
170
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
171
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
172
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
173
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
174
+ cnre = t1
175
+ cnim = t2
176
+ snre = t3
177
+ snim = t4
178
+ n = 1
179
+ termre = c1re
180
+ termim = c1im
181
+ sre = c1re + ((are * cnre - aim * cnim) >> wp)
182
+ sim = c1im + ((are * cnim + aim * cnre) >> wp)
183
+ n = 3
184
+ termre = ((are * cnre - aim * cnim) >> wp)
185
+ termim = ((are * cnim + aim * cnre) >> wp)
186
+ sre = c1re + ((are * cnre - aim * cnim) >> wp)
187
+ sim = c1im + ((are * cnim + aim * cnre) >> wp)
188
+ n = 5
189
+ while are**2 + aim**2 > MIN:
190
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
191
+ (bre * x2im + bim * x2re) >> wp
192
+ are, aim = (are * bre - aim * bim) >> wp, \
193
+ (are * bim + aim * bre) >> wp
194
+ #cn, sn = (cn*c1 - sn*s1) >> wp, (sn*c1 + cn*s1) >> wp
195
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
196
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
197
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
198
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
199
+ cnre = t1
200
+ cnim = t2
201
+ snre = t3
202
+ snim = t4
203
+ termre = ((are * cnre - aim * cnim) >> wp)
204
+ termim = ((aim * cnre + are * cnim) >> wp)
205
+ sre += ((are * cnre - aim * cnim) >> wp)
206
+ sim += ((aim * cnre + are * cnim) >> wp)
207
+ n += 2
208
+ sre = (sre << 1)
209
+ sim = (sim << 1)
210
+ sre = ctx.ldexp(sre, -wp)
211
+ sim = ctx.ldexp(sim, -wp)
212
+ s = ctx.mpc(sre, sim)
213
+ s *= ctx.nthroot(q, 4)
214
+ return s
215
+
216
+ @defun
217
+ def _djacobi_theta2(ctx, z, q, nd):
218
+ MIN = 2
219
+ extra1 = 10
220
+ extra2 = 20
221
+ if (not ctx._im(q)) and (not ctx._im(z)):
222
+ wp = ctx.prec + extra1
223
+ x = ctx.to_fixed(ctx._re(q), wp)
224
+ x2 = (x*x) >> wp
225
+ a = b = x2
226
+ c1, s1 = ctx.cos_sin(ctx._re(z), prec=wp)
227
+ cn = c1 = ctx.to_fixed(c1, wp)
228
+ sn = s1 = ctx.to_fixed(s1, wp)
229
+ c2 = (c1*c1 - s1*s1) >> wp
230
+ s2 = (c1 * s1) >> (wp - 1)
231
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
232
+ if (nd&1):
233
+ s = s1 + ((a * sn * 3**nd) >> wp)
234
+ else:
235
+ s = c1 + ((a * cn * 3**nd) >> wp)
236
+ n = 2
237
+ while abs(a) > MIN:
238
+ b = (b*x2) >> wp
239
+ a = (a*b) >> wp
240
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
241
+ if nd&1:
242
+ s += (a * sn * (2*n+1)**nd) >> wp
243
+ else:
244
+ s += (a * cn * (2*n+1)**nd) >> wp
245
+ n += 1
246
+ s = -(s << 1)
247
+ s = ctx.ldexp(s, -wp)
248
+ # case z real, q complex
249
+ elif not ctx._im(z):
250
+ wp = ctx.prec + extra2
251
+ xre = ctx.to_fixed(ctx._re(q), wp)
252
+ xim = ctx.to_fixed(ctx._im(q), wp)
253
+ x2re = (xre*xre - xim*xim) >> wp
254
+ x2im = (xre*xim) >> (wp - 1)
255
+ are = bre = x2re
256
+ aim = bim = x2im
257
+ c1, s1 = ctx.cos_sin(ctx._re(z), prec=wp)
258
+ cn = c1 = ctx.to_fixed(c1, wp)
259
+ sn = s1 = ctx.to_fixed(s1, wp)
260
+ c2 = (c1*c1 - s1*s1) >> wp
261
+ s2 = (c1 * s1) >> (wp - 1)
262
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
263
+ if (nd&1):
264
+ sre = s1 + ((are * sn * 3**nd) >> wp)
265
+ sim = ((aim * sn * 3**nd) >> wp)
266
+ else:
267
+ sre = c1 + ((are * cn * 3**nd) >> wp)
268
+ sim = ((aim * cn * 3**nd) >> wp)
269
+ n = 5
270
+ while are**2 + aim**2 > MIN:
271
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
272
+ (bre * x2im + bim * x2re) >> wp
273
+ are, aim = (are * bre - aim * bim) >> wp, \
274
+ (are * bim + aim * bre) >> wp
275
+ cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
276
+
277
+ if (nd&1):
278
+ sre += ((are * sn * n**nd) >> wp)
279
+ sim += ((aim * sn * n**nd) >> wp)
280
+ else:
281
+ sre += ((are * cn * n**nd) >> wp)
282
+ sim += ((aim * cn * n**nd) >> wp)
283
+ n += 2
284
+ sre = -(sre << 1)
285
+ sim = -(sim << 1)
286
+ sre = ctx.ldexp(sre, -wp)
287
+ sim = ctx.ldexp(sim, -wp)
288
+ s = ctx.mpc(sre, sim)
289
+ #case z complex, q real
290
+ elif not ctx._im(q):
291
+ wp = ctx.prec + extra2
292
+ x = ctx.to_fixed(ctx._re(q), wp)
293
+ x2 = (x*x) >> wp
294
+ a = b = x2
295
+ prec0 = ctx.prec
296
+ ctx.prec = wp
297
+ c1, s1 = ctx.cos_sin(z)
298
+ ctx.prec = prec0
299
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
300
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
301
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
302
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
303
+ #c2 = (c1*c1 - s1*s1) >> wp
304
+ c2re = (c1re*c1re - c1im*c1im - s1re*s1re + s1im*s1im) >> wp
305
+ c2im = (c1re*c1im - s1re*s1im) >> (wp - 1)
306
+ #s2 = (c1 * s1) >> (wp - 1)
307
+ s2re = (c1re*s1re - c1im*s1im) >> (wp - 1)
308
+ s2im = (c1re*s1im + c1im*s1re) >> (wp - 1)
309
+ #cn, sn = (cn*c2 - sn*s2) >> wp, (sn*c2 + cn*s2) >> wp
310
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
311
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
312
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
313
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
314
+ cnre = t1
315
+ cnim = t2
316
+ snre = t3
317
+ snim = t4
318
+ if (nd&1):
319
+ sre = s1re + ((a * snre * 3**nd) >> wp)
320
+ sim = s1im + ((a * snim * 3**nd) >> wp)
321
+ else:
322
+ sre = c1re + ((a * cnre * 3**nd) >> wp)
323
+ sim = c1im + ((a * cnim * 3**nd) >> wp)
324
+ n = 5
325
+ while abs(a) > MIN:
326
+ b = (b*x2) >> wp
327
+ a = (a*b) >> wp
328
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
329
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
330
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
331
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
332
+ cnre = t1
333
+ cnim = t2
334
+ snre = t3
335
+ snim = t4
336
+ if (nd&1):
337
+ sre += ((a * snre * n**nd) >> wp)
338
+ sim += ((a * snim * n**nd) >> wp)
339
+ else:
340
+ sre += ((a * cnre * n**nd) >> wp)
341
+ sim += ((a * cnim * n**nd) >> wp)
342
+ n += 2
343
+ sre = -(sre << 1)
344
+ sim = -(sim << 1)
345
+ sre = ctx.ldexp(sre, -wp)
346
+ sim = ctx.ldexp(sim, -wp)
347
+ s = ctx.mpc(sre, sim)
348
+ # case z and q complex
349
+ else:
350
+ wp = ctx.prec + extra2
351
+ xre = ctx.to_fixed(ctx._re(q), wp)
352
+ xim = ctx.to_fixed(ctx._im(q), wp)
353
+ x2re = (xre*xre - xim*xim) >> wp
354
+ x2im = (xre*xim) >> (wp - 1)
355
+ are = bre = x2re
356
+ aim = bim = x2im
357
+ prec0 = ctx.prec
358
+ ctx.prec = wp
359
+ # cos(2*z), sin(2*z) with z complex
360
+ c1, s1 = ctx.cos_sin(z)
361
+ ctx.prec = prec0
362
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
363
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
364
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
365
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
366
+ c2re = (c1re*c1re - c1im*c1im - s1re*s1re + s1im*s1im) >> wp
367
+ c2im = (c1re*c1im - s1re*s1im) >> (wp - 1)
368
+ s2re = (c1re*s1re - c1im*s1im) >> (wp - 1)
369
+ s2im = (c1re*s1im + c1im*s1re) >> (wp - 1)
370
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
371
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
372
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
373
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
374
+ cnre = t1
375
+ cnim = t2
376
+ snre = t3
377
+ snim = t4
378
+ if (nd&1):
379
+ sre = s1re + (((are * snre - aim * snim) * 3**nd) >> wp)
380
+ sim = s1im + (((are * snim + aim * snre)* 3**nd) >> wp)
381
+ else:
382
+ sre = c1re + (((are * cnre - aim * cnim) * 3**nd) >> wp)
383
+ sim = c1im + (((are * cnim + aim * cnre)* 3**nd) >> wp)
384
+ n = 5
385
+ while are**2 + aim**2 > MIN:
386
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
387
+ (bre * x2im + bim * x2re) >> wp
388
+ are, aim = (are * bre - aim * bim) >> wp, \
389
+ (are * bim + aim * bre) >> wp
390
+ #cn, sn = (cn*c1 - sn*s1) >> wp, (sn*c1 + cn*s1) >> wp
391
+ t1 = (cnre*c2re - cnim*c2im - snre*s2re + snim*s2im) >> wp
392
+ t2 = (cnre*c2im + cnim*c2re - snre*s2im - snim*s2re) >> wp
393
+ t3 = (snre*c2re - snim*c2im + cnre*s2re - cnim*s2im) >> wp
394
+ t4 = (snre*c2im + snim*c2re + cnre*s2im + cnim*s2re) >> wp
395
+ cnre = t1
396
+ cnim = t2
397
+ snre = t3
398
+ snim = t4
399
+ if (nd&1):
400
+ sre += (((are * snre - aim * snim) * n**nd) >> wp)
401
+ sim += (((aim * snre + are * snim) * n**nd) >> wp)
402
+ else:
403
+ sre += (((are * cnre - aim * cnim) * n**nd) >> wp)
404
+ sim += (((aim * cnre + are * cnim) * n**nd) >> wp)
405
+ n += 2
406
+ sre = -(sre << 1)
407
+ sim = -(sim << 1)
408
+ sre = ctx.ldexp(sre, -wp)
409
+ sim = ctx.ldexp(sim, -wp)
410
+ s = ctx.mpc(sre, sim)
411
+ s *= ctx.nthroot(q, 4)
412
+ if (nd&1):
413
+ return (-1)**(nd//2) * s
414
+ else:
415
+ return (-1)**(1 + nd//2) * s
416
+
417
+ @defun
418
+ def _jacobi_theta3(ctx, z, q):
419
+ extra1 = 10
420
+ extra2 = 20
421
+ MIN = 2
422
+ if z == ctx.zero:
423
+ if not ctx._im(q):
424
+ wp = ctx.prec + extra1
425
+ x = ctx.to_fixed(ctx._re(q), wp)
426
+ s = x
427
+ a = b = x
428
+ x2 = (x*x) >> wp
429
+ while abs(a) > MIN:
430
+ b = (b*x2) >> wp
431
+ a = (a*b) >> wp
432
+ s += a
433
+ s = (1 << wp) + (s << 1)
434
+ s = ctx.ldexp(s, -wp)
435
+ return s
436
+ else:
437
+ wp = ctx.prec + extra1
438
+ xre = ctx.to_fixed(ctx._re(q), wp)
439
+ xim = ctx.to_fixed(ctx._im(q), wp)
440
+ x2re = (xre*xre - xim*xim) >> wp
441
+ x2im = (xre*xim) >> (wp - 1)
442
+ sre = are = bre = xre
443
+ sim = aim = bim = xim
444
+ while are**2 + aim**2 > MIN:
445
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
446
+ (bre * x2im + bim * x2re) >> wp
447
+ are, aim = (are * bre - aim * bim) >> wp, \
448
+ (are * bim + aim * bre) >> wp
449
+ sre += are
450
+ sim += aim
451
+ sre = (1 << wp) + (sre << 1)
452
+ sim = (sim << 1)
453
+ sre = ctx.ldexp(sre, -wp)
454
+ sim = ctx.ldexp(sim, -wp)
455
+ s = ctx.mpc(sre, sim)
456
+ return s
457
+ else:
458
+ if (not ctx._im(q)) and (not ctx._im(z)):
459
+ s = 0
460
+ wp = ctx.prec + extra1
461
+ x = ctx.to_fixed(ctx._re(q), wp)
462
+ a = b = x
463
+ x2 = (x*x) >> wp
464
+ c1, s1 = ctx.cos_sin(ctx._re(z)*2, prec=wp)
465
+ c1 = ctx.to_fixed(c1, wp)
466
+ s1 = ctx.to_fixed(s1, wp)
467
+ cn = c1
468
+ sn = s1
469
+ s += (a * cn) >> wp
470
+ while abs(a) > MIN:
471
+ b = (b*x2) >> wp
472
+ a = (a*b) >> wp
473
+ cn, sn = (cn*c1 - sn*s1) >> wp, (sn*c1 + cn*s1) >> wp
474
+ s += (a * cn) >> wp
475
+ s = (1 << wp) + (s << 1)
476
+ s = ctx.ldexp(s, -wp)
477
+ return s
478
+ # case z real, q complex
479
+ elif not ctx._im(z):
480
+ wp = ctx.prec + extra2
481
+ xre = ctx.to_fixed(ctx._re(q), wp)
482
+ xim = ctx.to_fixed(ctx._im(q), wp)
483
+ x2re = (xre*xre - xim*xim) >> wp
484
+ x2im = (xre*xim) >> (wp - 1)
485
+ are = bre = xre
486
+ aim = bim = xim
487
+ c1, s1 = ctx.cos_sin(ctx._re(z)*2, prec=wp)
488
+ c1 = ctx.to_fixed(c1, wp)
489
+ s1 = ctx.to_fixed(s1, wp)
490
+ cn = c1
491
+ sn = s1
492
+ sre = (are * cn) >> wp
493
+ sim = (aim * cn) >> wp
494
+ while are**2 + aim**2 > MIN:
495
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
496
+ (bre * x2im + bim * x2re) >> wp
497
+ are, aim = (are * bre - aim * bim) >> wp, \
498
+ (are * bim + aim * bre) >> wp
499
+ cn, sn = (cn*c1 - sn*s1) >> wp, (sn*c1 + cn*s1) >> wp
500
+ sre += (are * cn) >> wp
501
+ sim += (aim * cn) >> wp
502
+ sre = (1 << wp) + (sre << 1)
503
+ sim = (sim << 1)
504
+ sre = ctx.ldexp(sre, -wp)
505
+ sim = ctx.ldexp(sim, -wp)
506
+ s = ctx.mpc(sre, sim)
507
+ return s
508
+ #case z complex, q real
509
+ elif not ctx._im(q):
510
+ wp = ctx.prec + extra2
511
+ x = ctx.to_fixed(ctx._re(q), wp)
512
+ a = b = x
513
+ x2 = (x*x) >> wp
514
+ prec0 = ctx.prec
515
+ ctx.prec = wp
516
+ c1, s1 = ctx.cos_sin(2*z)
517
+ ctx.prec = prec0
518
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
519
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
520
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
521
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
522
+ sre = (a * cnre) >> wp
523
+ sim = (a * cnim) >> wp
524
+ while abs(a) > MIN:
525
+ b = (b*x2) >> wp
526
+ a = (a*b) >> wp
527
+ t1 = (cnre*c1re - cnim*c1im - snre*s1re + snim*s1im) >> wp
528
+ t2 = (cnre*c1im + cnim*c1re - snre*s1im - snim*s1re) >> wp
529
+ t3 = (snre*c1re - snim*c1im + cnre*s1re - cnim*s1im) >> wp
530
+ t4 = (snre*c1im + snim*c1re + cnre*s1im + cnim*s1re) >> wp
531
+ cnre = t1
532
+ cnim = t2
533
+ snre = t3
534
+ snim = t4
535
+ sre += (a * cnre) >> wp
536
+ sim += (a * cnim) >> wp
537
+ sre = (1 << wp) + (sre << 1)
538
+ sim = (sim << 1)
539
+ sre = ctx.ldexp(sre, -wp)
540
+ sim = ctx.ldexp(sim, -wp)
541
+ s = ctx.mpc(sre, sim)
542
+ return s
543
+ # case z and q complex
544
+ else:
545
+ wp = ctx.prec + extra2
546
+ xre = ctx.to_fixed(ctx._re(q), wp)
547
+ xim = ctx.to_fixed(ctx._im(q), wp)
548
+ x2re = (xre*xre - xim*xim) >> wp
549
+ x2im = (xre*xim) >> (wp - 1)
550
+ are = bre = xre
551
+ aim = bim = xim
552
+ prec0 = ctx.prec
553
+ ctx.prec = wp
554
+ # cos(2*z), sin(2*z) with z complex
555
+ c1, s1 = ctx.cos_sin(2*z)
556
+ ctx.prec = prec0
557
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
558
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
559
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
560
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
561
+ sre = (are * cnre - aim * cnim) >> wp
562
+ sim = (aim * cnre + are * cnim) >> wp
563
+ while are**2 + aim**2 > MIN:
564
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
565
+ (bre * x2im + bim * x2re) >> wp
566
+ are, aim = (are * bre - aim * bim) >> wp, \
567
+ (are * bim + aim * bre) >> wp
568
+ t1 = (cnre*c1re - cnim*c1im - snre*s1re + snim*s1im) >> wp
569
+ t2 = (cnre*c1im + cnim*c1re - snre*s1im - snim*s1re) >> wp
570
+ t3 = (snre*c1re - snim*c1im + cnre*s1re - cnim*s1im) >> wp
571
+ t4 = (snre*c1im + snim*c1re + cnre*s1im + cnim*s1re) >> wp
572
+ cnre = t1
573
+ cnim = t2
574
+ snre = t3
575
+ snim = t4
576
+ sre += (are * cnre - aim * cnim) >> wp
577
+ sim += (aim * cnre + are * cnim) >> wp
578
+ sre = (1 << wp) + (sre << 1)
579
+ sim = (sim << 1)
580
+ sre = ctx.ldexp(sre, -wp)
581
+ sim = ctx.ldexp(sim, -wp)
582
+ s = ctx.mpc(sre, sim)
583
+ return s
584
+
585
+ @defun
586
+ def _djacobi_theta3(ctx, z, q, nd):
587
+ """nd=1,2,3 order of the derivative with respect to z"""
588
+ MIN = 2
589
+ extra1 = 10
590
+ extra2 = 20
591
+ if (not ctx._im(q)) and (not ctx._im(z)):
592
+ s = 0
593
+ wp = ctx.prec + extra1
594
+ x = ctx.to_fixed(ctx._re(q), wp)
595
+ a = b = x
596
+ x2 = (x*x) >> wp
597
+ c1, s1 = ctx.cos_sin(ctx._re(z)*2, prec=wp)
598
+ c1 = ctx.to_fixed(c1, wp)
599
+ s1 = ctx.to_fixed(s1, wp)
600
+ cn = c1
601
+ sn = s1
602
+ if (nd&1):
603
+ s += (a * sn) >> wp
604
+ else:
605
+ s += (a * cn) >> wp
606
+ n = 2
607
+ while abs(a) > MIN:
608
+ b = (b*x2) >> wp
609
+ a = (a*b) >> wp
610
+ cn, sn = (cn*c1 - sn*s1) >> wp, (sn*c1 + cn*s1) >> wp
611
+ if nd&1:
612
+ s += (a * sn * n**nd) >> wp
613
+ else:
614
+ s += (a * cn * n**nd) >> wp
615
+ n += 1
616
+ s = -(s << (nd+1))
617
+ s = ctx.ldexp(s, -wp)
618
+ # case z real, q complex
619
+ elif not ctx._im(z):
620
+ wp = ctx.prec + extra2
621
+ xre = ctx.to_fixed(ctx._re(q), wp)
622
+ xim = ctx.to_fixed(ctx._im(q), wp)
623
+ x2re = (xre*xre - xim*xim) >> wp
624
+ x2im = (xre*xim) >> (wp - 1)
625
+ are = bre = xre
626
+ aim = bim = xim
627
+ c1, s1 = ctx.cos_sin(ctx._re(z)*2, prec=wp)
628
+ c1 = ctx.to_fixed(c1, wp)
629
+ s1 = ctx.to_fixed(s1, wp)
630
+ cn = c1
631
+ sn = s1
632
+ if (nd&1):
633
+ sre = (are * sn) >> wp
634
+ sim = (aim * sn) >> wp
635
+ else:
636
+ sre = (are * cn) >> wp
637
+ sim = (aim * cn) >> wp
638
+ n = 2
639
+ while are**2 + aim**2 > MIN:
640
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
641
+ (bre * x2im + bim * x2re) >> wp
642
+ are, aim = (are * bre - aim * bim) >> wp, \
643
+ (are * bim + aim * bre) >> wp
644
+ cn, sn = (cn*c1 - sn*s1) >> wp, (sn*c1 + cn*s1) >> wp
645
+ if nd&1:
646
+ sre += (are * sn * n**nd) >> wp
647
+ sim += (aim * sn * n**nd) >> wp
648
+ else:
649
+ sre += (are * cn * n**nd) >> wp
650
+ sim += (aim * cn * n**nd) >> wp
651
+ n += 1
652
+ sre = -(sre << (nd+1))
653
+ sim = -(sim << (nd+1))
654
+ sre = ctx.ldexp(sre, -wp)
655
+ sim = ctx.ldexp(sim, -wp)
656
+ s = ctx.mpc(sre, sim)
657
+ #case z complex, q real
658
+ elif not ctx._im(q):
659
+ wp = ctx.prec + extra2
660
+ x = ctx.to_fixed(ctx._re(q), wp)
661
+ a = b = x
662
+ x2 = (x*x) >> wp
663
+ prec0 = ctx.prec
664
+ ctx.prec = wp
665
+ c1, s1 = ctx.cos_sin(2*z)
666
+ ctx.prec = prec0
667
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
668
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
669
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
670
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
671
+ if (nd&1):
672
+ sre = (a * snre) >> wp
673
+ sim = (a * snim) >> wp
674
+ else:
675
+ sre = (a * cnre) >> wp
676
+ sim = (a * cnim) >> wp
677
+ n = 2
678
+ while abs(a) > MIN:
679
+ b = (b*x2) >> wp
680
+ a = (a*b) >> wp
681
+ t1 = (cnre*c1re - cnim*c1im - snre*s1re + snim*s1im) >> wp
682
+ t2 = (cnre*c1im + cnim*c1re - snre*s1im - snim*s1re) >> wp
683
+ t3 = (snre*c1re - snim*c1im + cnre*s1re - cnim*s1im) >> wp
684
+ t4 = (snre*c1im + snim*c1re + cnre*s1im + cnim*s1re) >> wp
685
+ cnre = t1
686
+ cnim = t2
687
+ snre = t3
688
+ snim = t4
689
+ if (nd&1):
690
+ sre += (a * snre * n**nd) >> wp
691
+ sim += (a * snim * n**nd) >> wp
692
+ else:
693
+ sre += (a * cnre * n**nd) >> wp
694
+ sim += (a * cnim * n**nd) >> wp
695
+ n += 1
696
+ sre = -(sre << (nd+1))
697
+ sim = -(sim << (nd+1))
698
+ sre = ctx.ldexp(sre, -wp)
699
+ sim = ctx.ldexp(sim, -wp)
700
+ s = ctx.mpc(sre, sim)
701
+ # case z and q complex
702
+ else:
703
+ wp = ctx.prec + extra2
704
+ xre = ctx.to_fixed(ctx._re(q), wp)
705
+ xim = ctx.to_fixed(ctx._im(q), wp)
706
+ x2re = (xre*xre - xim*xim) >> wp
707
+ x2im = (xre*xim) >> (wp - 1)
708
+ are = bre = xre
709
+ aim = bim = xim
710
+ prec0 = ctx.prec
711
+ ctx.prec = wp
712
+ # cos(2*z), sin(2*z) with z complex
713
+ c1, s1 = ctx.cos_sin(2*z)
714
+ ctx.prec = prec0
715
+ cnre = c1re = ctx.to_fixed(ctx._re(c1), wp)
716
+ cnim = c1im = ctx.to_fixed(ctx._im(c1), wp)
717
+ snre = s1re = ctx.to_fixed(ctx._re(s1), wp)
718
+ snim = s1im = ctx.to_fixed(ctx._im(s1), wp)
719
+ if (nd&1):
720
+ sre = (are * snre - aim * snim) >> wp
721
+ sim = (aim * snre + are * snim) >> wp
722
+ else:
723
+ sre = (are * cnre - aim * cnim) >> wp
724
+ sim = (aim * cnre + are * cnim) >> wp
725
+ n = 2
726
+ while are**2 + aim**2 > MIN:
727
+ bre, bim = (bre * x2re - bim * x2im) >> wp, \
728
+ (bre * x2im + bim * x2re) >> wp
729
+ are, aim = (are * bre - aim * bim) >> wp, \
730
+ (are * bim + aim * bre) >> wp
731
+ t1 = (cnre*c1re - cnim*c1im - snre*s1re + snim*s1im) >> wp
732
+ t2 = (cnre*c1im + cnim*c1re - snre*s1im - snim*s1re) >> wp
733
+ t3 = (snre*c1re - snim*c1im + cnre*s1re - cnim*s1im) >> wp
734
+ t4 = (snre*c1im + snim*c1re + cnre*s1im + cnim*s1re) >> wp
735
+ cnre = t1
736
+ cnim = t2
737
+ snre = t3
738
+ snim = t4
739
+ if(nd&1):
740
+ sre += ((are * snre - aim * snim) * n**nd) >> wp
741
+ sim += ((aim * snre + are * snim) * n**nd) >> wp
742
+ else:
743
+ sre += ((are * cnre - aim * cnim) * n**nd) >> wp
744
+ sim += ((aim * cnre + are * cnim) * n**nd) >> wp
745
+ n += 1
746
+ sre = -(sre << (nd+1))
747
+ sim = -(sim << (nd+1))
748
+ sre = ctx.ldexp(sre, -wp)
749
+ sim = ctx.ldexp(sim, -wp)
750
+ s = ctx.mpc(sre, sim)
751
+ if (nd&1):
752
+ return (-1)**(nd//2) * s
753
+ else:
754
+ return (-1)**(1 + nd//2) * s
755
+
756
+ @defun
757
+ def _jacobi_theta2a(ctx, z, q):
758
+ """
759
+ case ctx._im(z) != 0
760
+ theta(2, z, q) =
761
+ q**1/4 * Sum(q**(n*n + n) * exp(j*(2*n + 1)*z), n=-inf, inf)
762
+ max term for minimum (2*n+1)*log(q).real - 2* ctx._im(z)
763
+ n0 = int(ctx._im(z)/log(q).real - 1/2)
764
+ theta(2, z, q) =
765
+ q**1/4 * Sum(q**(n*n + n) * exp(j*(2*n + 1)*z), n=n0, inf) +
766
+ q**1/4 * Sum(q**(n*n + n) * exp(j*(2*n + 1)*z), n, n0-1, -inf)
767
+ """
768
+ n = n0 = int(ctx._im(z)/ctx._re(ctx.log(q)) - 1/2)
769
+ e2 = ctx.expj(2*z)
770
+ e = e0 = ctx.expj((2*n+1)*z)
771
+ a = q**(n*n + n)
772
+ # leading term
773
+ term = a * e
774
+ s = term
775
+ eps1 = ctx.eps*abs(term)
776
+ while 1:
777
+ n += 1
778
+ e = e * e2
779
+ term = q**(n*n + n) * e
780
+ if abs(term) < eps1:
781
+ break
782
+ s += term
783
+ e = e0
784
+ e2 = ctx.expj(-2*z)
785
+ n = n0
786
+ while 1:
787
+ n -= 1
788
+ e = e * e2
789
+ term = q**(n*n + n) * e
790
+ if abs(term) < eps1:
791
+ break
792
+ s += term
793
+ s = s * ctx.nthroot(q, 4)
794
+ return s
795
+
796
+ @defun
797
+ def _jacobi_theta3a(ctx, z, q):
798
+ """
799
+ case ctx._im(z) != 0
800
+ theta3(z, q) = Sum(q**(n*n) * exp(j*2*n*z), n, -inf, inf)
801
+ max term for n*abs(log(q).real) + ctx._im(z) ~= 0
802
+ n0 = int(- ctx._im(z)/abs(log(q).real))
803
+ """
804
+ n = n0 = int(-ctx._im(z)/abs(ctx._re(ctx.log(q))))
805
+ e2 = ctx.expj(2*z)
806
+ e = e0 = ctx.expj(2*n*z)
807
+ s = term = q**(n*n) * e
808
+ eps1 = ctx.eps*abs(term)
809
+ while 1:
810
+ n += 1
811
+ e = e * e2
812
+ term = q**(n*n) * e
813
+ if abs(term) < eps1:
814
+ break
815
+ s += term
816
+ e = e0
817
+ e2 = ctx.expj(-2*z)
818
+ n = n0
819
+ while 1:
820
+ n -= 1
821
+ e = e * e2
822
+ term = q**(n*n) * e
823
+ if abs(term) < eps1:
824
+ break
825
+ s += term
826
+ return s
827
+
828
+ @defun
829
+ def _djacobi_theta2a(ctx, z, q, nd):
830
+ """
831
+ case ctx._im(z) != 0
832
+ dtheta(2, z, q, nd) =
833
+ j* q**1/4 * Sum(q**(n*n + n) * (2*n+1)*exp(j*(2*n + 1)*z), n=-inf, inf)
834
+ max term for (2*n0+1)*log(q).real - 2* ctx._im(z) ~= 0
835
+ n0 = int(ctx._im(z)/log(q).real - 1/2)
836
+ """
837
+ n = n0 = int(ctx._im(z)/ctx._re(ctx.log(q)) - 1/2)
838
+ e2 = ctx.expj(2*z)
839
+ e = e0 = ctx.expj((2*n + 1)*z)
840
+ a = q**(n*n + n)
841
+ # leading term
842
+ term = (2*n+1)**nd * a * e
843
+ s = term
844
+ eps1 = ctx.eps*abs(term)
845
+ while 1:
846
+ n += 1
847
+ e = e * e2
848
+ term = (2*n+1)**nd * q**(n*n + n) * e
849
+ if abs(term) < eps1:
850
+ break
851
+ s += term
852
+ e = e0
853
+ e2 = ctx.expj(-2*z)
854
+ n = n0
855
+ while 1:
856
+ n -= 1
857
+ e = e * e2
858
+ term = (2*n+1)**nd * q**(n*n + n) * e
859
+ if abs(term) < eps1:
860
+ break
861
+ s += term
862
+ return ctx.j**nd * s * ctx.nthroot(q, 4)
863
+
864
+ @defun
865
+ def _djacobi_theta3a(ctx, z, q, nd):
866
+ """
867
+ case ctx._im(z) != 0
868
+ djtheta3(z, q, nd) = (2*j)**nd *
869
+ Sum(q**(n*n) * n**nd * exp(j*2*n*z), n, -inf, inf)
870
+ max term for minimum n*abs(log(q).real) + ctx._im(z)
871
+ """
872
+ n = n0 = int(-ctx._im(z)/abs(ctx._re(ctx.log(q))))
873
+ e2 = ctx.expj(2*z)
874
+ e = e0 = ctx.expj(2*n*z)
875
+ a = q**(n*n) * e
876
+ s = term = n**nd * a
877
+ if n != 0:
878
+ eps1 = ctx.eps*abs(term)
879
+ else:
880
+ eps1 = ctx.eps*abs(a)
881
+ while 1:
882
+ n += 1
883
+ e = e * e2
884
+ a = q**(n*n) * e
885
+ term = n**nd * a
886
+ if n != 0:
887
+ aterm = abs(term)
888
+ else:
889
+ aterm = abs(a)
890
+ if aterm < eps1:
891
+ break
892
+ s += term
893
+ e = e0
894
+ e2 = ctx.expj(-2*z)
895
+ n = n0
896
+ while 1:
897
+ n -= 1
898
+ e = e * e2
899
+ a = q**(n*n) * e
900
+ term = n**nd * a
901
+ if n != 0:
902
+ aterm = abs(term)
903
+ else:
904
+ aterm = abs(a)
905
+ if aterm < eps1:
906
+ break
907
+ s += term
908
+ return (2*ctx.j)**nd * s
909
+
910
+ @defun
911
+ def jtheta(ctx, n, z, q, derivative=0):
912
+ if derivative:
913
+ return ctx._djtheta(n, z, q, derivative)
914
+
915
+ z = ctx.convert(z)
916
+ q = ctx.convert(q)
917
+
918
+ # Implementation note
919
+ # If ctx._im(z) is close to zero, _jacobi_theta2 and _jacobi_theta3
920
+ # are used,
921
+ # which compute the series starting from n=0 using fixed precision
922
+ # numbers;
923
+ # otherwise _jacobi_theta2a and _jacobi_theta3a are used, which compute
924
+ # the series starting from n=n0, which is the largest term.
925
+
926
+ # TODO: write _jacobi_theta2a and _jacobi_theta3a using fixed-point
927
+
928
+ if abs(q) > ctx.THETA_Q_LIM:
929
+ raise ValueError('abs(q) > THETA_Q_LIM = %f' % ctx.THETA_Q_LIM)
930
+
931
+ extra = 10
932
+ if z:
933
+ M = ctx.mag(z)
934
+ if M > 5 or (n == 1 and M < -5):
935
+ extra += 2*abs(M)
936
+ cz = 0.5
937
+ extra2 = 50
938
+ prec0 = ctx.prec
939
+ try:
940
+ ctx.prec += extra
941
+ if n == 1:
942
+ if ctx._im(z):
943
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
944
+ ctx.dps += extra2
945
+ res = ctx._jacobi_theta2(z - ctx.pi/2, q)
946
+ else:
947
+ ctx.dps += 10
948
+ res = ctx._jacobi_theta2a(z - ctx.pi/2, q)
949
+ else:
950
+ res = ctx._jacobi_theta2(z - ctx.pi/2, q)
951
+ elif n == 2:
952
+ if ctx._im(z):
953
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
954
+ ctx.dps += extra2
955
+ res = ctx._jacobi_theta2(z, q)
956
+ else:
957
+ ctx.dps += 10
958
+ res = ctx._jacobi_theta2a(z, q)
959
+ else:
960
+ res = ctx._jacobi_theta2(z, q)
961
+ elif n == 3:
962
+ if ctx._im(z):
963
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
964
+ ctx.dps += extra2
965
+ res = ctx._jacobi_theta3(z, q)
966
+ else:
967
+ ctx.dps += 10
968
+ res = ctx._jacobi_theta3a(z, q)
969
+ else:
970
+ res = ctx._jacobi_theta3(z, q)
971
+ elif n == 4:
972
+ if ctx._im(z):
973
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
974
+ ctx.dps += extra2
975
+ res = ctx._jacobi_theta3(z, -q)
976
+ else:
977
+ ctx.dps += 10
978
+ res = ctx._jacobi_theta3a(z, -q)
979
+ else:
980
+ res = ctx._jacobi_theta3(z, -q)
981
+ else:
982
+ raise ValueError
983
+ finally:
984
+ ctx.prec = prec0
985
+ return res
986
+
987
+ @defun
988
+ def _djtheta(ctx, n, z, q, derivative=1):
989
+ z = ctx.convert(z)
990
+ q = ctx.convert(q)
991
+ nd = int(derivative)
992
+
993
+ if abs(q) > ctx.THETA_Q_LIM:
994
+ raise ValueError('abs(q) > THETA_Q_LIM = %f' % ctx.THETA_Q_LIM)
995
+ extra = 10 + ctx.prec * nd // 10
996
+ if z:
997
+ M = ctx.mag(z)
998
+ if M > 5 or (n != 1 and M < -5):
999
+ extra += 2*abs(M)
1000
+ cz = 0.5
1001
+ extra2 = 50
1002
+ prec0 = ctx.prec
1003
+ try:
1004
+ ctx.prec += extra
1005
+ if n == 1:
1006
+ if ctx._im(z):
1007
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
1008
+ ctx.dps += extra2
1009
+ res = ctx._djacobi_theta2(z - ctx.pi/2, q, nd)
1010
+ else:
1011
+ ctx.dps += 10
1012
+ res = ctx._djacobi_theta2a(z - ctx.pi/2, q, nd)
1013
+ else:
1014
+ res = ctx._djacobi_theta2(z - ctx.pi/2, q, nd)
1015
+ elif n == 2:
1016
+ if ctx._im(z):
1017
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
1018
+ ctx.dps += extra2
1019
+ res = ctx._djacobi_theta2(z, q, nd)
1020
+ else:
1021
+ ctx.dps += 10
1022
+ res = ctx._djacobi_theta2a(z, q, nd)
1023
+ else:
1024
+ res = ctx._djacobi_theta2(z, q, nd)
1025
+ elif n == 3:
1026
+ if ctx._im(z):
1027
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
1028
+ ctx.dps += extra2
1029
+ res = ctx._djacobi_theta3(z, q, nd)
1030
+ else:
1031
+ ctx.dps += 10
1032
+ res = ctx._djacobi_theta3a(z, q, nd)
1033
+ else:
1034
+ res = ctx._djacobi_theta3(z, q, nd)
1035
+ elif n == 4:
1036
+ if ctx._im(z):
1037
+ if abs(ctx._im(z)) < cz * abs(ctx._re(ctx.log(q))):
1038
+ ctx.dps += extra2
1039
+ res = ctx._djacobi_theta3(z, -q, nd)
1040
+ else:
1041
+ ctx.dps += 10
1042
+ res = ctx._djacobi_theta3a(z, -q, nd)
1043
+ else:
1044
+ res = ctx._djacobi_theta3(z, -q, nd)
1045
+ else:
1046
+ raise ValueError
1047
+ finally:
1048
+ ctx.prec = prec0
1049
+ return +res
lib/python3.11/site-packages/mpmath/functions/zeta.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ from ..libmp.backend import xrange
4
+ from .functions import defun, defun_wrapped, defun_static
5
+
6
+ @defun
7
+ def stieltjes(ctx, n, a=1):
8
+ n = ctx.convert(n)
9
+ a = ctx.convert(a)
10
+ if n < 0:
11
+ return ctx.bad_domain("Stieltjes constants defined for n >= 0")
12
+ if hasattr(ctx, "stieltjes_cache"):
13
+ stieltjes_cache = ctx.stieltjes_cache
14
+ else:
15
+ stieltjes_cache = ctx.stieltjes_cache = {}
16
+ if a == 1:
17
+ if n == 0:
18
+ return +ctx.euler
19
+ if n in stieltjes_cache:
20
+ prec, s = stieltjes_cache[n]
21
+ if prec >= ctx.prec:
22
+ return +s
23
+ mag = 1
24
+ def f(x):
25
+ xa = x/a
26
+ v = (xa-ctx.j)*ctx.ln(a-ctx.j*x)**n/(1+xa**2)/(ctx.exp(2*ctx.pi*x)-1)
27
+ return ctx._re(v) / mag
28
+ orig = ctx.prec
29
+ try:
30
+ # Normalize integrand by approx. magnitude to
31
+ # speed up quadrature (which uses absolute error)
32
+ if n > 50:
33
+ ctx.prec = 20
34
+ mag = ctx.quad(f, [0,ctx.inf], maxdegree=3)
35
+ ctx.prec = orig + 10 + int(n**0.5)
36
+ s = ctx.quad(f, [0,ctx.inf], maxdegree=20)
37
+ v = ctx.ln(a)**n/(2*a) - ctx.ln(a)**(n+1)/(n+1) + 2*s/a*mag
38
+ finally:
39
+ ctx.prec = orig
40
+ if a == 1 and ctx.isint(n):
41
+ stieltjes_cache[n] = (ctx.prec, v)
42
+ return +v
43
+
44
+ @defun_wrapped
45
+ def siegeltheta(ctx, t, derivative=0):
46
+ d = int(derivative)
47
+ if (t == ctx.inf or t == ctx.ninf):
48
+ if d < 2:
49
+ if t == ctx.ninf and d == 0:
50
+ return ctx.ninf
51
+ return ctx.inf
52
+ else:
53
+ return ctx.zero
54
+ if d == 0:
55
+ if ctx._im(t):
56
+ # XXX: cancellation occurs
57
+ a = ctx.loggamma(0.25+0.5j*t)
58
+ b = ctx.loggamma(0.25-0.5j*t)
59
+ return -ctx.ln(ctx.pi)/2*t - 0.5j*(a-b)
60
+ else:
61
+ if ctx.isinf(t):
62
+ return t
63
+ return ctx._im(ctx.loggamma(0.25+0.5j*t)) - ctx.ln(ctx.pi)/2*t
64
+ if d > 0:
65
+ a = (-0.5j)**(d-1)*ctx.polygamma(d-1, 0.25-0.5j*t)
66
+ b = (0.5j)**(d-1)*ctx.polygamma(d-1, 0.25+0.5j*t)
67
+ if ctx._im(t):
68
+ if d == 1:
69
+ return -0.5*ctx.log(ctx.pi)+0.25*(a+b)
70
+ else:
71
+ return 0.25*(a+b)
72
+ else:
73
+ if d == 1:
74
+ return ctx._re(-0.5*ctx.log(ctx.pi)+0.25*(a+b))
75
+ else:
76
+ return ctx._re(0.25*(a+b))
77
+
78
+ @defun_wrapped
79
+ def grampoint(ctx, n):
80
+ # asymptotic expansion, from
81
+ # http://mathworld.wolfram.com/GramPoint.html
82
+ g = 2*ctx.pi*ctx.exp(1+ctx.lambertw((8*n+1)/(8*ctx.e)))
83
+ return ctx.findroot(lambda t: ctx.siegeltheta(t)-ctx.pi*n, g)
84
+
85
+
86
+ @defun_wrapped
87
+ def siegelz(ctx, t, **kwargs):
88
+ d = int(kwargs.get("derivative", 0))
89
+ t = ctx.convert(t)
90
+ t1 = ctx._re(t)
91
+ t2 = ctx._im(t)
92
+ prec = ctx.prec
93
+ try:
94
+ if abs(t1) > 500*prec and t2**2 < t1:
95
+ v = ctx.rs_z(t, d)
96
+ if ctx._is_real_type(t):
97
+ return ctx._re(v)
98
+ return v
99
+ except NotImplementedError:
100
+ pass
101
+ ctx.prec += 21
102
+ e1 = ctx.expj(ctx.siegeltheta(t))
103
+ z = ctx.zeta(0.5+ctx.j*t)
104
+ if d == 0:
105
+ v = e1*z
106
+ ctx.prec=prec
107
+ if ctx._is_real_type(t):
108
+ return ctx._re(v)
109
+ return +v
110
+ z1 = ctx.zeta(0.5+ctx.j*t, derivative=1)
111
+ theta1 = ctx.siegeltheta(t, derivative=1)
112
+ if d == 1:
113
+ v = ctx.j*e1*(z1+z*theta1)
114
+ ctx.prec=prec
115
+ if ctx._is_real_type(t):
116
+ return ctx._re(v)
117
+ return +v
118
+ z2 = ctx.zeta(0.5+ctx.j*t, derivative=2)
119
+ theta2 = ctx.siegeltheta(t, derivative=2)
120
+ comb1 = theta1**2-ctx.j*theta2
121
+ if d == 2:
122
+ def terms():
123
+ return [2*z1*theta1, z2, z*comb1]
124
+ v = ctx.sum_accurately(terms, 1)
125
+ v = -e1*v
126
+ ctx.prec = prec
127
+ if ctx._is_real_type(t):
128
+ return ctx._re(v)
129
+ return +v
130
+ ctx.prec += 10
131
+ z3 = ctx.zeta(0.5+ctx.j*t, derivative=3)
132
+ theta3 = ctx.siegeltheta(t, derivative=3)
133
+ comb2 = theta1**3-3*ctx.j*theta1*theta2-theta3
134
+ if d == 3:
135
+ def terms():
136
+ return [3*theta1*z2, 3*z1*comb1, z3+z*comb2]
137
+ v = ctx.sum_accurately(terms, 1)
138
+ v = -ctx.j*e1*v
139
+ ctx.prec = prec
140
+ if ctx._is_real_type(t):
141
+ return ctx._re(v)
142
+ return +v
143
+ z4 = ctx.zeta(0.5+ctx.j*t, derivative=4)
144
+ theta4 = ctx.siegeltheta(t, derivative=4)
145
+ def terms():
146
+ return [theta1**4, -6*ctx.j*theta1**2*theta2, -3*theta2**2,
147
+ -4*theta1*theta3, ctx.j*theta4]
148
+ comb3 = ctx.sum_accurately(terms, 1)
149
+ if d == 4:
150
+ def terms():
151
+ return [6*theta1**2*z2, -6*ctx.j*z2*theta2, 4*theta1*z3,
152
+ 4*z1*comb2, z4, z*comb3]
153
+ v = ctx.sum_accurately(terms, 1)
154
+ v = e1*v
155
+ ctx.prec = prec
156
+ if ctx._is_real_type(t):
157
+ return ctx._re(v)
158
+ return +v
159
+ if d > 4:
160
+ h = lambda x: ctx.siegelz(x, derivative=4)
161
+ return ctx.diff(h, t, n=d-4)
162
+
163
+
164
+ _zeta_zeros = [
165
+ 14.134725142,21.022039639,25.010857580,30.424876126,32.935061588,
166
+ 37.586178159,40.918719012,43.327073281,48.005150881,49.773832478,
167
+ 52.970321478,56.446247697,59.347044003,60.831778525,65.112544048,
168
+ 67.079810529,69.546401711,72.067157674,75.704690699,77.144840069,
169
+ 79.337375020,82.910380854,84.735492981,87.425274613,88.809111208,
170
+ 92.491899271,94.651344041,95.870634228,98.831194218,101.317851006,
171
+ 103.725538040,105.446623052,107.168611184,111.029535543,111.874659177,
172
+ 114.320220915,116.226680321,118.790782866,121.370125002,122.946829294,
173
+ 124.256818554,127.516683880,129.578704200,131.087688531,133.497737203,
174
+ 134.756509753,138.116042055,139.736208952,141.123707404,143.111845808,
175
+ 146.000982487,147.422765343,150.053520421,150.925257612,153.024693811,
176
+ 156.112909294,157.597591818,158.849988171,161.188964138,163.030709687,
177
+ 165.537069188,167.184439978,169.094515416,169.911976479,173.411536520,
178
+ 174.754191523,176.441434298,178.377407776,179.916484020,182.207078484,
179
+ 184.874467848,185.598783678,187.228922584,189.416158656,192.026656361,
180
+ 193.079726604,195.265396680,196.876481841,198.015309676,201.264751944,
181
+ 202.493594514,204.189671803,205.394697202,207.906258888,209.576509717,
182
+ 211.690862595,213.347919360,214.547044783,216.169538508,219.067596349,
183
+ 220.714918839,221.430705555,224.007000255,224.983324670,227.421444280,
184
+ 229.337413306,231.250188700,231.987235253,233.693404179,236.524229666,
185
+ ]
186
+
187
+ def _load_zeta_zeros(url):
188
+ import urllib
189
+ d = urllib.urlopen(url)
190
+ L = [float(x) for x in d.readlines()]
191
+ # Sanity check
192
+ assert round(L[0]) == 14
193
+ _zeta_zeros[:] = L
194
+
195
+ @defun
196
+ def oldzetazero(ctx, n, url='http://www.dtc.umn.edu/~odlyzko/zeta_tables/zeros1'):
197
+ n = int(n)
198
+ if n < 0:
199
+ return ctx.zetazero(-n).conjugate()
200
+ if n == 0:
201
+ raise ValueError("n must be nonzero")
202
+ if n > len(_zeta_zeros) and n <= 100000:
203
+ _load_zeta_zeros(url)
204
+ if n > len(_zeta_zeros):
205
+ raise NotImplementedError("n too large for zetazeros")
206
+ return ctx.mpc(0.5, ctx.findroot(ctx.siegelz, _zeta_zeros[n-1]))
207
+
208
+ @defun_wrapped
209
+ def riemannr(ctx, x):
210
+ if x == 0:
211
+ return ctx.zero
212
+ # Check if a simple asymptotic estimate is accurate enough
213
+ if abs(x) > 1000:
214
+ a = ctx.li(x)
215
+ b = 0.5*ctx.li(ctx.sqrt(x))
216
+ if abs(b) < abs(a)*ctx.eps:
217
+ return a
218
+ if abs(x) < 0.01:
219
+ # XXX
220
+ ctx.prec += int(-ctx.log(abs(x),2))
221
+ # Sum Gram's series
222
+ s = t = ctx.one
223
+ u = ctx.ln(x)
224
+ k = 1
225
+ while abs(t) > abs(s)*ctx.eps:
226
+ t = t * u / k
227
+ s += t / (k * ctx._zeta_int(k+1))
228
+ k += 1
229
+ return s
230
+
231
+ @defun_static
232
+ def primepi(ctx, x):
233
+ x = int(x)
234
+ if x < 2:
235
+ return 0
236
+ return len(ctx.list_primes(x))
237
+
238
+ # TODO: fix the interface wrt contexts
239
+ @defun_wrapped
240
+ def primepi2(ctx, x):
241
+ x = int(x)
242
+ if x < 2:
243
+ return ctx._iv.zero
244
+ if x < 2657:
245
+ return ctx._iv.mpf(ctx.primepi(x))
246
+ mid = ctx.li(x)
247
+ # Schoenfeld's estimate for x >= 2657, assuming RH
248
+ err = ctx.sqrt(x,rounding='u')*ctx.ln(x,rounding='u')/8/ctx.pi(rounding='d')
249
+ a = ctx.floor((ctx._iv.mpf(mid)-err).a, rounding='d')
250
+ b = ctx.ceil((ctx._iv.mpf(mid)+err).b, rounding='u')
251
+ return ctx._iv.mpf([a,b])
252
+
253
+ @defun_wrapped
254
+ def primezeta(ctx, s):
255
+ if ctx.isnan(s):
256
+ return s
257
+ if ctx.re(s) <= 0:
258
+ raise ValueError("prime zeta function defined only for re(s) > 0")
259
+ if s == 1:
260
+ return ctx.inf
261
+ if s == 0.5:
262
+ return ctx.mpc(ctx.ninf, ctx.pi)
263
+ r = ctx.re(s)
264
+ if r > ctx.prec:
265
+ return 0.5**s
266
+ else:
267
+ wp = ctx.prec + int(r)
268
+ def terms():
269
+ orig = ctx.prec
270
+ # zeta ~ 1+eps; need to set precision
271
+ # to get logarithm accurately
272
+ k = 0
273
+ while 1:
274
+ k += 1
275
+ u = ctx.moebius(k)
276
+ if not u:
277
+ continue
278
+ ctx.prec = wp
279
+ t = u*ctx.ln(ctx.zeta(k*s))/k
280
+ if not t:
281
+ return
282
+ #print ctx.prec, ctx.nstr(t)
283
+ ctx.prec = orig
284
+ yield t
285
+ return ctx.sum_accurately(terms)
286
+
287
+ # TODO: for bernpoly and eulerpoly, ensure that all exact zeros are covered
288
+
289
+ @defun_wrapped
290
+ def bernpoly(ctx, n, z):
291
+ # Slow implementation:
292
+ #return sum(ctx.binomial(n,k)*ctx.bernoulli(k)*z**(n-k) for k in xrange(0,n+1))
293
+ n = int(n)
294
+ if n < 0:
295
+ raise ValueError("Bernoulli polynomials only defined for n >= 0")
296
+ if z == 0 or (z == 1 and n > 1):
297
+ return ctx.bernoulli(n)
298
+ if z == 0.5:
299
+ return (ctx.ldexp(1,1-n)-1)*ctx.bernoulli(n)
300
+ if n <= 3:
301
+ if n == 0: return z ** 0
302
+ if n == 1: return z - 0.5
303
+ if n == 2: return (6*z*(z-1)+1)/6
304
+ if n == 3: return z*(z*(z-1.5)+0.5)
305
+ if ctx.isinf(z):
306
+ return z ** n
307
+ if ctx.isnan(z):
308
+ return z
309
+ if abs(z) > 2:
310
+ def terms():
311
+ t = ctx.one
312
+ yield t
313
+ r = ctx.one/z
314
+ k = 1
315
+ while k <= n:
316
+ t = t*(n+1-k)/k*r
317
+ if not (k > 2 and k & 1):
318
+ yield t*ctx.bernoulli(k)
319
+ k += 1
320
+ return ctx.sum_accurately(terms) * z**n
321
+ else:
322
+ def terms():
323
+ yield ctx.bernoulli(n)
324
+ t = ctx.one
325
+ k = 1
326
+ while k <= n:
327
+ t = t*(n+1-k)/k * z
328
+ m = n-k
329
+ if not (m > 2 and m & 1):
330
+ yield t*ctx.bernoulli(m)
331
+ k += 1
332
+ return ctx.sum_accurately(terms)
333
+
334
+ @defun_wrapped
335
+ def eulerpoly(ctx, n, z):
336
+ n = int(n)
337
+ if n < 0:
338
+ raise ValueError("Euler polynomials only defined for n >= 0")
339
+ if n <= 2:
340
+ if n == 0: return z ** 0
341
+ if n == 1: return z - 0.5
342
+ if n == 2: return z*(z-1)
343
+ if ctx.isinf(z):
344
+ return z**n
345
+ if ctx.isnan(z):
346
+ return z
347
+ m = n+1
348
+ if z == 0:
349
+ return -2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
350
+ if z == 1:
351
+ return 2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
352
+ if z == 0.5:
353
+ if n % 2:
354
+ return ctx.zero
355
+ # Use exact code for Euler numbers
356
+ if n < 100 or n*ctx.mag(0.46839865*n) < ctx.prec*0.25:
357
+ return ctx.ldexp(ctx._eulernum(n), -n)
358
+ # http://functions.wolfram.com/Polynomials/EulerE2/06/01/02/01/0002/
359
+ def terms():
360
+ t = ctx.one
361
+ k = 0
362
+ w = ctx.ldexp(1,n+2)
363
+ while 1:
364
+ v = n-k+1
365
+ if not (v > 2 and v & 1):
366
+ yield (2-w)*ctx.bernoulli(v)*t
367
+ k += 1
368
+ if k > n:
369
+ break
370
+ t = t*z*(n-k+2)/k
371
+ w *= 0.5
372
+ return ctx.sum_accurately(terms) / m
373
+
374
+ @defun
375
+ def eulernum(ctx, n, exact=False):
376
+ n = int(n)
377
+ if exact:
378
+ return int(ctx._eulernum(n))
379
+ if n < 100:
380
+ return ctx.mpf(ctx._eulernum(n))
381
+ if n % 2:
382
+ return ctx.zero
383
+ return ctx.ldexp(ctx.eulerpoly(n,0.5), n)
384
+
385
+ # TODO: this should be implemented low-level
386
+ def polylog_series(ctx, s, z):
387
+ tol = +ctx.eps
388
+ l = ctx.zero
389
+ k = 1
390
+ zk = z
391
+ while 1:
392
+ term = zk / k**s
393
+ l += term
394
+ if abs(term) < tol:
395
+ break
396
+ zk *= z
397
+ k += 1
398
+ return l
399
+
400
+ def polylog_continuation(ctx, n, z):
401
+ if n < 0:
402
+ return z*0
403
+ twopij = 2j * ctx.pi
404
+ a = -twopij**n/ctx.fac(n) * ctx.bernpoly(n, ctx.ln(z)/twopij)
405
+ if ctx._is_real_type(z) and z < 0:
406
+ a = ctx._re(a)
407
+ if ctx._im(z) < 0 or (ctx._im(z) == 0 and ctx._re(z) >= 1):
408
+ a -= twopij*ctx.ln(z)**(n-1)/ctx.fac(n-1)
409
+ return a
410
+
411
+ def polylog_unitcircle(ctx, n, z):
412
+ tol = +ctx.eps
413
+ if n > 1:
414
+ l = ctx.zero
415
+ logz = ctx.ln(z)
416
+ logmz = ctx.one
417
+ m = 0
418
+ while 1:
419
+ if (n-m) != 1:
420
+ term = ctx.zeta(n-m) * logmz / ctx.fac(m)
421
+ if term and abs(term) < tol:
422
+ break
423
+ l += term
424
+ logmz *= logz
425
+ m += 1
426
+ l += ctx.ln(z)**(n-1)/ctx.fac(n-1)*(ctx.harmonic(n-1)-ctx.ln(-ctx.ln(z)))
427
+ elif n < 1: # else
428
+ l = ctx.fac(-n)*(-ctx.ln(z))**(n-1)
429
+ logz = ctx.ln(z)
430
+ logkz = ctx.one
431
+ k = 0
432
+ while 1:
433
+ b = ctx.bernoulli(k-n+1)
434
+ if b:
435
+ term = b*logkz/(ctx.fac(k)*(k-n+1))
436
+ if abs(term) < tol:
437
+ break
438
+ l -= term
439
+ logkz *= logz
440
+ k += 1
441
+ else:
442
+ raise ValueError
443
+ if ctx._is_real_type(z) and z < 0:
444
+ l = ctx._re(l)
445
+ return l
446
+
447
+ def polylog_general(ctx, s, z):
448
+ v = ctx.zero
449
+ u = ctx.ln(z)
450
+ if not abs(u) < 5: # theoretically |u| < 2*pi
451
+ j = ctx.j
452
+ v = 1-s
453
+ y = ctx.ln(-z)/(2*ctx.pi*j)
454
+ return ctx.gamma(v)*(j**v*ctx.zeta(v,0.5+y) + j**-v*ctx.zeta(v,0.5-y))/(2*ctx.pi)**v
455
+ t = 1
456
+ k = 0
457
+ while 1:
458
+ term = ctx.zeta(s-k) * t
459
+ if abs(term) < ctx.eps:
460
+ break
461
+ v += term
462
+ k += 1
463
+ t *= u
464
+ t /= k
465
+ return ctx.gamma(1-s)*(-u)**(s-1) + v
466
+
467
+ @defun_wrapped
468
+ def polylog(ctx, s, z):
469
+ s = ctx.convert(s)
470
+ z = ctx.convert(z)
471
+ if z == 1:
472
+ return ctx.zeta(s)
473
+ if z == -1:
474
+ return -ctx.altzeta(s)
475
+ if s == 0:
476
+ return z/(1-z)
477
+ if s == 1:
478
+ return -ctx.ln(1-z)
479
+ if s == -1:
480
+ return z/(1-z)**2
481
+ if abs(z) <= 0.75 or (not ctx.isint(s) and abs(z) < 0.9):
482
+ return polylog_series(ctx, s, z)
483
+ if abs(z) >= 1.4 and ctx.isint(s):
484
+ return (-1)**(s+1)*polylog_series(ctx, s, 1/z) + polylog_continuation(ctx, int(ctx.re(s)), z)
485
+ if ctx.isint(s):
486
+ return polylog_unitcircle(ctx, int(ctx.re(s)), z)
487
+ return polylog_general(ctx, s, z)
488
+
489
+ @defun_wrapped
490
+ def clsin(ctx, s, z, pi=False):
491
+ if ctx.isint(s) and s < 0 and int(s) % 2 == 1:
492
+ return z*0
493
+ if pi:
494
+ a = ctx.expjpi(z)
495
+ else:
496
+ a = ctx.expj(z)
497
+ if ctx._is_real_type(z) and ctx._is_real_type(s):
498
+ return ctx.im(ctx.polylog(s,a))
499
+ b = 1/a
500
+ return (-0.5j)*(ctx.polylog(s,a) - ctx.polylog(s,b))
501
+
502
+ @defun_wrapped
503
+ def clcos(ctx, s, z, pi=False):
504
+ if ctx.isint(s) and s < 0 and int(s) % 2 == 0:
505
+ return z*0
506
+ if pi:
507
+ a = ctx.expjpi(z)
508
+ else:
509
+ a = ctx.expj(z)
510
+ if ctx._is_real_type(z) and ctx._is_real_type(s):
511
+ return ctx.re(ctx.polylog(s,a))
512
+ b = 1/a
513
+ return 0.5*(ctx.polylog(s,a) + ctx.polylog(s,b))
514
+
515
+ @defun
516
+ def altzeta(ctx, s, **kwargs):
517
+ try:
518
+ return ctx._altzeta(s, **kwargs)
519
+ except NotImplementedError:
520
+ return ctx._altzeta_generic(s)
521
+
522
+ @defun_wrapped
523
+ def _altzeta_generic(ctx, s):
524
+ if s == 1:
525
+ return ctx.ln2 + 0*s
526
+ return -ctx.powm1(2, 1-s) * ctx.zeta(s)
527
+
528
+ @defun
529
+ def zeta(ctx, s, a=1, derivative=0, method=None, **kwargs):
530
+ d = int(derivative)
531
+ if a == 1 and not (d or method):
532
+ try:
533
+ return ctx._zeta(s, **kwargs)
534
+ except NotImplementedError:
535
+ pass
536
+ s = ctx.convert(s)
537
+ prec = ctx.prec
538
+ method = kwargs.get('method')
539
+ verbose = kwargs.get('verbose')
540
+ if (not s) and (not derivative):
541
+ return ctx.mpf(0.5) - ctx._convert_param(a)[0]
542
+ if a == 1 and method != 'euler-maclaurin':
543
+ im = abs(ctx._im(s))
544
+ re = abs(ctx._re(s))
545
+ #if (im < prec or method == 'borwein') and not derivative:
546
+ # try:
547
+ # if verbose:
548
+ # print "zeta: Attempting to use the Borwein algorithm"
549
+ # return ctx._zeta(s, **kwargs)
550
+ # except NotImplementedError:
551
+ # if verbose:
552
+ # print "zeta: Could not use the Borwein algorithm"
553
+ # pass
554
+ if abs(im) > 500*prec and 10*re < prec and derivative <= 4 or \
555
+ method == 'riemann-siegel':
556
+ try: # py2.4 compatible try block
557
+ try:
558
+ if verbose:
559
+ print("zeta: Attempting to use the Riemann-Siegel algorithm")
560
+ return ctx.rs_zeta(s, derivative, **kwargs)
561
+ except NotImplementedError:
562
+ if verbose:
563
+ print("zeta: Could not use the Riemann-Siegel algorithm")
564
+ pass
565
+ finally:
566
+ ctx.prec = prec
567
+ if s == 1:
568
+ return ctx.inf
569
+ abss = abs(s)
570
+ if abss == ctx.inf:
571
+ if ctx.re(s) == ctx.inf:
572
+ if d == 0:
573
+ return ctx.one
574
+ return ctx.zero
575
+ return s*0
576
+ elif ctx.isnan(abss):
577
+ return 1/s
578
+ if ctx.re(s) > 2*ctx.prec and a == 1 and not derivative:
579
+ return ctx.one + ctx.power(2, -s)
580
+ return +ctx._hurwitz(s, a, d, **kwargs)
581
+
582
+ @defun
583
+ def _hurwitz(ctx, s, a=1, d=0, **kwargs):
584
+ prec = ctx.prec
585
+ verbose = kwargs.get('verbose')
586
+ try:
587
+ extraprec = 10
588
+ ctx.prec += extraprec
589
+ # We strongly want to special-case rational a
590
+ a, atype = ctx._convert_param(a)
591
+ if ctx.re(s) < 0:
592
+ if verbose:
593
+ print("zeta: Attempting reflection formula")
594
+ try:
595
+ return _hurwitz_reflection(ctx, s, a, d, atype)
596
+ except NotImplementedError:
597
+ pass
598
+ if verbose:
599
+ print("zeta: Reflection formula failed")
600
+ if verbose:
601
+ print("zeta: Using the Euler-Maclaurin algorithm")
602
+ while 1:
603
+ ctx.prec = prec + extraprec
604
+ T1, T2 = _hurwitz_em(ctx, s, a, d, prec+10, verbose)
605
+ cancellation = ctx.mag(T1) - ctx.mag(T1+T2)
606
+ if verbose:
607
+ print("Term 1:", T1)
608
+ print("Term 2:", T2)
609
+ print("Cancellation:", cancellation, "bits")
610
+ if cancellation < extraprec:
611
+ return T1 + T2
612
+ else:
613
+ extraprec = max(2*extraprec, min(cancellation + 5, 100*prec))
614
+ if extraprec > kwargs.get('maxprec', 100*prec):
615
+ raise ctx.NoConvergence("zeta: too much cancellation")
616
+ finally:
617
+ ctx.prec = prec
618
+
619
+ def _hurwitz_reflection(ctx, s, a, d, atype):
620
+ # TODO: implement for derivatives
621
+ if d != 0:
622
+ raise NotImplementedError
623
+ res = ctx.re(s)
624
+ negs = -s
625
+ # Integer reflection formula
626
+ if ctx.isnpint(s):
627
+ n = int(res)
628
+ if n <= 0:
629
+ return ctx.bernpoly(1-n, a) / (n-1)
630
+ if not (atype == 'Q' or atype == 'Z'):
631
+ raise NotImplementedError
632
+ t = 1-s
633
+ # We now require a to be standardized
634
+ v = 0
635
+ shift = 0
636
+ b = a
637
+ while ctx.re(b) > 1:
638
+ b -= 1
639
+ v -= b**negs
640
+ shift -= 1
641
+ while ctx.re(b) <= 0:
642
+ v += b**negs
643
+ b += 1
644
+ shift += 1
645
+ # Rational reflection formula
646
+ try:
647
+ p, q = a._mpq_
648
+ except:
649
+ assert a == int(a)
650
+ p = int(a)
651
+ q = 1
652
+ p += shift*q
653
+ assert 1 <= p <= q
654
+ g = ctx.fsum(ctx.cospi(t/2-2*k*b)*ctx._hurwitz(t,(k,q)) \
655
+ for k in range(1,q+1))
656
+ g *= 2*ctx.gamma(t)/(2*ctx.pi*q)**t
657
+ v += g
658
+ return v
659
+
660
+ def _hurwitz_em(ctx, s, a, d, prec, verbose):
661
+ # May not be converted at this point
662
+ a = ctx.convert(a)
663
+ tol = -prec
664
+ # Estimate number of terms for Euler-Maclaurin summation; could be improved
665
+ M1 = 0
666
+ M2 = prec // 3
667
+ N = M2
668
+ lsum = 0
669
+ # This speeds up the recurrence for derivatives
670
+ if ctx.isint(s):
671
+ s = int(ctx._re(s))
672
+ s1 = s-1
673
+ while 1:
674
+ # Truncated L-series
675
+ l = ctx._zetasum(s, M1+a, M2-M1-1, [d])[0][0]
676
+ #if d:
677
+ # l = ctx.fsum((-ctx.ln(n+a))**d * (n+a)**negs for n in range(M1,M2))
678
+ #else:
679
+ # l = ctx.fsum((n+a)**negs for n in range(M1,M2))
680
+ lsum += l
681
+ M2a = M2+a
682
+ logM2a = ctx.ln(M2a)
683
+ logM2ad = logM2a**d
684
+ logs = [logM2ad]
685
+ logr = 1/logM2a
686
+ rM2a = 1/M2a
687
+ M2as = M2a**(-s)
688
+ if d:
689
+ tailsum = ctx.gammainc(d+1, s1*logM2a) / s1**(d+1)
690
+ else:
691
+ tailsum = 1/((s1)*(M2a)**s1)
692
+ tailsum += 0.5 * logM2ad * M2as
693
+ U = [1]
694
+ r = M2as
695
+ fact = 2
696
+ for j in range(1, N+1):
697
+ # TODO: the following could perhaps be tidied a bit
698
+ j2 = 2*j
699
+ if j == 1:
700
+ upds = [1]
701
+ else:
702
+ upds = [j2-2, j2-1]
703
+ for m in upds:
704
+ D = min(m,d+1)
705
+ if m <= d:
706
+ logs.append(logs[-1] * logr)
707
+ Un = [0]*(D+1)
708
+ for i in xrange(D): Un[i] = (1-m-s)*U[i]
709
+ for i in xrange(1,D+1): Un[i] += (d-(i-1))*U[i-1]
710
+ U = Un
711
+ r *= rM2a
712
+ t = ctx.fdot(U, logs) * r * ctx.bernoulli(j2)/(-fact)
713
+ tailsum += t
714
+ if ctx.mag(t) < tol:
715
+ return lsum, (-1)**d * tailsum
716
+ fact *= (j2+1)*(j2+2)
717
+ if verbose:
718
+ print("Sum range:", M1, M2, "term magnitude", ctx.mag(t), "tolerance", tol)
719
+ M1, M2 = M2, M2*2
720
+ if ctx.re(s) < 0:
721
+ N += N//2
722
+
723
+
724
+
725
+ @defun
726
+ def _zetasum(ctx, s, a, n, derivatives=[0], reflect=False):
727
+ """
728
+ Returns [xd0,xd1,...,xdr], [yd0,yd1,...ydr] where
729
+
730
+ xdk = D^k ( 1/a^s + 1/(a+1)^s + ... + 1/(a+n)^s )
731
+ ydk = D^k conj( 1/a^(1-s) + 1/(a+1)^(1-s) + ... + 1/(a+n)^(1-s) )
732
+
733
+ D^k = kth derivative with respect to s, k ranges over the given list of
734
+ derivatives (which should consist of either a single element
735
+ or a range 0,1,...r). If reflect=False, the ydks are not computed.
736
+ """
737
+ #print "zetasum", s, a, n
738
+ # don't use the fixed-point code if there are large exponentials
739
+ if abs(ctx.re(s)) < 0.5 * ctx.prec:
740
+ try:
741
+ return ctx._zetasum_fast(s, a, n, derivatives, reflect)
742
+ except NotImplementedError:
743
+ pass
744
+ negs = ctx.fneg(s, exact=True)
745
+ have_derivatives = derivatives != [0]
746
+ have_one_derivative = len(derivatives) == 1
747
+ if not reflect:
748
+ if not have_derivatives:
749
+ return [ctx.fsum((a+k)**negs for k in xrange(n+1))], []
750
+ if have_one_derivative:
751
+ d = derivatives[0]
752
+ x = ctx.fsum(ctx.ln(a+k)**d * (a+k)**negs for k in xrange(n+1))
753
+ return [(-1)**d * x], []
754
+ maxd = max(derivatives)
755
+ if not have_one_derivative:
756
+ derivatives = range(maxd+1)
757
+ xs = [ctx.zero for d in derivatives]
758
+ if reflect:
759
+ ys = [ctx.zero for d in derivatives]
760
+ else:
761
+ ys = []
762
+ for k in xrange(n+1):
763
+ w = a + k
764
+ xterm = w ** negs
765
+ if reflect:
766
+ yterm = ctx.conj(ctx.one / (w * xterm))
767
+ if have_derivatives:
768
+ logw = -ctx.ln(w)
769
+ if have_one_derivative:
770
+ logw = logw ** maxd
771
+ xs[0] += xterm * logw
772
+ if reflect:
773
+ ys[0] += yterm * logw
774
+ else:
775
+ t = ctx.one
776
+ for d in derivatives:
777
+ xs[d] += xterm * t
778
+ if reflect:
779
+ ys[d] += yterm * t
780
+ t *= logw
781
+ else:
782
+ xs[0] += xterm
783
+ if reflect:
784
+ ys[0] += yterm
785
+ return xs, ys
786
+
787
+ @defun
788
+ def dirichlet(ctx, s, chi=[1], derivative=0):
789
+ s = ctx.convert(s)
790
+ q = len(chi)
791
+ d = int(derivative)
792
+ if d > 2:
793
+ raise NotImplementedError("arbitrary order derivatives")
794
+ prec = ctx.prec
795
+ try:
796
+ ctx.prec += 10
797
+ if s == 1:
798
+ have_pole = True
799
+ for x in chi:
800
+ if x and x != 1:
801
+ have_pole = False
802
+ h = +ctx.eps
803
+ ctx.prec *= 2*(d+1)
804
+ s += h
805
+ if have_pole:
806
+ return +ctx.inf
807
+ z = ctx.zero
808
+ for p in range(1,q+1):
809
+ if chi[p%q]:
810
+ if d == 1:
811
+ z += chi[p%q] * (ctx.zeta(s, (p,q), 1) - \
812
+ ctx.zeta(s, (p,q))*ctx.log(q))
813
+ else:
814
+ z += chi[p%q] * ctx.zeta(s, (p,q))
815
+ z /= q**s
816
+ finally:
817
+ ctx.prec = prec
818
+ return +z
819
+
820
+
821
+ def secondzeta_main_term(ctx, s, a, **kwargs):
822
+ tol = ctx.eps
823
+ f = lambda n: ctx.gammainc(0.5*s, a*gamm**2, regularized=True)*gamm**(-s)
824
+ totsum = term = ctx.zero
825
+ mg = ctx.inf
826
+ n = 0
827
+ while mg > tol:
828
+ totsum += term
829
+ n += 1
830
+ gamm = ctx.im(ctx.zetazero_memoized(n))
831
+ term = f(n)
832
+ mg = abs(term)
833
+ err = 0
834
+ if kwargs.get("error"):
835
+ sg = ctx.re(s)
836
+ err = 0.5*ctx.pi**(-1)*max(1,sg)*a**(sg-0.5)*ctx.log(gamm/(2*ctx.pi))*\
837
+ ctx.gammainc(-0.5, a*gamm**2)/abs(ctx.gamma(s/2))
838
+ err = abs(err)
839
+ return +totsum, err, n
840
+
841
+ def secondzeta_prime_term(ctx, s, a, **kwargs):
842
+ tol = ctx.eps
843
+ f = lambda n: ctx.gammainc(0.5*(1-s),0.25*ctx.log(n)**2 * a**(-1))*\
844
+ ((0.5*ctx.log(n))**(s-1))*ctx.mangoldt(n)/ctx.sqrt(n)/\
845
+ (2*ctx.gamma(0.5*s)*ctx.sqrt(ctx.pi))
846
+ totsum = term = ctx.zero
847
+ mg = ctx.inf
848
+ n = 1
849
+ while mg > tol or n < 9:
850
+ totsum += term
851
+ n += 1
852
+ term = f(n)
853
+ if term == 0:
854
+ mg = ctx.inf
855
+ else:
856
+ mg = abs(term)
857
+ if kwargs.get("error"):
858
+ err = mg
859
+ return +totsum, err, n
860
+
861
+ def secondzeta_exp_term(ctx, s, a):
862
+ if ctx.isint(s) and ctx.re(s) <= 0:
863
+ m = int(round(ctx.re(s)))
864
+ if not m & 1:
865
+ return ctx.mpf('-0.25')**(-m//2)
866
+ tol = ctx.eps
867
+ f = lambda n: (0.25*a)**n/((n+0.5*s)*ctx.fac(n))
868
+ totsum = ctx.zero
869
+ term = f(0)
870
+ mg = ctx.inf
871
+ n = 0
872
+ while mg > tol:
873
+ totsum += term
874
+ n += 1
875
+ term = f(n)
876
+ mg = abs(term)
877
+ v = a**(0.5*s)*totsum/ctx.gamma(0.5*s)
878
+ return v
879
+
880
+ def secondzeta_singular_term(ctx, s, a, **kwargs):
881
+ factor = a**(0.5*(s-1))/(4*ctx.sqrt(ctx.pi)*ctx.gamma(0.5*s))
882
+ extraprec = ctx.mag(factor)
883
+ ctx.prec += extraprec
884
+ factor = a**(0.5*(s-1))/(4*ctx.sqrt(ctx.pi)*ctx.gamma(0.5*s))
885
+ tol = ctx.eps
886
+ f = lambda n: ctx.bernpoly(n,0.75)*(4*ctx.sqrt(a))**n*\
887
+ ctx.gamma(0.5*n)/((s+n-1)*ctx.fac(n))
888
+ totsum = ctx.zero
889
+ mg1 = ctx.inf
890
+ n = 1
891
+ term = f(n)
892
+ mg2 = abs(term)
893
+ while mg2 > tol and mg2 <= mg1:
894
+ totsum += term
895
+ n += 1
896
+ term = f(n)
897
+ totsum += term
898
+ n +=1
899
+ term = f(n)
900
+ mg1 = mg2
901
+ mg2 = abs(term)
902
+ totsum += term
903
+ pole = -2*(s-1)**(-2)+(ctx.euler+ctx.log(16*ctx.pi**2*a))*(s-1)**(-1)
904
+ st = factor*(pole+totsum)
905
+ err = 0
906
+ if kwargs.get("error"):
907
+ if not ((mg2 > tol) and (mg2 <= mg1)):
908
+ if mg2 <= tol:
909
+ err = ctx.mpf(10)**int(ctx.log(abs(factor*tol),10))
910
+ if mg2 > mg1:
911
+ err = ctx.mpf(10)**int(ctx.log(abs(factor*mg1),10))
912
+ err = max(err, ctx.eps*1.)
913
+ ctx.prec -= extraprec
914
+ return +st, err
915
+
916
+ @defun
917
+ def secondzeta(ctx, s, a = 0.015, **kwargs):
918
+ r"""
919
+ Evaluates the secondary zeta function `Z(s)`, defined for
920
+ `\mathrm{Re}(s)>1` by
921
+
922
+ .. math ::
923
+
924
+ Z(s) = \sum_{n=1}^{\infty} \frac{1}{\tau_n^s}
925
+
926
+ where `\frac12+i\tau_n` runs through the zeros of `\zeta(s)` with
927
+ imaginary part positive.
928
+
929
+ `Z(s)` extends to a meromorphic function on `\mathbb{C}` with a
930
+ double pole at `s=1` and simple poles at the points `-2n` for
931
+ `n=0`, 1, 2, ...
932
+
933
+ **Examples**
934
+
935
+ >>> from mpmath import *
936
+ >>> mp.pretty = True; mp.dps = 15
937
+ >>> secondzeta(2)
938
+ 0.023104993115419
939
+ >>> xi = lambda s: 0.5*s*(s-1)*pi**(-0.5*s)*gamma(0.5*s)*zeta(s)
940
+ >>> Xi = lambda t: xi(0.5+t*j)
941
+ >>> chop(-0.5*diff(Xi,0,n=2)/Xi(0))
942
+ 0.023104993115419
943
+
944
+ We may ask for an approximate error value::
945
+
946
+ >>> secondzeta(0.5+100j, error=True)
947
+ ((-0.216272011276718 - 0.844952708937228j), 2.22044604925031e-16)
948
+
949
+ The function has poles at the negative odd integers,
950
+ and dyadic rational values at the negative even integers::
951
+
952
+ >>> mp.dps = 30
953
+ >>> secondzeta(-8)
954
+ -0.67236328125
955
+ >>> secondzeta(-7)
956
+ +inf
957
+
958
+ **Implementation notes**
959
+
960
+ The function is computed as sum of four terms `Z(s)=A(s)-P(s)+E(s)-S(s)`
961
+ respectively main, prime, exponential and singular terms.
962
+ The main term `A(s)` is computed from the zeros of zeta.
963
+ The prime term depends on the von Mangoldt function.
964
+ The singular term is responsible for the poles of the function.
965
+
966
+ The four terms depends on a small parameter `a`. We may change the
967
+ value of `a`. Theoretically this has no effect on the sum of the four
968
+ terms, but in practice may be important.
969
+
970
+ A smaller value of the parameter `a` makes `A(s)` depend on
971
+ a smaller number of zeros of zeta, but `P(s)` uses more values of
972
+ von Mangoldt function.
973
+
974
+ We may also add a verbose option to obtain data about the
975
+ values of the four terms.
976
+
977
+ >>> mp.dps = 10
978
+ >>> secondzeta(0.5 + 40j, error=True, verbose=True)
979
+ main term = (-30190318549.138656312556 - 13964804384.624622876523j)
980
+ computed using 19 zeros of zeta
981
+ prime term = (132717176.89212754625045 + 188980555.17563978290601j)
982
+ computed using 9 values of the von Mangoldt function
983
+ exponential term = (542447428666.07179812536 + 362434922978.80192435203j)
984
+ singular term = (512124392939.98154322355 + 348281138038.65531023921j)
985
+ ((0.059471043 + 0.3463514534j), 1.455191523e-11)
986
+
987
+ >>> secondzeta(0.5 + 40j, a=0.04, error=True, verbose=True)
988
+ main term = (-151962888.19606243907725 - 217930683.90210294051982j)
989
+ computed using 9 zeros of zeta
990
+ prime term = (2476659342.3038722372461 + 28711581821.921627163136j)
991
+ computed using 37 values of the von Mangoldt function
992
+ exponential term = (178506047114.7838188264 + 819674143244.45677330576j)
993
+ singular term = (175877424884.22441310708 + 790744630738.28669174871j)
994
+ ((0.059471043 + 0.3463514534j), 1.455191523e-11)
995
+
996
+ Notice the great cancellation between the four terms. Changing `a`, the
997
+ four terms are very different numbers but the cancellation gives
998
+ the good value of Z(s).
999
+
1000
+ **References**
1001
+
1002
+ A. Voros, Zeta functions for the Riemann zeros, Ann. Institute Fourier,
1003
+ 53, (2003) 665--699.
1004
+
1005
+ A. Voros, Zeta functions over Zeros of Zeta Functions, Lecture Notes
1006
+ of the Unione Matematica Italiana, Springer, 2009.
1007
+ """
1008
+ s = ctx.convert(s)
1009
+ a = ctx.convert(a)
1010
+ tol = ctx.eps
1011
+ if ctx.isint(s) and ctx.re(s) <= 1:
1012
+ if abs(s-1) < tol*1000:
1013
+ return ctx.inf
1014
+ m = int(round(ctx.re(s)))
1015
+ if m & 1:
1016
+ return ctx.inf
1017
+ else:
1018
+ return ((-1)**(-m//2)*\
1019
+ ctx.fraction(8-ctx.eulernum(-m,exact=True),2**(-m+3)))
1020
+ prec = ctx.prec
1021
+ try:
1022
+ t3 = secondzeta_exp_term(ctx, s, a)
1023
+ extraprec = max(ctx.mag(t3),0)
1024
+ ctx.prec += extraprec + 3
1025
+ t1, r1, gt = secondzeta_main_term(ctx,s,a,error='True', verbose='True')
1026
+ t2, r2, pt = secondzeta_prime_term(ctx,s,a,error='True', verbose='True')
1027
+ t4, r4 = secondzeta_singular_term(ctx,s,a,error='True')
1028
+ t3 = secondzeta_exp_term(ctx, s, a)
1029
+ err = r1+r2+r4
1030
+ t = t1-t2+t3-t4
1031
+ if kwargs.get("verbose"):
1032
+ print('main term =', t1)
1033
+ print(' computed using', gt, 'zeros of zeta')
1034
+ print('prime term =', t2)
1035
+ print(' computed using', pt, 'values of the von Mangoldt function')
1036
+ print('exponential term =', t3)
1037
+ print('singular term =', t4)
1038
+ finally:
1039
+ ctx.prec = prec
1040
+ if kwargs.get("error"):
1041
+ w = max(ctx.mag(abs(t)),0)
1042
+ err = max(err*2**w, ctx.eps*1.*2**w)
1043
+ return +t, err
1044
+ return +t
1045
+
1046
+
1047
+ @defun_wrapped
1048
+ def lerchphi(ctx, z, s, a):
1049
+ r"""
1050
+ Gives the Lerch transcendent, defined for `|z| < 1` and
1051
+ `\Re{a} > 0` by
1052
+
1053
+ .. math ::
1054
+
1055
+ \Phi(z,s,a) = \sum_{k=0}^{\infty} \frac{z^k}{(a+k)^s}
1056
+
1057
+ and generally by the recurrence `\Phi(z,s,a) = z \Phi(z,s,a+1) + a^{-s}`
1058
+ along with the integral representation valid for `\Re{a} > 0`
1059
+
1060
+ .. math ::
1061
+
1062
+ \Phi(z,s,a) = \frac{1}{2 a^s} +
1063
+ \int_0^{\infty} \frac{z^t}{(a+t)^s} dt -
1064
+ 2 \int_0^{\infty} \frac{\sin(t \log z - s
1065
+ \operatorname{arctan}(t/a)}{(a^2 + t^2)^{s/2}
1066
+ (e^{2 \pi t}-1)} dt.
1067
+
1068
+ The Lerch transcendent generalizes the Hurwitz zeta function :func:`zeta`
1069
+ (`z = 1`) and the polylogarithm :func:`polylog` (`a = 1`).
1070
+
1071
+ **Examples**
1072
+
1073
+ Several evaluations in terms of simpler functions::
1074
+
1075
+ >>> from mpmath import *
1076
+ >>> mp.dps = 25; mp.pretty = True
1077
+ >>> lerchphi(-1,2,0.5); 4*catalan
1078
+ 3.663862376708876060218414
1079
+ 3.663862376708876060218414
1080
+ >>> diff(lerchphi, (-1,-2,1), (0,1,0)); 7*zeta(3)/(4*pi**2)
1081
+ 0.2131391994087528954617607
1082
+ 0.2131391994087528954617607
1083
+ >>> lerchphi(-4,1,1); log(5)/4
1084
+ 0.4023594781085250936501898
1085
+ 0.4023594781085250936501898
1086
+ >>> lerchphi(-3+2j,1,0.5); 2*atanh(sqrt(-3+2j))/sqrt(-3+2j)
1087
+ (1.142423447120257137774002 + 0.2118232380980201350495795j)
1088
+ (1.142423447120257137774002 + 0.2118232380980201350495795j)
1089
+
1090
+ Evaluation works for complex arguments and `|z| \ge 1`::
1091
+
1092
+ >>> lerchphi(1+2j, 3-j, 4+2j)
1093
+ (0.002025009957009908600539469 + 0.003327897536813558807438089j)
1094
+ >>> lerchphi(-2,2,-2.5)
1095
+ -12.28676272353094275265944
1096
+ >>> lerchphi(10,10,10)
1097
+ (-4.462130727102185701817349e-11 - 1.575172198981096218823481e-12j)
1098
+ >>> lerchphi(10,10,-10.5)
1099
+ (112658784011940.5605789002 - 498113185.5756221777743631j)
1100
+
1101
+ Some degenerate cases::
1102
+
1103
+ >>> lerchphi(0,1,2)
1104
+ 0.5
1105
+ >>> lerchphi(0,1,-2)
1106
+ -0.5
1107
+
1108
+ Reduction to simpler functions::
1109
+
1110
+ >>> lerchphi(1, 4.25+1j, 1)
1111
+ (1.044674457556746668033975 - 0.04674508654012658932271226j)
1112
+ >>> zeta(4.25+1j)
1113
+ (1.044674457556746668033975 - 0.04674508654012658932271226j)
1114
+ >>> lerchphi(1 - 0.5**10, 4.25+1j, 1)
1115
+ (1.044629338021507546737197 - 0.04667768813963388181708101j)
1116
+ >>> lerchphi(3, 4, 1)
1117
+ (1.249503297023366545192592 - 0.2314252413375664776474462j)
1118
+ >>> polylog(4, 3) / 3
1119
+ (1.249503297023366545192592 - 0.2314252413375664776474462j)
1120
+ >>> lerchphi(3, 4, 1 - 0.5**10)
1121
+ (1.253978063946663945672674 - 0.2316736622836535468765376j)
1122
+
1123
+ **References**
1124
+
1125
+ 1. [DLMF]_ section 25.14
1126
+
1127
+ """
1128
+ if z == 0:
1129
+ return a ** (-s)
1130
+ # Faster, but these cases are useful for testing right now
1131
+ if z == 1:
1132
+ return ctx.zeta(s, a)
1133
+ if a == 1:
1134
+ return ctx.polylog(s, z) / z
1135
+ if ctx.re(a) < 1:
1136
+ if ctx.isnpint(a):
1137
+ raise ValueError("Lerch transcendent complex infinity")
1138
+ m = int(ctx.ceil(1-ctx.re(a)))
1139
+ v = ctx.zero
1140
+ zpow = ctx.one
1141
+ for n in xrange(m):
1142
+ v += zpow / (a+n)**s
1143
+ zpow *= z
1144
+ return zpow * ctx.lerchphi(z,s, a+m) + v
1145
+ g = ctx.ln(z)
1146
+ v = 1/(2*a**s) + ctx.gammainc(1-s, -a*g) * (-g)**(s-1) / z**a
1147
+ h = s / 2
1148
+ r = 2*ctx.pi
1149
+ f = lambda t: ctx.sin(s*ctx.atan(t/a)-t*g) / \
1150
+ ((a**2+t**2)**h * ctx.expm1(r*t))
1151
+ v += 2*ctx.quad(f, [0, ctx.inf])
1152
+ if not ctx.im(z) and not ctx.im(s) and not ctx.im(a) and ctx.re(z) < 1:
1153
+ v = ctx.chop(v)
1154
+ return v
lib/python3.11/site-packages/mpmath/functions/zetazeros.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The function zetazero(n) computes the n-th nontrivial zero of zeta(s).
3
+
4
+ The general strategy is to locate a block of Gram intervals B where we
5
+ know exactly the number of zeros contained and which of those zeros
6
+ is that which we search.
7
+
8
+ If n <= 400 000 000 we know exactly the Rosser exceptions, contained
9
+ in a list in this file. Hence for n<=400 000 000 we simply
10
+ look at these list of exceptions. If our zero is implicated in one of
11
+ these exceptions we have our block B. In other case we simply locate
12
+ the good Rosser block containing our zero.
13
+
14
+ For n > 400 000 000 we apply the method of Turing, as complemented by
15
+ Lehman, Brent and Trudgian to find a suitable B.
16
+ """
17
+
18
+ from .functions import defun, defun_wrapped
19
+
20
+ def find_rosser_block_zero(ctx, n):
21
+ """for n<400 000 000 determines a block were one find our zero"""
22
+ for k in range(len(_ROSSER_EXCEPTIONS)//2):
23
+ a=_ROSSER_EXCEPTIONS[2*k][0]
24
+ b=_ROSSER_EXCEPTIONS[2*k][1]
25
+ if ((a<= n-2) and (n-1 <= b)):
26
+ t0 = ctx.grampoint(a)
27
+ t1 = ctx.grampoint(b)
28
+ v0 = ctx._fp.siegelz(t0)
29
+ v1 = ctx._fp.siegelz(t1)
30
+ my_zero_number = n-a-1
31
+ zero_number_block = b-a
32
+ pattern = _ROSSER_EXCEPTIONS[2*k+1]
33
+ return (my_zero_number, [a,b], [t0,t1], [v0,v1])
34
+ k = n-2
35
+ t,v,b = compute_triple_tvb(ctx, k)
36
+ T = [t]
37
+ V = [v]
38
+ while b < 0:
39
+ k -= 1
40
+ t,v,b = compute_triple_tvb(ctx, k)
41
+ T.insert(0,t)
42
+ V.insert(0,v)
43
+ my_zero_number = n-k-1
44
+ m = n-1
45
+ t,v,b = compute_triple_tvb(ctx, m)
46
+ T.append(t)
47
+ V.append(v)
48
+ while b < 0:
49
+ m += 1
50
+ t,v,b = compute_triple_tvb(ctx, m)
51
+ T.append(t)
52
+ V.append(v)
53
+ return (my_zero_number, [k,m], T, V)
54
+
55
+ def wpzeros(t):
56
+ """Precision needed to compute higher zeros"""
57
+ wp = 53
58
+ if t > 3*10**8:
59
+ wp = 63
60
+ if t > 10**11:
61
+ wp = 70
62
+ if t > 10**14:
63
+ wp = 83
64
+ return wp
65
+
66
+ def separate_zeros_in_block(ctx, zero_number_block, T, V, limitloop=None,
67
+ fp_tolerance=None):
68
+ """Separate the zeros contained in the block T, limitloop
69
+ determines how long one must search"""
70
+ if limitloop is None:
71
+ limitloop = ctx.inf
72
+ loopnumber = 0
73
+ variations = count_variations(V)
74
+ while ((variations < zero_number_block) and (loopnumber <limitloop)):
75
+ a = T[0]
76
+ v = V[0]
77
+ newT = [a]
78
+ newV = [v]
79
+ variations = 0
80
+ for n in range(1,len(T)):
81
+ b2 = T[n]
82
+ u = V[n]
83
+ if (u*v>0):
84
+ alpha = ctx.sqrt(u/v)
85
+ b= (alpha*a+b2)/(alpha+1)
86
+ else:
87
+ b = (a+b2)/2
88
+ if fp_tolerance < 10:
89
+ w = ctx._fp.siegelz(b)
90
+ if abs(w)<fp_tolerance:
91
+ w = ctx.siegelz(b)
92
+ else:
93
+ w=ctx.siegelz(b)
94
+ if v*w<0:
95
+ variations += 1
96
+ newT.append(b)
97
+ newV.append(w)
98
+ u = V[n]
99
+ if u*w <0:
100
+ variations += 1
101
+ newT.append(b2)
102
+ newV.append(u)
103
+ a = b2
104
+ v = u
105
+ T = newT
106
+ V = newV
107
+ loopnumber +=1
108
+ if (limitloop>ITERATION_LIMIT)and(loopnumber>2)and(variations+2==zero_number_block):
109
+ dtMax=0
110
+ dtSec=0
111
+ kMax = 0
112
+ for k1 in range(1,len(T)):
113
+ dt = T[k1]-T[k1-1]
114
+ if dt > dtMax:
115
+ kMax=k1
116
+ dtSec = dtMax
117
+ dtMax = dt
118
+ elif (dt<dtMax) and(dt >dtSec):
119
+ dtSec = dt
120
+ if dtMax>3*dtSec:
121
+ f = lambda x: ctx.rs_z(x,derivative=1)
122
+ t0=T[kMax-1]
123
+ t1 = T[kMax]
124
+ t=ctx.findroot(f, (t0,t1), solver ='illinois',verify=False, verbose=False)
125
+ v = ctx.siegelz(t)
126
+ if (t0<t) and (t<t1) and (v*V[kMax]<0):
127
+ T.insert(kMax,t)
128
+ V.insert(kMax,v)
129
+ variations = count_variations(V)
130
+ if variations == zero_number_block:
131
+ separated = True
132
+ else:
133
+ separated = False
134
+ return (T,V, separated)
135
+
136
+ def separate_my_zero(ctx, my_zero_number, zero_number_block, T, V, prec):
137
+ """If we know which zero of this block is mine,
138
+ the function separates the zero"""
139
+ variations = 0
140
+ v0 = V[0]
141
+ for k in range(1,len(V)):
142
+ v1 = V[k]
143
+ if v0*v1 < 0:
144
+ variations +=1
145
+ if variations == my_zero_number:
146
+ k0 = k
147
+ leftv = v0
148
+ rightv = v1
149
+ v0 = v1
150
+ t1 = T[k0]
151
+ t0 = T[k0-1]
152
+ ctx.prec = prec
153
+ wpz = wpzeros(my_zero_number*ctx.log(my_zero_number))
154
+
155
+ guard = 4*ctx.mag(my_zero_number)
156
+ precs = [ctx.prec+4]
157
+ index=0
158
+ while precs[0] > 2*wpz:
159
+ index +=1
160
+ precs = [precs[0] // 2 +3+2*index] + precs
161
+ ctx.prec = precs[0] + guard
162
+ r = ctx.findroot(lambda x:ctx.siegelz(x), (t0,t1), solver ='illinois', verbose=False)
163
+ #print "first step at", ctx.dps, "digits"
164
+ z=ctx.mpc(0.5,r)
165
+ for prec in precs[1:]:
166
+ ctx.prec = prec + guard
167
+ #print "refining to", ctx.dps, "digits"
168
+ znew = z - ctx.zeta(z) / ctx.zeta(z, derivative=1)
169
+ #print "difference", ctx.nstr(abs(z-znew))
170
+ z=ctx.mpc(0.5,ctx.im(znew))
171
+ return ctx.im(z)
172
+
173
+ def sure_number_block(ctx, n):
174
+ """The number of good Rosser blocks needed to apply
175
+ Turing method
176
+ References:
177
+ R. P. Brent, On the Zeros of the Riemann Zeta Function
178
+ in the Critical Strip, Math. Comp. 33 (1979) 1361--1372
179
+ T. Trudgian, Improvements to Turing Method, Math. Comp."""
180
+ if n < 9*10**5:
181
+ return(2)
182
+ g = ctx.grampoint(n-100)
183
+ lg = ctx._fp.ln(g)
184
+ brent = 0.0061 * lg**2 +0.08*lg
185
+ trudgian = 0.0031 * lg**2 +0.11*lg
186
+ N = ctx.ceil(min(brent,trudgian))
187
+ N = int(N)
188
+ return N
189
+
190
+ def compute_triple_tvb(ctx, n):
191
+ t = ctx.grampoint(n)
192
+ v = ctx._fp.siegelz(t)
193
+ if ctx.mag(abs(v))<ctx.mag(t)-45:
194
+ v = ctx.siegelz(t)
195
+ b = v*(-1)**n
196
+ return t,v,b
197
+
198
+
199
+
200
+ ITERATION_LIMIT = 4
201
+
202
+ def search_supergood_block(ctx, n, fp_tolerance):
203
+ """To use for n>400 000 000"""
204
+ sb = sure_number_block(ctx, n)
205
+ number_goodblocks = 0
206
+ m2 = n-1
207
+ t, v, b = compute_triple_tvb(ctx, m2)
208
+ Tf = [t]
209
+ Vf = [v]
210
+ while b < 0:
211
+ m2 += 1
212
+ t,v,b = compute_triple_tvb(ctx, m2)
213
+ Tf.append(t)
214
+ Vf.append(v)
215
+ goodpoints = [m2]
216
+ T = [t]
217
+ V = [v]
218
+ while number_goodblocks < 2*sb:
219
+ m2 += 1
220
+ t, v, b = compute_triple_tvb(ctx, m2)
221
+ T.append(t)
222
+ V.append(v)
223
+ while b < 0:
224
+ m2 += 1
225
+ t,v,b = compute_triple_tvb(ctx, m2)
226
+ T.append(t)
227
+ V.append(v)
228
+ goodpoints.append(m2)
229
+ zn = len(T)-1
230
+ A, B, separated =\
231
+ separate_zeros_in_block(ctx, zn, T, V, limitloop=ITERATION_LIMIT,
232
+ fp_tolerance=fp_tolerance)
233
+ Tf.pop()
234
+ Tf.extend(A)
235
+ Vf.pop()
236
+ Vf.extend(B)
237
+ if separated:
238
+ number_goodblocks += 1
239
+ else:
240
+ number_goodblocks = 0
241
+ T = [t]
242
+ V = [v]
243
+ # Now the same procedure to the left
244
+ number_goodblocks = 0
245
+ m2 = n-2
246
+ t, v, b = compute_triple_tvb(ctx, m2)
247
+ Tf.insert(0,t)
248
+ Vf.insert(0,v)
249
+ while b < 0:
250
+ m2 -= 1
251
+ t,v,b = compute_triple_tvb(ctx, m2)
252
+ Tf.insert(0,t)
253
+ Vf.insert(0,v)
254
+ goodpoints.insert(0,m2)
255
+ T = [t]
256
+ V = [v]
257
+ while number_goodblocks < 2*sb:
258
+ m2 -= 1
259
+ t, v, b = compute_triple_tvb(ctx, m2)
260
+ T.insert(0,t)
261
+ V.insert(0,v)
262
+ while b < 0:
263
+ m2 -= 1
264
+ t,v,b = compute_triple_tvb(ctx, m2)
265
+ T.insert(0,t)
266
+ V.insert(0,v)
267
+ goodpoints.insert(0,m2)
268
+ zn = len(T)-1
269
+ A, B, separated =\
270
+ separate_zeros_in_block(ctx, zn, T, V, limitloop=ITERATION_LIMIT, fp_tolerance=fp_tolerance)
271
+ A.pop()
272
+ Tf = A+Tf
273
+ B.pop()
274
+ Vf = B+Vf
275
+ if separated:
276
+ number_goodblocks += 1
277
+ else:
278
+ number_goodblocks = 0
279
+ T = [t]
280
+ V = [v]
281
+ r = goodpoints[2*sb]
282
+ lg = len(goodpoints)
283
+ s = goodpoints[lg-2*sb-1]
284
+ tr, vr, br = compute_triple_tvb(ctx, r)
285
+ ar = Tf.index(tr)
286
+ ts, vs, bs = compute_triple_tvb(ctx, s)
287
+ as1 = Tf.index(ts)
288
+ T = Tf[ar:as1+1]
289
+ V = Vf[ar:as1+1]
290
+ zn = s-r
291
+ A, B, separated =\
292
+ separate_zeros_in_block(ctx, zn,T,V,limitloop=ITERATION_LIMIT, fp_tolerance=fp_tolerance)
293
+ if separated:
294
+ return (n-r-1,[r,s],A,B)
295
+ q = goodpoints[sb]
296
+ lg = len(goodpoints)
297
+ t = goodpoints[lg-sb-1]
298
+ tq, vq, bq = compute_triple_tvb(ctx, q)
299
+ aq = Tf.index(tq)
300
+ tt, vt, bt = compute_triple_tvb(ctx, t)
301
+ at = Tf.index(tt)
302
+ T = Tf[aq:at+1]
303
+ V = Vf[aq:at+1]
304
+ return (n-q-1,[q,t],T,V)
305
+
306
+ def count_variations(V):
307
+ count = 0
308
+ vold = V[0]
309
+ for n in range(1, len(V)):
310
+ vnew = V[n]
311
+ if vold*vnew < 0:
312
+ count +=1
313
+ vold = vnew
314
+ return count
315
+
316
+ def pattern_construct(ctx, block, T, V):
317
+ pattern = '('
318
+ a = block[0]
319
+ b = block[1]
320
+ t0,v0,b0 = compute_triple_tvb(ctx, a)
321
+ k = 0
322
+ k0 = 0
323
+ for n in range(a+1,b+1):
324
+ t1,v1,b1 = compute_triple_tvb(ctx, n)
325
+ lgT =len(T)
326
+ while (k < lgT) and (T[k] <= t1):
327
+ k += 1
328
+ L = V[k0:k]
329
+ L.append(v1)
330
+ L.insert(0,v0)
331
+ count = count_variations(L)
332
+ pattern = pattern + ("%s" % count)
333
+ if b1 > 0:
334
+ pattern = pattern + ')('
335
+ k0 = k
336
+ t0,v0,b0 = t1,v1,b1
337
+ pattern = pattern[:-1]
338
+ return pattern
339
+
340
+ @defun
341
+ def zetazero(ctx, n, info=False, round=True):
342
+ r"""
343
+ Computes the `n`-th nontrivial zero of `\zeta(s)` on the critical line,
344
+ i.e. returns an approximation of the `n`-th largest complex number
345
+ `s = \frac{1}{2} + ti` for which `\zeta(s) = 0`. Equivalently, the
346
+ imaginary part `t` is a zero of the Z-function (:func:`~mpmath.siegelz`).
347
+
348
+ **Examples**
349
+
350
+ The first few zeros::
351
+
352
+ >>> from mpmath import *
353
+ >>> mp.dps = 25; mp.pretty = True
354
+ >>> zetazero(1)
355
+ (0.5 + 14.13472514173469379045725j)
356
+ >>> zetazero(2)
357
+ (0.5 + 21.02203963877155499262848j)
358
+ >>> zetazero(20)
359
+ (0.5 + 77.14484006887480537268266j)
360
+
361
+ Verifying that the values are zeros::
362
+
363
+ >>> for n in range(1,5):
364
+ ... s = zetazero(n)
365
+ ... chop(zeta(s)), chop(siegelz(s.imag))
366
+ ...
367
+ (0.0, 0.0)
368
+ (0.0, 0.0)
369
+ (0.0, 0.0)
370
+ (0.0, 0.0)
371
+
372
+ Negative indices give the conjugate zeros (`n = 0` is undefined)::
373
+
374
+ >>> zetazero(-1)
375
+ (0.5 - 14.13472514173469379045725j)
376
+
377
+ :func:`~mpmath.zetazero` supports arbitrarily large `n` and arbitrary precision::
378
+
379
+ >>> mp.dps = 15
380
+ >>> zetazero(1234567)
381
+ (0.5 + 727690.906948208j)
382
+ >>> mp.dps = 50
383
+ >>> zetazero(1234567)
384
+ (0.5 + 727690.9069482075392389420041147142092708393819935j)
385
+ >>> chop(zeta(_)/_)
386
+ 0.0
387
+
388
+ with *info=True*, :func:`~mpmath.zetazero` gives additional information::
389
+
390
+ >>> mp.dps = 15
391
+ >>> zetazero(542964976,info=True)
392
+ ((0.5 + 209039046.578535j), [542964969, 542964978], 6, '(013111110)')
393
+
394
+ This means that the zero is between Gram points 542964969 and 542964978;
395
+ it is the 6-th zero between them. Finally (01311110) is the pattern
396
+ of zeros in this interval. The numbers indicate the number of zeros
397
+ in each Gram interval (Rosser blocks between parenthesis). In this case
398
+ there is only one Rosser block of length nine.
399
+ """
400
+ n = int(n)
401
+ if n < 0:
402
+ return ctx.zetazero(-n).conjugate()
403
+ if n == 0:
404
+ raise ValueError("n must be nonzero")
405
+ wpinitial = ctx.prec
406
+ try:
407
+ wpz, fp_tolerance = comp_fp_tolerance(ctx, n)
408
+ ctx.prec = wpz
409
+ if n < 400000000:
410
+ my_zero_number, block, T, V =\
411
+ find_rosser_block_zero(ctx, n)
412
+ else:
413
+ my_zero_number, block, T, V =\
414
+ search_supergood_block(ctx, n, fp_tolerance)
415
+ zero_number_block = block[1]-block[0]
416
+ T, V, separated = separate_zeros_in_block(ctx, zero_number_block, T, V,
417
+ limitloop=ctx.inf, fp_tolerance=fp_tolerance)
418
+ if info:
419
+ pattern = pattern_construct(ctx,block,T,V)
420
+ prec = max(wpinitial, wpz)
421
+ t = separate_my_zero(ctx, my_zero_number, zero_number_block,T,V,prec)
422
+ v = ctx.mpc(0.5,t)
423
+ finally:
424
+ ctx.prec = wpinitial
425
+ if round:
426
+ v =+v
427
+ if info:
428
+ return (v,block,my_zero_number,pattern)
429
+ else:
430
+ return v
431
+
432
+ def gram_index(ctx, t):
433
+ if t > 10**13:
434
+ wp = 3*ctx.log(t, 10)
435
+ else:
436
+ wp = 0
437
+ prec = ctx.prec
438
+ try:
439
+ ctx.prec += wp
440
+ h = int(ctx.siegeltheta(t)/ctx.pi)
441
+ finally:
442
+ ctx.prec = prec
443
+ return(h)
444
+
445
+ def count_to(ctx, t, T, V):
446
+ count = 0
447
+ vold = V[0]
448
+ told = T[0]
449
+ tnew = T[1]
450
+ k = 1
451
+ while tnew < t:
452
+ vnew = V[k]
453
+ if vold*vnew < 0:
454
+ count += 1
455
+ vold = vnew
456
+ k += 1
457
+ tnew = T[k]
458
+ a = ctx.siegelz(t)
459
+ if a*vold < 0:
460
+ count += 1
461
+ return count
462
+
463
+ def comp_fp_tolerance(ctx, n):
464
+ wpz = wpzeros(n*ctx.log(n))
465
+ if n < 15*10**8:
466
+ fp_tolerance = 0.0005
467
+ elif n <= 10**14:
468
+ fp_tolerance = 0.1
469
+ else:
470
+ fp_tolerance = 100
471
+ return wpz, fp_tolerance
472
+
473
+ @defun
474
+ def nzeros(ctx, t):
475
+ r"""
476
+ Computes the number of zeros of the Riemann zeta function in
477
+ `(0,1) \times (0,t]`, usually denoted by `N(t)`.
478
+
479
+ **Examples**
480
+
481
+ The first zero has imaginary part between 14 and 15::
482
+
483
+ >>> from mpmath import *
484
+ >>> mp.dps = 15; mp.pretty = True
485
+ >>> nzeros(14)
486
+ 0
487
+ >>> nzeros(15)
488
+ 1
489
+ >>> zetazero(1)
490
+ (0.5 + 14.1347251417347j)
491
+
492
+ Some closely spaced zeros::
493
+
494
+ >>> nzeros(10**7)
495
+ 21136125
496
+ >>> zetazero(21136125)
497
+ (0.5 + 9999999.32718175j)
498
+ >>> zetazero(21136126)
499
+ (0.5 + 10000000.2400236j)
500
+ >>> nzeros(545439823.215)
501
+ 1500000001
502
+ >>> zetazero(1500000001)
503
+ (0.5 + 545439823.201985j)
504
+ >>> zetazero(1500000002)
505
+ (0.5 + 545439823.325697j)
506
+
507
+ This confirms the data given by J. van de Lune,
508
+ H. J. J. te Riele and D. T. Winter in 1986.
509
+ """
510
+ if t < 14.1347251417347:
511
+ return 0
512
+ x = gram_index(ctx, t)
513
+ k = int(ctx.floor(x))
514
+ wpinitial = ctx.prec
515
+ wpz, fp_tolerance = comp_fp_tolerance(ctx, k)
516
+ ctx.prec = wpz
517
+ a = ctx.siegelz(t)
518
+ if k == -1 and a < 0:
519
+ return 0
520
+ elif k == -1 and a > 0:
521
+ return 1
522
+ if k+2 < 400000000:
523
+ Rblock = find_rosser_block_zero(ctx, k+2)
524
+ else:
525
+ Rblock = search_supergood_block(ctx, k+2, fp_tolerance)
526
+ n1, n2 = Rblock[1]
527
+ if n2-n1 == 1:
528
+ b = Rblock[3][0]
529
+ if a*b > 0:
530
+ ctx.prec = wpinitial
531
+ return k+1
532
+ else:
533
+ ctx.prec = wpinitial
534
+ return k+2
535
+ my_zero_number,block, T, V = Rblock
536
+ zero_number_block = n2-n1
537
+ T, V, separated = separate_zeros_in_block(ctx,\
538
+ zero_number_block, T, V,\
539
+ limitloop=ctx.inf,\
540
+ fp_tolerance=fp_tolerance)
541
+ n = count_to(ctx, t, T, V)
542
+ ctx.prec = wpinitial
543
+ return n+n1+1
544
+
545
+ @defun_wrapped
546
+ def backlunds(ctx, t):
547
+ r"""
548
+ Computes the function
549
+ `S(t) = \operatorname{arg} \zeta(\frac{1}{2} + it) / \pi`.
550
+
551
+ See Titchmarsh Section 9.3 for details of the definition.
552
+
553
+ **Examples**
554
+
555
+ >>> from mpmath import *
556
+ >>> mp.dps = 15; mp.pretty = True
557
+ >>> backlunds(217.3)
558
+ 0.16302205431184
559
+
560
+ Generally, the value is a small number. At Gram points it is an integer,
561
+ frequently equal to 0::
562
+
563
+ >>> chop(backlunds(grampoint(200)))
564
+ 0.0
565
+ >>> backlunds(extraprec(10)(grampoint)(211))
566
+ 1.0
567
+ >>> backlunds(extraprec(10)(grampoint)(232))
568
+ -1.0
569
+
570
+ The number of zeros of the Riemann zeta function up to height `t`
571
+ satisfies `N(t) = \theta(t)/\pi + 1 + S(t)` (see :func:nzeros` and
572
+ :func:`siegeltheta`)::
573
+
574
+ >>> t = 1234.55
575
+ >>> nzeros(t)
576
+ 842
577
+ >>> siegeltheta(t)/pi+1+backlunds(t)
578
+ 842.0
579
+
580
+ """
581
+ return ctx.nzeros(t)-1-ctx.siegeltheta(t)/ctx.pi
582
+
583
+
584
+ """
585
+ _ROSSER_EXCEPTIONS is a list of all exceptions to
586
+ Rosser's rule for n <= 400 000 000.
587
+
588
+ Alternately the entry is of type [n,m], or a string.
589
+ The string is the zero pattern of the Block and the relevant
590
+ adjacent. For example (010)3 corresponds to a block
591
+ composed of three Gram intervals, the first ant third without
592
+ a zero and the intermediate with a zero. The next Gram interval
593
+ contain three zeros. So that in total we have 4 zeros in 4 Gram
594
+ blocks. n and m are the indices of the Gram points of this
595
+ interval of four Gram intervals. The Rosser exception is therefore
596
+ formed by the three Gram intervals that are signaled between
597
+ parenthesis.
598
+
599
+ We have included also some Rosser's exceptions beyond n=400 000 000
600
+ that are noted in the literature by some reason.
601
+
602
+ The list is composed from the data published in the references:
603
+
604
+ R. P. Brent, J. van de Lune, H. J. J. te Riele, D. T. Winter,
605
+ 'On the Zeros of the Riemann Zeta Function in the Critical Strip. II',
606
+ Math. Comp. 39 (1982) 681--688.
607
+ See also Corrigenda in Math. Comp. 46 (1986) 771.
608
+
609
+ J. van de Lune, H. J. J. te Riele,
610
+ 'On the Zeros of the Riemann Zeta Function in the Critical Strip. III',
611
+ Math. Comp. 41 (1983) 759--767.
612
+ See also Corrigenda in Math. Comp. 46 (1986) 771.
613
+
614
+ J. van de Lune,
615
+ 'Sums of Equal Powers of Positive Integers',
616
+ Dissertation,
617
+ Vrije Universiteit te Amsterdam, Centrum voor Wiskunde en Informatica,
618
+ Amsterdam, 1984.
619
+
620
+ Thanks to the authors all this papers and those others that have
621
+ contributed to make this possible.
622
+ """
623
+
624
+
625
+
626
+
627
+
628
+
629
+
630
+ _ROSSER_EXCEPTIONS = \
631
+ [[13999525, 13999528], '(00)3',
632
+ [30783329, 30783332], '(00)3',
633
+ [30930926, 30930929], '3(00)',
634
+ [37592215, 37592218], '(00)3',
635
+ [40870156, 40870159], '(00)3',
636
+ [43628107, 43628110], '(00)3',
637
+ [46082042, 46082045], '(00)3',
638
+ [46875667, 46875670], '(00)3',
639
+ [49624540, 49624543], '3(00)',
640
+ [50799238, 50799241], '(00)3',
641
+ [55221453, 55221456], '3(00)',
642
+ [56948779, 56948782], '3(00)',
643
+ [60515663, 60515666], '(00)3',
644
+ [61331766, 61331770], '(00)40',
645
+ [69784843, 69784846], '3(00)',
646
+ [75052114, 75052117], '(00)3',
647
+ [79545240, 79545243], '3(00)',
648
+ [79652247, 79652250], '3(00)',
649
+ [83088043, 83088046], '(00)3',
650
+ [83689522, 83689525], '3(00)',
651
+ [85348958, 85348961], '(00)3',
652
+ [86513820, 86513823], '(00)3',
653
+ [87947596, 87947599], '3(00)',
654
+ [88600095, 88600098], '(00)3',
655
+ [93681183, 93681186], '(00)3',
656
+ [100316551, 100316554], '3(00)',
657
+ [100788444, 100788447], '(00)3',
658
+ [106236172, 106236175], '(00)3',
659
+ [106941327, 106941330], '3(00)',
660
+ [107287955, 107287958], '(00)3',
661
+ [107532016, 107532019], '3(00)',
662
+ [110571044, 110571047], '(00)3',
663
+ [111885253, 111885256], '3(00)',
664
+ [113239783, 113239786], '(00)3',
665
+ [120159903, 120159906], '(00)3',
666
+ [121424391, 121424394], '3(00)',
667
+ [121692931, 121692934], '3(00)',
668
+ [121934170, 121934173], '3(00)',
669
+ [122612848, 122612851], '3(00)',
670
+ [126116567, 126116570], '(00)3',
671
+ [127936513, 127936516], '(00)3',
672
+ [128710277, 128710280], '3(00)',
673
+ [129398902, 129398905], '3(00)',
674
+ [130461096, 130461099], '3(00)',
675
+ [131331947, 131331950], '3(00)',
676
+ [137334071, 137334074], '3(00)',
677
+ [137832603, 137832606], '(00)3',
678
+ [138799471, 138799474], '3(00)',
679
+ [139027791, 139027794], '(00)3',
680
+ [141617806, 141617809], '(00)3',
681
+ [144454931, 144454934], '(00)3',
682
+ [145402379, 145402382], '3(00)',
683
+ [146130245, 146130248], '3(00)',
684
+ [147059770, 147059773], '(00)3',
685
+ [147896099, 147896102], '3(00)',
686
+ [151097113, 151097116], '(00)3',
687
+ [152539438, 152539441], '(00)3',
688
+ [152863168, 152863171], '3(00)',
689
+ [153522726, 153522729], '3(00)',
690
+ [155171524, 155171527], '3(00)',
691
+ [155366607, 155366610], '(00)3',
692
+ [157260686, 157260689], '3(00)',
693
+ [157269224, 157269227], '(00)3',
694
+ [157755123, 157755126], '(00)3',
695
+ [158298484, 158298487], '3(00)',
696
+ [160369050, 160369053], '3(00)',
697
+ [162962787, 162962790], '(00)3',
698
+ [163724709, 163724712], '(00)3',
699
+ [164198113, 164198116], '3(00)',
700
+ [164689301, 164689305], '(00)40',
701
+ [164880228, 164880231], '3(00)',
702
+ [166201932, 166201935], '(00)3',
703
+ [168573836, 168573839], '(00)3',
704
+ [169750763, 169750766], '(00)3',
705
+ [170375507, 170375510], '(00)3',
706
+ [170704879, 170704882], '3(00)',
707
+ [172000992, 172000995], '3(00)',
708
+ [173289941, 173289944], '(00)3',
709
+ [173737613, 173737616], '3(00)',
710
+ [174102513, 174102516], '(00)3',
711
+ [174284990, 174284993], '(00)3',
712
+ [174500513, 174500516], '(00)3',
713
+ [175710609, 175710612], '(00)3',
714
+ [176870843, 176870846], '3(00)',
715
+ [177332732, 177332735], '3(00)',
716
+ [177902861, 177902864], '3(00)',
717
+ [179979095, 179979098], '(00)3',
718
+ [181233726, 181233729], '3(00)',
719
+ [181625435, 181625438], '(00)3',
720
+ [182105255, 182105259], '22(00)',
721
+ [182223559, 182223562], '3(00)',
722
+ [191116404, 191116407], '3(00)',
723
+ [191165599, 191165602], '3(00)',
724
+ [191297535, 191297539], '(00)22',
725
+ [192485616, 192485619], '(00)3',
726
+ [193264634, 193264638], '22(00)',
727
+ [194696968, 194696971], '(00)3',
728
+ [195876805, 195876808], '(00)3',
729
+ [195916548, 195916551], '3(00)',
730
+ [196395160, 196395163], '3(00)',
731
+ [196676303, 196676306], '(00)3',
732
+ [197889882, 197889885], '3(00)',
733
+ [198014122, 198014125], '(00)3',
734
+ [199235289, 199235292], '(00)3',
735
+ [201007375, 201007378], '(00)3',
736
+ [201030605, 201030608], '3(00)',
737
+ [201184290, 201184293], '3(00)',
738
+ [201685414, 201685418], '(00)22',
739
+ [202762875, 202762878], '3(00)',
740
+ [202860957, 202860960], '3(00)',
741
+ [203832577, 203832580], '3(00)',
742
+ [205880544, 205880547], '(00)3',
743
+ [206357111, 206357114], '(00)3',
744
+ [207159767, 207159770], '3(00)',
745
+ [207167343, 207167346], '3(00)',
746
+ [207482539, 207482543], '3(010)',
747
+ [207669540, 207669543], '3(00)',
748
+ [208053426, 208053429], '(00)3',
749
+ [208110027, 208110030], '3(00)',
750
+ [209513826, 209513829], '3(00)',
751
+ [212623522, 212623525], '(00)3',
752
+ [213841715, 213841718], '(00)3',
753
+ [214012333, 214012336], '(00)3',
754
+ [214073567, 214073570], '(00)3',
755
+ [215170600, 215170603], '3(00)',
756
+ [215881039, 215881042], '3(00)',
757
+ [216274604, 216274607], '3(00)',
758
+ [216957120, 216957123], '3(00)',
759
+ [217323208, 217323211], '(00)3',
760
+ [218799264, 218799267], '(00)3',
761
+ [218803557, 218803560], '3(00)',
762
+ [219735146, 219735149], '(00)3',
763
+ [219830062, 219830065], '3(00)',
764
+ [219897904, 219897907], '(00)3',
765
+ [221205545, 221205548], '(00)3',
766
+ [223601929, 223601932], '(00)3',
767
+ [223907076, 223907079], '3(00)',
768
+ [223970397, 223970400], '(00)3',
769
+ [224874044, 224874048], '22(00)',
770
+ [225291157, 225291160], '(00)3',
771
+ [227481734, 227481737], '(00)3',
772
+ [228006442, 228006445], '3(00)',
773
+ [228357900, 228357903], '(00)3',
774
+ [228386399, 228386402], '(00)3',
775
+ [228907446, 228907449], '(00)3',
776
+ [228984552, 228984555], '3(00)',
777
+ [229140285, 229140288], '3(00)',
778
+ [231810024, 231810027], '(00)3',
779
+ [232838062, 232838065], '3(00)',
780
+ [234389088, 234389091], '3(00)',
781
+ [235588194, 235588197], '(00)3',
782
+ [236645695, 236645698], '(00)3',
783
+ [236962876, 236962879], '3(00)',
784
+ [237516723, 237516727], '04(00)',
785
+ [240004911, 240004914], '(00)3',
786
+ [240221306, 240221309], '3(00)',
787
+ [241389213, 241389217], '(010)3',
788
+ [241549003, 241549006], '(00)3',
789
+ [241729717, 241729720], '(00)3',
790
+ [241743684, 241743687], '3(00)',
791
+ [243780200, 243780203], '3(00)',
792
+ [243801317, 243801320], '(00)3',
793
+ [244122072, 244122075], '(00)3',
794
+ [244691224, 244691227], '3(00)',
795
+ [244841577, 244841580], '(00)3',
796
+ [245813461, 245813464], '(00)3',
797
+ [246299475, 246299478], '(00)3',
798
+ [246450176, 246450179], '3(00)',
799
+ [249069349, 249069352], '(00)3',
800
+ [250076378, 250076381], '(00)3',
801
+ [252442157, 252442160], '3(00)',
802
+ [252904231, 252904234], '3(00)',
803
+ [255145220, 255145223], '(00)3',
804
+ [255285971, 255285974], '3(00)',
805
+ [256713230, 256713233], '(00)3',
806
+ [257992082, 257992085], '(00)3',
807
+ [258447955, 258447959], '22(00)',
808
+ [259298045, 259298048], '3(00)',
809
+ [262141503, 262141506], '(00)3',
810
+ [263681743, 263681746], '3(00)',
811
+ [266527881, 266527885], '(010)3',
812
+ [266617122, 266617125], '(00)3',
813
+ [266628044, 266628047], '3(00)',
814
+ [267305763, 267305766], '(00)3',
815
+ [267388404, 267388407], '3(00)',
816
+ [267441672, 267441675], '3(00)',
817
+ [267464886, 267464889], '(00)3',
818
+ [267554907, 267554910], '3(00)',
819
+ [269787480, 269787483], '(00)3',
820
+ [270881434, 270881437], '(00)3',
821
+ [270997583, 270997586], '3(00)',
822
+ [272096378, 272096381], '3(00)',
823
+ [272583009, 272583012], '(00)3',
824
+ [274190881, 274190884], '3(00)',
825
+ [274268747, 274268750], '(00)3',
826
+ [275297429, 275297432], '3(00)',
827
+ [275545476, 275545479], '3(00)',
828
+ [275898479, 275898482], '3(00)',
829
+ [275953000, 275953003], '(00)3',
830
+ [277117197, 277117201], '(00)22',
831
+ [277447310, 277447313], '3(00)',
832
+ [279059657, 279059660], '3(00)',
833
+ [279259144, 279259147], '3(00)',
834
+ [279513636, 279513639], '3(00)',
835
+ [279849069, 279849072], '3(00)',
836
+ [280291419, 280291422], '(00)3',
837
+ [281449425, 281449428], '3(00)',
838
+ [281507953, 281507956], '3(00)',
839
+ [281825600, 281825603], '(00)3',
840
+ [282547093, 282547096], '3(00)',
841
+ [283120963, 283120966], '3(00)',
842
+ [283323493, 283323496], '(00)3',
843
+ [284764535, 284764538], '3(00)',
844
+ [286172639, 286172642], '3(00)',
845
+ [286688824, 286688827], '(00)3',
846
+ [287222172, 287222175], '3(00)',
847
+ [287235534, 287235537], '3(00)',
848
+ [287304861, 287304864], '3(00)',
849
+ [287433571, 287433574], '(00)3',
850
+ [287823551, 287823554], '(00)3',
851
+ [287872422, 287872425], '3(00)',
852
+ [288766615, 288766618], '3(00)',
853
+ [290122963, 290122966], '3(00)',
854
+ [290450849, 290450853], '(00)22',
855
+ [291426141, 291426144], '3(00)',
856
+ [292810353, 292810356], '3(00)',
857
+ [293109861, 293109864], '3(00)',
858
+ [293398054, 293398057], '3(00)',
859
+ [294134426, 294134429], '3(00)',
860
+ [294216438, 294216441], '(00)3',
861
+ [295367141, 295367144], '3(00)',
862
+ [297834111, 297834114], '3(00)',
863
+ [299099969, 299099972], '3(00)',
864
+ [300746958, 300746961], '3(00)',
865
+ [301097423, 301097426], '(00)3',
866
+ [301834209, 301834212], '(00)3',
867
+ [302554791, 302554794], '(00)3',
868
+ [303497445, 303497448], '3(00)',
869
+ [304165344, 304165347], '3(00)',
870
+ [304790218, 304790222], '3(010)',
871
+ [305302352, 305302355], '(00)3',
872
+ [306785996, 306785999], '3(00)',
873
+ [307051443, 307051446], '3(00)',
874
+ [307481539, 307481542], '3(00)',
875
+ [308605569, 308605572], '3(00)',
876
+ [309237610, 309237613], '3(00)',
877
+ [310509287, 310509290], '(00)3',
878
+ [310554057, 310554060], '3(00)',
879
+ [310646345, 310646348], '3(00)',
880
+ [311274896, 311274899], '(00)3',
881
+ [311894272, 311894275], '3(00)',
882
+ [312269470, 312269473], '(00)3',
883
+ [312306601, 312306605], '(00)40',
884
+ [312683193, 312683196], '3(00)',
885
+ [314499804, 314499807], '3(00)',
886
+ [314636802, 314636805], '(00)3',
887
+ [314689897, 314689900], '3(00)',
888
+ [314721319, 314721322], '3(00)',
889
+ [316132890, 316132893], '3(00)',
890
+ [316217470, 316217474], '(010)3',
891
+ [316465705, 316465708], '3(00)',
892
+ [316542790, 316542793], '(00)3',
893
+ [320822347, 320822350], '3(00)',
894
+ [321733242, 321733245], '3(00)',
895
+ [324413970, 324413973], '(00)3',
896
+ [325950140, 325950143], '(00)3',
897
+ [326675884, 326675887], '(00)3',
898
+ [326704208, 326704211], '3(00)',
899
+ [327596247, 327596250], '3(00)',
900
+ [328123172, 328123175], '3(00)',
901
+ [328182212, 328182215], '(00)3',
902
+ [328257498, 328257501], '3(00)',
903
+ [328315836, 328315839], '(00)3',
904
+ [328800974, 328800977], '(00)3',
905
+ [328998509, 328998512], '3(00)',
906
+ [329725370, 329725373], '(00)3',
907
+ [332080601, 332080604], '(00)3',
908
+ [332221246, 332221249], '(00)3',
909
+ [332299899, 332299902], '(00)3',
910
+ [332532822, 332532825], '(00)3',
911
+ [333334544, 333334548], '(00)22',
912
+ [333881266, 333881269], '3(00)',
913
+ [334703267, 334703270], '3(00)',
914
+ [334875138, 334875141], '3(00)',
915
+ [336531451, 336531454], '3(00)',
916
+ [336825907, 336825910], '(00)3',
917
+ [336993167, 336993170], '(00)3',
918
+ [337493998, 337494001], '3(00)',
919
+ [337861034, 337861037], '3(00)',
920
+ [337899191, 337899194], '(00)3',
921
+ [337958123, 337958126], '(00)3',
922
+ [342331982, 342331985], '3(00)',
923
+ [342676068, 342676071], '3(00)',
924
+ [347063781, 347063784], '3(00)',
925
+ [347697348, 347697351], '3(00)',
926
+ [347954319, 347954322], '3(00)',
927
+ [348162775, 348162778], '3(00)',
928
+ [349210702, 349210705], '(00)3',
929
+ [349212913, 349212916], '3(00)',
930
+ [349248650, 349248653], '(00)3',
931
+ [349913500, 349913503], '3(00)',
932
+ [350891529, 350891532], '3(00)',
933
+ [351089323, 351089326], '3(00)',
934
+ [351826158, 351826161], '3(00)',
935
+ [352228580, 352228583], '(00)3',
936
+ [352376244, 352376247], '3(00)',
937
+ [352853758, 352853761], '(00)3',
938
+ [355110439, 355110442], '(00)3',
939
+ [355808090, 355808094], '(00)40',
940
+ [355941556, 355941559], '3(00)',
941
+ [356360231, 356360234], '(00)3',
942
+ [356586657, 356586660], '3(00)',
943
+ [356892926, 356892929], '(00)3',
944
+ [356908232, 356908235], '3(00)',
945
+ [357912730, 357912733], '3(00)',
946
+ [358120344, 358120347], '3(00)',
947
+ [359044096, 359044099], '(00)3',
948
+ [360819357, 360819360], '3(00)',
949
+ [361399662, 361399666], '(010)3',
950
+ [362361315, 362361318], '(00)3',
951
+ [363610112, 363610115], '(00)3',
952
+ [363964804, 363964807], '3(00)',
953
+ [364527375, 364527378], '(00)3',
954
+ [365090327, 365090330], '(00)3',
955
+ [365414539, 365414542], '3(00)',
956
+ [366738474, 366738477], '3(00)',
957
+ [368714778, 368714783], '04(010)',
958
+ [368831545, 368831548], '(00)3',
959
+ [368902387, 368902390], '(00)3',
960
+ [370109769, 370109772], '3(00)',
961
+ [370963333, 370963336], '3(00)',
962
+ [372541136, 372541140], '3(010)',
963
+ [372681562, 372681565], '(00)3',
964
+ [373009410, 373009413], '(00)3',
965
+ [373458970, 373458973], '3(00)',
966
+ [375648658, 375648661], '3(00)',
967
+ [376834728, 376834731], '3(00)',
968
+ [377119945, 377119948], '(00)3',
969
+ [377335703, 377335706], '(00)3',
970
+ [378091745, 378091748], '3(00)',
971
+ [379139522, 379139525], '3(00)',
972
+ [380279160, 380279163], '(00)3',
973
+ [380619442, 380619445], '3(00)',
974
+ [381244231, 381244234], '3(00)',
975
+ [382327446, 382327450], '(010)3',
976
+ [382357073, 382357076], '3(00)',
977
+ [383545479, 383545482], '3(00)',
978
+ [384363766, 384363769], '(00)3',
979
+ [384401786, 384401790], '22(00)',
980
+ [385198212, 385198215], '3(00)',
981
+ [385824476, 385824479], '(00)3',
982
+ [385908194, 385908197], '3(00)',
983
+ [386946806, 386946809], '3(00)',
984
+ [387592175, 387592179], '22(00)',
985
+ [388329293, 388329296], '(00)3',
986
+ [388679566, 388679569], '3(00)',
987
+ [388832142, 388832145], '3(00)',
988
+ [390087103, 390087106], '(00)3',
989
+ [390190926, 390190930], '(00)22',
990
+ [390331207, 390331210], '3(00)',
991
+ [391674495, 391674498], '3(00)',
992
+ [391937831, 391937834], '3(00)',
993
+ [391951632, 391951636], '(00)22',
994
+ [392963986, 392963989], '(00)3',
995
+ [393007921, 393007924], '3(00)',
996
+ [393373210, 393373213], '3(00)',
997
+ [393759572, 393759575], '(00)3',
998
+ [394036662, 394036665], '(00)3',
999
+ [395813866, 395813869], '(00)3',
1000
+ [395956690, 395956693], '3(00)',
1001
+ [396031670, 396031673], '3(00)',
1002
+ [397076433, 397076436], '3(00)',
1003
+ [397470601, 397470604], '3(00)',
1004
+ [398289458, 398289461], '3(00)',
1005
+ #
1006
+ [368714778, 368714783], '04(010)',
1007
+ [437953499, 437953504], '04(010)',
1008
+ [526196233, 526196238], '032(00)',
1009
+ [744719566, 744719571], '(010)40',
1010
+ [750375857, 750375862], '032(00)',
1011
+ [958241932, 958241937], '04(010)',
1012
+ [983377342, 983377347], '(00)410',
1013
+ [1003780080, 1003780085], '04(010)',
1014
+ [1070232754, 1070232759], '(00)230',
1015
+ [1209834865, 1209834870], '032(00)',
1016
+ [1257209100, 1257209105], '(00)410',
1017
+ [1368002233, 1368002238], '(00)230'
1018
+ ]
lib/python3.11/site-packages/mpmath/identification.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implements the PSLQ algorithm for integer relation detection,
3
+ and derivative algorithms for constant recognition.
4
+ """
5
+
6
+ from .libmp.backend import xrange
7
+ from .libmp import int_types, sqrt_fixed
8
+
9
+ # round to nearest integer (can be done more elegantly...)
10
+ def round_fixed(x, prec):
11
+ return ((x + (1<<(prec-1))) >> prec) << prec
12
+
13
+ class IdentificationMethods(object):
14
+ pass
15
+
16
+
17
+ def pslq(ctx, x, tol=None, maxcoeff=1000, maxsteps=100, verbose=False):
18
+ r"""
19
+ Given a vector of real numbers `x = [x_0, x_1, ..., x_n]`, ``pslq(x)``
20
+ uses the PSLQ algorithm to find a list of integers
21
+ `[c_0, c_1, ..., c_n]` such that
22
+
23
+ .. math ::
24
+
25
+ |c_1 x_1 + c_2 x_2 + ... + c_n x_n| < \mathrm{tol}
26
+
27
+ and such that `\max |c_k| < \mathrm{maxcoeff}`. If no such vector
28
+ exists, :func:`~mpmath.pslq` returns ``None``. The tolerance defaults to
29
+ 3/4 of the working precision.
30
+
31
+ **Examples**
32
+
33
+ Find rational approximations for `\pi`::
34
+
35
+ >>> from mpmath import *
36
+ >>> mp.dps = 15; mp.pretty = True
37
+ >>> pslq([-1, pi], tol=0.01)
38
+ [22, 7]
39
+ >>> pslq([-1, pi], tol=0.001)
40
+ [355, 113]
41
+ >>> mpf(22)/7; mpf(355)/113; +pi
42
+ 3.14285714285714
43
+ 3.14159292035398
44
+ 3.14159265358979
45
+
46
+ Pi is not a rational number with denominator less than 1000::
47
+
48
+ >>> pslq([-1, pi])
49
+ >>>
50
+
51
+ To within the standard precision, it can however be approximated
52
+ by at least one rational number with denominator less than `10^{12}`::
53
+
54
+ >>> p, q = pslq([-1, pi], maxcoeff=10**12)
55
+ >>> print(p); print(q)
56
+ 238410049439
57
+ 75888275702
58
+ >>> mpf(p)/q
59
+ 3.14159265358979
60
+
61
+ The PSLQ algorithm can be applied to long vectors. For example,
62
+ we can investigate the rational (in)dependence of integer square
63
+ roots::
64
+
65
+ >>> mp.dps = 30
66
+ >>> pslq([sqrt(n) for n in range(2, 5+1)])
67
+ >>>
68
+ >>> pslq([sqrt(n) for n in range(2, 6+1)])
69
+ >>>
70
+ >>> pslq([sqrt(n) for n in range(2, 8+1)])
71
+ [2, 0, 0, 0, 0, 0, -1]
72
+
73
+ **Machin formulas**
74
+
75
+ A famous formula for `\pi` is Machin's,
76
+
77
+ .. math ::
78
+
79
+ \frac{\pi}{4} = 4 \operatorname{acot} 5 - \operatorname{acot} 239
80
+
81
+ There are actually infinitely many formulas of this type. Two
82
+ others are
83
+
84
+ .. math ::
85
+
86
+ \frac{\pi}{4} = \operatorname{acot} 1
87
+
88
+ \frac{\pi}{4} = 12 \operatorname{acot} 49 + 32 \operatorname{acot} 57
89
+ + 5 \operatorname{acot} 239 + 12 \operatorname{acot} 110443
90
+
91
+ We can easily verify the formulas using the PSLQ algorithm::
92
+
93
+ >>> mp.dps = 30
94
+ >>> pslq([pi/4, acot(1)])
95
+ [1, -1]
96
+ >>> pslq([pi/4, acot(5), acot(239)])
97
+ [1, -4, 1]
98
+ >>> pslq([pi/4, acot(49), acot(57), acot(239), acot(110443)])
99
+ [1, -12, -32, 5, -12]
100
+
101
+ We could try to generate a custom Machin-like formula by running
102
+ the PSLQ algorithm with a few inverse cotangent values, for example
103
+ acot(2), acot(3) ... acot(10). Unfortunately, there is a linear
104
+ dependence among these values, resulting in only that dependence
105
+ being detected, with a zero coefficient for `\pi`::
106
+
107
+ >>> pslq([pi] + [acot(n) for n in range(2,11)])
108
+ [0, 1, -1, 0, 0, 0, -1, 0, 0, 0]
109
+
110
+ We get better luck by removing linearly dependent terms::
111
+
112
+ >>> pslq([pi] + [acot(n) for n in range(2,11) if n not in (3, 5)])
113
+ [1, -8, 0, 0, 4, 0, 0, 0]
114
+
115
+ In other words, we found the following formula::
116
+
117
+ >>> 8*acot(2) - 4*acot(7)
118
+ 3.14159265358979323846264338328
119
+ >>> +pi
120
+ 3.14159265358979323846264338328
121
+
122
+ **Algorithm**
123
+
124
+ This is a fairly direct translation to Python of the pseudocode given by
125
+ David Bailey, "The PSLQ Integer Relation Algorithm":
126
+ http://www.cecm.sfu.ca/organics/papers/bailey/paper/html/node3.html
127
+
128
+ The present implementation uses fixed-point instead of floating-point
129
+ arithmetic, since this is significantly (about 7x) faster.
130
+ """
131
+
132
+ n = len(x)
133
+ if n < 2:
134
+ raise ValueError("n cannot be less than 2")
135
+
136
+ # At too low precision, the algorithm becomes meaningless
137
+ prec = ctx.prec
138
+ if prec < 53:
139
+ raise ValueError("prec cannot be less than 53")
140
+
141
+ if verbose and prec // max(2,n) < 5:
142
+ print("Warning: precision for PSLQ may be too low")
143
+
144
+ target = int(prec * 0.75)
145
+
146
+ if tol is None:
147
+ tol = ctx.mpf(2)**(-target)
148
+ else:
149
+ tol = ctx.convert(tol)
150
+
151
+ extra = 60
152
+ prec += extra
153
+
154
+ if verbose:
155
+ print("PSLQ using prec %i and tol %s" % (prec, ctx.nstr(tol)))
156
+
157
+ tol = ctx.to_fixed(tol, prec)
158
+ assert tol
159
+
160
+ # Convert to fixed-point numbers. The dummy None is added so we can
161
+ # use 1-based indexing. (This just allows us to be consistent with
162
+ # Bailey's indexing. The algorithm is 100 lines long, so debugging
163
+ # a single wrong index can be painful.)
164
+ x = [None] + [ctx.to_fixed(ctx.mpf(xk), prec) for xk in x]
165
+
166
+ # Sanity check on magnitudes
167
+ minx = min(abs(xx) for xx in x[1:])
168
+ if not minx:
169
+ raise ValueError("PSLQ requires a vector of nonzero numbers")
170
+ if minx < tol//100:
171
+ if verbose:
172
+ print("STOPPING: (one number is too small)")
173
+ return None
174
+
175
+ g = sqrt_fixed((4<<prec)//3, prec)
176
+ A = {}
177
+ B = {}
178
+ H = {}
179
+ # Initialization
180
+ # step 1
181
+ for i in xrange(1, n+1):
182
+ for j in xrange(1, n+1):
183
+ A[i,j] = B[i,j] = (i==j) << prec
184
+ H[i,j] = 0
185
+ # step 2
186
+ s = [None] + [0] * n
187
+ for k in xrange(1, n+1):
188
+ t = 0
189
+ for j in xrange(k, n+1):
190
+ t += (x[j]**2 >> prec)
191
+ s[k] = sqrt_fixed(t, prec)
192
+ t = s[1]
193
+ y = x[:]
194
+ for k in xrange(1, n+1):
195
+ y[k] = (x[k] << prec) // t
196
+ s[k] = (s[k] << prec) // t
197
+ # step 3
198
+ for i in xrange(1, n+1):
199
+ for j in xrange(i+1, n):
200
+ H[i,j] = 0
201
+ if i <= n-1:
202
+ if s[i]:
203
+ H[i,i] = (s[i+1] << prec) // s[i]
204
+ else:
205
+ H[i,i] = 0
206
+ for j in range(1, i):
207
+ sjj1 = s[j]*s[j+1]
208
+ if sjj1:
209
+ H[i,j] = ((-y[i]*y[j])<<prec)//sjj1
210
+ else:
211
+ H[i,j] = 0
212
+ # step 4
213
+ for i in xrange(2, n+1):
214
+ for j in xrange(i-1, 0, -1):
215
+ #t = floor(H[i,j]/H[j,j] + 0.5)
216
+ if H[j,j]:
217
+ t = round_fixed((H[i,j] << prec)//H[j,j], prec)
218
+ else:
219
+ #t = 0
220
+ continue
221
+ y[j] = y[j] + (t*y[i] >> prec)
222
+ for k in xrange(1, j+1):
223
+ H[i,k] = H[i,k] - (t*H[j,k] >> prec)
224
+ for k in xrange(1, n+1):
225
+ A[i,k] = A[i,k] - (t*A[j,k] >> prec)
226
+ B[k,j] = B[k,j] + (t*B[k,i] >> prec)
227
+ # Main algorithm
228
+ for REP in range(maxsteps):
229
+ # Step 1
230
+ m = -1
231
+ szmax = -1
232
+ for i in range(1, n):
233
+ h = H[i,i]
234
+ sz = (g**i * abs(h)) >> (prec*(i-1))
235
+ if sz > szmax:
236
+ m = i
237
+ szmax = sz
238
+ # Step 2
239
+ y[m], y[m+1] = y[m+1], y[m]
240
+ for i in xrange(1,n+1): H[m,i], H[m+1,i] = H[m+1,i], H[m,i]
241
+ for i in xrange(1,n+1): A[m,i], A[m+1,i] = A[m+1,i], A[m,i]
242
+ for i in xrange(1,n+1): B[i,m], B[i,m+1] = B[i,m+1], B[i,m]
243
+ # Step 3
244
+ if m <= n - 2:
245
+ t0 = sqrt_fixed((H[m,m]**2 + H[m,m+1]**2)>>prec, prec)
246
+ # A zero element probably indicates that the precision has
247
+ # been exhausted. XXX: this could be spurious, due to
248
+ # using fixed-point arithmetic
249
+ if not t0:
250
+ break
251
+ t1 = (H[m,m] << prec) // t0
252
+ t2 = (H[m,m+1] << prec) // t0
253
+ for i in xrange(m, n+1):
254
+ t3 = H[i,m]
255
+ t4 = H[i,m+1]
256
+ H[i,m] = (t1*t3+t2*t4) >> prec
257
+ H[i,m+1] = (-t2*t3+t1*t4) >> prec
258
+ # Step 4
259
+ for i in xrange(m+1, n+1):
260
+ for j in xrange(min(i-1, m+1), 0, -1):
261
+ try:
262
+ t = round_fixed((H[i,j] << prec)//H[j,j], prec)
263
+ # Precision probably exhausted
264
+ except ZeroDivisionError:
265
+ break
266
+ y[j] = y[j] + ((t*y[i]) >> prec)
267
+ for k in xrange(1, j+1):
268
+ H[i,k] = H[i,k] - (t*H[j,k] >> prec)
269
+ for k in xrange(1, n+1):
270
+ A[i,k] = A[i,k] - (t*A[j,k] >> prec)
271
+ B[k,j] = B[k,j] + (t*B[k,i] >> prec)
272
+ # Until a relation is found, the error typically decreases
273
+ # slowly (e.g. a factor 1-10) with each step TODO: we could
274
+ # compare err from two successive iterations. If there is a
275
+ # large drop (several orders of magnitude), that indicates a
276
+ # "high quality" relation was detected. Reporting this to
277
+ # the user somehow might be useful.
278
+ best_err = maxcoeff<<prec
279
+ for i in xrange(1, n+1):
280
+ err = abs(y[i])
281
+ # Maybe we are done?
282
+ if err < tol:
283
+ # We are done if the coefficients are acceptable
284
+ vec = [int(round_fixed(B[j,i], prec) >> prec) for j in \
285
+ range(1,n+1)]
286
+ if max(abs(v) for v in vec) < maxcoeff:
287
+ if verbose:
288
+ print("FOUND relation at iter %i/%i, error: %s" % \
289
+ (REP, maxsteps, ctx.nstr(err / ctx.mpf(2)**prec, 1)))
290
+ return vec
291
+ best_err = min(err, best_err)
292
+ # Calculate a lower bound for the norm. We could do this
293
+ # more exactly (using the Euclidean norm) but there is probably
294
+ # no practical benefit.
295
+ recnorm = max(abs(h) for h in H.values())
296
+ if recnorm:
297
+ norm = ((1 << (2*prec)) // recnorm) >> prec
298
+ norm //= 100
299
+ else:
300
+ norm = ctx.inf
301
+ if verbose:
302
+ print("%i/%i: Error: %8s Norm: %s" % \
303
+ (REP, maxsteps, ctx.nstr(best_err / ctx.mpf(2)**prec, 1), norm))
304
+ if norm >= maxcoeff:
305
+ break
306
+ if verbose:
307
+ print("CANCELLING after step %i/%i." % (REP, maxsteps))
308
+ print("Could not find an integer relation. Norm bound: %s" % norm)
309
+ return None
310
+
311
+ def findpoly(ctx, x, n=1, **kwargs):
312
+ r"""
313
+ ``findpoly(x, n)`` returns the coefficients of an integer
314
+ polynomial `P` of degree at most `n` such that `P(x) \approx 0`.
315
+ If no polynomial having `x` as a root can be found,
316
+ :func:`~mpmath.findpoly` returns ``None``.
317
+
318
+ :func:`~mpmath.findpoly` works by successively calling :func:`~mpmath.pslq` with
319
+ the vectors `[1, x]`, `[1, x, x^2]`, `[1, x, x^2, x^3]`, ...,
320
+ `[1, x, x^2, .., x^n]` as input. Keyword arguments given to
321
+ :func:`~mpmath.findpoly` are forwarded verbatim to :func:`~mpmath.pslq`. In
322
+ particular, you can specify a tolerance for `P(x)` with ``tol``
323
+ and a maximum permitted coefficient size with ``maxcoeff``.
324
+
325
+ For large values of `n`, it is recommended to run :func:`~mpmath.findpoly`
326
+ at high precision; preferably 50 digits or more.
327
+
328
+ **Examples**
329
+
330
+ By default (degree `n = 1`), :func:`~mpmath.findpoly` simply finds a linear
331
+ polynomial with a rational root::
332
+
333
+ >>> from mpmath import *
334
+ >>> mp.dps = 15; mp.pretty = True
335
+ >>> findpoly(0.7)
336
+ [-10, 7]
337
+
338
+ The generated coefficient list is valid input to ``polyval`` and
339
+ ``polyroots``::
340
+
341
+ >>> nprint(polyval(findpoly(phi, 2), phi), 1)
342
+ -2.0e-16
343
+ >>> for r in polyroots(findpoly(phi, 2)):
344
+ ... print(r)
345
+ ...
346
+ -0.618033988749895
347
+ 1.61803398874989
348
+
349
+ Numbers of the form `m + n \sqrt p` for integers `(m, n, p)` are
350
+ solutions to quadratic equations. As we find here, `1+\sqrt 2`
351
+ is a root of the polynomial `x^2 - 2x - 1`::
352
+
353
+ >>> findpoly(1+sqrt(2), 2)
354
+ [1, -2, -1]
355
+ >>> findroot(lambda x: x**2 - 2*x - 1, 1)
356
+ 2.4142135623731
357
+
358
+ Despite only containing square roots, the following number results
359
+ in a polynomial of degree 4::
360
+
361
+ >>> findpoly(sqrt(2)+sqrt(3), 4)
362
+ [1, 0, -10, 0, 1]
363
+
364
+ In fact, `x^4 - 10x^2 + 1` is the *minimal polynomial* of
365
+ `r = \sqrt 2 + \sqrt 3`, meaning that a rational polynomial of
366
+ lower degree having `r` as a root does not exist. Given sufficient
367
+ precision, :func:`~mpmath.findpoly` will usually find the correct
368
+ minimal polynomial of a given algebraic number.
369
+
370
+ **Non-algebraic numbers**
371
+
372
+ If :func:`~mpmath.findpoly` fails to find a polynomial with given
373
+ coefficient size and tolerance constraints, that means no such
374
+ polynomial exists.
375
+
376
+ We can verify that `\pi` is not an algebraic number of degree 3 with
377
+ coefficients less than 1000::
378
+
379
+ >>> mp.dps = 15
380
+ >>> findpoly(pi, 3)
381
+ >>>
382
+
383
+ It is always possible to find an algebraic approximation of a number
384
+ using one (or several) of the following methods:
385
+
386
+ 1. Increasing the permitted degree
387
+ 2. Allowing larger coefficients
388
+ 3. Reducing the tolerance
389
+
390
+ One example of each method is shown below::
391
+
392
+ >>> mp.dps = 15
393
+ >>> findpoly(pi, 4)
394
+ [95, -545, 863, -183, -298]
395
+ >>> findpoly(pi, 3, maxcoeff=10000)
396
+ [836, -1734, -2658, -457]
397
+ >>> findpoly(pi, 3, tol=1e-7)
398
+ [-4, 22, -29, -2]
399
+
400
+ It is unknown whether Euler's constant is transcendental (or even
401
+ irrational). We can use :func:`~mpmath.findpoly` to check that if is
402
+ an algebraic number, its minimal polynomial must have degree
403
+ at least 7 and a coefficient of magnitude at least 1000000::
404
+
405
+ >>> mp.dps = 200
406
+ >>> findpoly(euler, 6, maxcoeff=10**6, tol=1e-100, maxsteps=1000)
407
+ >>>
408
+
409
+ Note that the high precision and strict tolerance is necessary
410
+ for such high-degree runs, since otherwise unwanted low-accuracy
411
+ approximations will be detected. It may also be necessary to set
412
+ maxsteps high to prevent a premature exit (before the coefficient
413
+ bound has been reached). Running with ``verbose=True`` to get an
414
+ idea what is happening can be useful.
415
+ """
416
+ x = ctx.mpf(x)
417
+ if n < 1:
418
+ raise ValueError("n cannot be less than 1")
419
+ if x == 0:
420
+ return [1, 0]
421
+ xs = [ctx.mpf(1)]
422
+ for i in range(1,n+1):
423
+ xs.append(x**i)
424
+ a = ctx.pslq(xs, **kwargs)
425
+ if a is not None:
426
+ return a[::-1]
427
+
428
+ def fracgcd(p, q):
429
+ x, y = p, q
430
+ while y:
431
+ x, y = y, x % y
432
+ if x != 1:
433
+ p //= x
434
+ q //= x
435
+ if q == 1:
436
+ return p
437
+ return p, q
438
+
439
+ def pslqstring(r, constants):
440
+ q = r[0]
441
+ r = r[1:]
442
+ s = []
443
+ for i in range(len(r)):
444
+ p = r[i]
445
+ if p:
446
+ z = fracgcd(-p,q)
447
+ cs = constants[i][1]
448
+ if cs == '1':
449
+ cs = ''
450
+ else:
451
+ cs = '*' + cs
452
+ if isinstance(z, int_types):
453
+ if z > 0: term = str(z) + cs
454
+ else: term = ("(%s)" % z) + cs
455
+ else:
456
+ term = ("(%s/%s)" % z) + cs
457
+ s.append(term)
458
+ s = ' + '.join(s)
459
+ if '+' in s or '*' in s:
460
+ s = '(' + s + ')'
461
+ return s or '0'
462
+
463
+ def prodstring(r, constants):
464
+ q = r[0]
465
+ r = r[1:]
466
+ num = []
467
+ den = []
468
+ for i in range(len(r)):
469
+ p = r[i]
470
+ if p:
471
+ z = fracgcd(-p,q)
472
+ cs = constants[i][1]
473
+ if isinstance(z, int_types):
474
+ if abs(z) == 1: t = cs
475
+ else: t = '%s**%s' % (cs, abs(z))
476
+ ([num,den][z<0]).append(t)
477
+ else:
478
+ t = '%s**(%s/%s)' % (cs, abs(z[0]), z[1])
479
+ ([num,den][z[0]<0]).append(t)
480
+ num = '*'.join(num)
481
+ den = '*'.join(den)
482
+ if num and den: return "(%s)/(%s)" % (num, den)
483
+ if num: return num
484
+ if den: return "1/(%s)" % den
485
+
486
+ def quadraticstring(ctx,t,a,b,c):
487
+ if c < 0:
488
+ a,b,c = -a,-b,-c
489
+ u1 = (-b+ctx.sqrt(b**2-4*a*c))/(2*c)
490
+ u2 = (-b-ctx.sqrt(b**2-4*a*c))/(2*c)
491
+ if abs(u1-t) < abs(u2-t):
492
+ if b: s = '((%s+sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
493
+ else: s = '(sqrt(%s)/%s)' % (-4*a*c,2*c)
494
+ else:
495
+ if b: s = '((%s-sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
496
+ else: s = '(-sqrt(%s)/%s)' % (-4*a*c,2*c)
497
+ return s
498
+
499
+ # Transformation y = f(x,c), with inverse function x = f(y,c)
500
+ # The third entry indicates whether the transformation is
501
+ # redundant when c = 1
502
+ transforms = [
503
+ (lambda ctx,x,c: x*c, '$y/$c', 0),
504
+ (lambda ctx,x,c: x/c, '$c*$y', 1),
505
+ (lambda ctx,x,c: c/x, '$c/$y', 0),
506
+ (lambda ctx,x,c: (x*c)**2, 'sqrt($y)/$c', 0),
507
+ (lambda ctx,x,c: (x/c)**2, '$c*sqrt($y)', 1),
508
+ (lambda ctx,x,c: (c/x)**2, '$c/sqrt($y)', 0),
509
+ (lambda ctx,x,c: c*x**2, 'sqrt($y)/sqrt($c)', 1),
510
+ (lambda ctx,x,c: x**2/c, 'sqrt($c)*sqrt($y)', 1),
511
+ (lambda ctx,x,c: c/x**2, 'sqrt($c)/sqrt($y)', 1),
512
+ (lambda ctx,x,c: ctx.sqrt(x*c), '$y**2/$c', 0),
513
+ (lambda ctx,x,c: ctx.sqrt(x/c), '$c*$y**2', 1),
514
+ (lambda ctx,x,c: ctx.sqrt(c/x), '$c/$y**2', 0),
515
+ (lambda ctx,x,c: c*ctx.sqrt(x), '$y**2/$c**2', 1),
516
+ (lambda ctx,x,c: ctx.sqrt(x)/c, '$c**2*$y**2', 1),
517
+ (lambda ctx,x,c: c/ctx.sqrt(x), '$c**2/$y**2', 1),
518
+ (lambda ctx,x,c: ctx.exp(x*c), 'log($y)/$c', 0),
519
+ (lambda ctx,x,c: ctx.exp(x/c), '$c*log($y)', 1),
520
+ (lambda ctx,x,c: ctx.exp(c/x), '$c/log($y)', 0),
521
+ (lambda ctx,x,c: c*ctx.exp(x), 'log($y/$c)', 1),
522
+ (lambda ctx,x,c: ctx.exp(x)/c, 'log($c*$y)', 1),
523
+ (lambda ctx,x,c: c/ctx.exp(x), 'log($c/$y)', 0),
524
+ (lambda ctx,x,c: ctx.ln(x*c), 'exp($y)/$c', 0),
525
+ (lambda ctx,x,c: ctx.ln(x/c), '$c*exp($y)', 1),
526
+ (lambda ctx,x,c: ctx.ln(c/x), '$c/exp($y)', 0),
527
+ (lambda ctx,x,c: c*ctx.ln(x), 'exp($y/$c)', 1),
528
+ (lambda ctx,x,c: ctx.ln(x)/c, 'exp($c*$y)', 1),
529
+ (lambda ctx,x,c: c/ctx.ln(x), 'exp($c/$y)', 0),
530
+ ]
531
+
532
+ def identify(ctx, x, constants=[], tol=None, maxcoeff=1000, full=False,
533
+ verbose=False):
534
+ r"""
535
+ Given a real number `x`, ``identify(x)`` attempts to find an exact
536
+ formula for `x`. This formula is returned as a string. If no match
537
+ is found, ``None`` is returned. With ``full=True``, a list of
538
+ matching formulas is returned.
539
+
540
+ As a simple example, :func:`~mpmath.identify` will find an algebraic
541
+ formula for the golden ratio::
542
+
543
+ >>> from mpmath import *
544
+ >>> mp.dps = 15; mp.pretty = True
545
+ >>> identify(phi)
546
+ '((1+sqrt(5))/2)'
547
+
548
+ :func:`~mpmath.identify` can identify simple algebraic numbers and simple
549
+ combinations of given base constants, as well as certain basic
550
+ transformations thereof. More specifically, :func:`~mpmath.identify`
551
+ looks for the following:
552
+
553
+ 1. Fractions
554
+ 2. Quadratic algebraic numbers
555
+ 3. Rational linear combinations of the base constants
556
+ 4. Any of the above after first transforming `x` into `f(x)` where
557
+ `f(x)` is `1/x`, `\sqrt x`, `x^2`, `\log x` or `\exp x`, either
558
+ directly or with `x` or `f(x)` multiplied or divided by one of
559
+ the base constants
560
+ 5. Products of fractional powers of the base constants and
561
+ small integers
562
+
563
+ Base constants can be given as a list of strings representing mpmath
564
+ expressions (:func:`~mpmath.identify` will ``eval`` the strings to numerical
565
+ values and use the original strings for the output), or as a dict of
566
+ formula:value pairs.
567
+
568
+ In order not to produce spurious results, :func:`~mpmath.identify` should
569
+ be used with high precision; preferably 50 digits or more.
570
+
571
+ **Examples**
572
+
573
+ Simple identifications can be performed safely at standard
574
+ precision. Here the default recognition of rational, algebraic,
575
+ and exp/log of algebraic numbers is demonstrated::
576
+
577
+ >>> mp.dps = 15
578
+ >>> identify(0.22222222222222222)
579
+ '(2/9)'
580
+ >>> identify(1.9662210973805663)
581
+ 'sqrt(((24+sqrt(48))/8))'
582
+ >>> identify(4.1132503787829275)
583
+ 'exp((sqrt(8)/2))'
584
+ >>> identify(0.881373587019543)
585
+ 'log(((2+sqrt(8))/2))'
586
+
587
+ By default, :func:`~mpmath.identify` does not recognize `\pi`. At standard
588
+ precision it finds a not too useful approximation. At slightly
589
+ increased precision, this approximation is no longer accurate
590
+ enough and :func:`~mpmath.identify` more correctly returns ``None``::
591
+
592
+ >>> identify(pi)
593
+ '(2**(176/117)*3**(20/117)*5**(35/39))/(7**(92/117))'
594
+ >>> mp.dps = 30
595
+ >>> identify(pi)
596
+ >>>
597
+
598
+ Numbers such as `\pi`, and simple combinations of user-defined
599
+ constants, can be identified if they are provided explicitly::
600
+
601
+ >>> identify(3*pi-2*e, ['pi', 'e'])
602
+ '(3*pi + (-2)*e)'
603
+
604
+ Here is an example using a dict of constants. Note that the
605
+ constants need not be "atomic"; :func:`~mpmath.identify` can just
606
+ as well express the given number in terms of expressions
607
+ given by formulas::
608
+
609
+ >>> identify(pi+e, {'a':pi+2, 'b':2*e})
610
+ '((-2) + 1*a + (1/2)*b)'
611
+
612
+ Next, we attempt some identifications with a set of base constants.
613
+ It is necessary to increase the precision a bit.
614
+
615
+ >>> mp.dps = 50
616
+ >>> base = ['sqrt(2)','pi','log(2)']
617
+ >>> identify(0.25, base)
618
+ '(1/4)'
619
+ >>> identify(3*pi + 2*sqrt(2) + 5*log(2)/7, base)
620
+ '(2*sqrt(2) + 3*pi + (5/7)*log(2))'
621
+ >>> identify(exp(pi+2), base)
622
+ 'exp((2 + 1*pi))'
623
+ >>> identify(1/(3+sqrt(2)), base)
624
+ '((3/7) + (-1/7)*sqrt(2))'
625
+ >>> identify(sqrt(2)/(3*pi+4), base)
626
+ 'sqrt(2)/(4 + 3*pi)'
627
+ >>> identify(5**(mpf(1)/3)*pi*log(2)**2, base)
628
+ '5**(1/3)*pi*log(2)**2'
629
+
630
+ An example of an erroneous solution being found when too low
631
+ precision is used::
632
+
633
+ >>> mp.dps = 15
634
+ >>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
635
+ '((11/25) + (-158/75)*pi + (76/75)*e + (44/15)*sqrt(2))'
636
+ >>> mp.dps = 50
637
+ >>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
638
+ '1/(3*pi + (-4)*e + 2*sqrt(2))'
639
+
640
+ **Finding approximate solutions**
641
+
642
+ The tolerance ``tol`` defaults to 3/4 of the working precision.
643
+ Lowering the tolerance is useful for finding approximate matches.
644
+ We can for example try to generate approximations for pi::
645
+
646
+ >>> mp.dps = 15
647
+ >>> identify(pi, tol=1e-2)
648
+ '(22/7)'
649
+ >>> identify(pi, tol=1e-3)
650
+ '(355/113)'
651
+ >>> identify(pi, tol=1e-10)
652
+ '(5**(339/269))/(2**(64/269)*3**(13/269)*7**(92/269))'
653
+
654
+ With ``full=True``, and by supplying a few base constants,
655
+ ``identify`` can generate almost endless lists of approximations
656
+ for any number (the output below has been truncated to show only
657
+ the first few)::
658
+
659
+ >>> for p in identify(pi, ['e', 'catalan'], tol=1e-5, full=True):
660
+ ... print(p)
661
+ ... # doctest: +ELLIPSIS
662
+ e/log((6 + (-4/3)*e))
663
+ (3**3*5*e*catalan**2)/(2*7**2)
664
+ sqrt(((-13) + 1*e + 22*catalan))
665
+ log(((-6) + 24*e + 4*catalan)/e)
666
+ exp(catalan*((-1/5) + (8/15)*e))
667
+ catalan*(6 + (-6)*e + 15*catalan)
668
+ sqrt((5 + 26*e + (-3)*catalan))/e
669
+ e*sqrt(((-27) + 2*e + 25*catalan))
670
+ log(((-1) + (-11)*e + 59*catalan))
671
+ ((3/20) + (21/20)*e + (3/20)*catalan)
672
+ ...
673
+
674
+ The numerical values are roughly as close to `\pi` as permitted by the
675
+ specified tolerance:
676
+
677
+ >>> e/log(6-4*e/3)
678
+ 3.14157719846001
679
+ >>> 135*e*catalan**2/98
680
+ 3.14166950419369
681
+ >>> sqrt(e-13+22*catalan)
682
+ 3.14158000062992
683
+ >>> log(24*e-6+4*catalan)-1
684
+ 3.14158791577159
685
+
686
+ **Symbolic processing**
687
+
688
+ The output formula can be evaluated as a Python expression.
689
+ Note however that if fractions (like '2/3') are present in
690
+ the formula, Python's :func:`~mpmath.eval()` may erroneously perform
691
+ integer division. Note also that the output is not necessarily
692
+ in the algebraically simplest form::
693
+
694
+ >>> identify(sqrt(2))
695
+ '(sqrt(8)/2)'
696
+
697
+ As a solution to both problems, consider using SymPy's
698
+ :func:`~mpmath.sympify` to convert the formula into a symbolic expression.
699
+ SymPy can be used to pretty-print or further simplify the formula
700
+ symbolically::
701
+
702
+ >>> from sympy import sympify # doctest: +SKIP
703
+ >>> sympify(identify(sqrt(2))) # doctest: +SKIP
704
+ 2**(1/2)
705
+
706
+ Sometimes :func:`~mpmath.identify` can simplify an expression further than
707
+ a symbolic algorithm::
708
+
709
+ >>> from sympy import simplify # doctest: +SKIP
710
+ >>> x = sympify('-1/(-3/2+(1/2)*5**(1/2))*(3/2-1/2*5**(1/2))**(1/2)') # doctest: +SKIP
711
+ >>> x # doctest: +SKIP
712
+ (3/2 - 5**(1/2)/2)**(-1/2)
713
+ >>> x = simplify(x) # doctest: +SKIP
714
+ >>> x # doctest: +SKIP
715
+ 2/(6 - 2*5**(1/2))**(1/2)
716
+ >>> mp.dps = 30 # doctest: +SKIP
717
+ >>> x = sympify(identify(x.evalf(30))) # doctest: +SKIP
718
+ >>> x # doctest: +SKIP
719
+ 1/2 + 5**(1/2)/2
720
+
721
+ (In fact, this functionality is available directly in SymPy as the
722
+ function :func:`~mpmath.nsimplify`, which is essentially a wrapper for
723
+ :func:`~mpmath.identify`.)
724
+
725
+ **Miscellaneous issues and limitations**
726
+
727
+ The input `x` must be a real number. All base constants must be
728
+ positive real numbers and must not be rationals or rational linear
729
+ combinations of each other.
730
+
731
+ The worst-case computation time grows quickly with the number of
732
+ base constants. Already with 3 or 4 base constants,
733
+ :func:`~mpmath.identify` may require several seconds to finish. To search
734
+ for relations among a large number of constants, you should
735
+ consider using :func:`~mpmath.pslq` directly.
736
+
737
+ The extended transformations are applied to x, not the constants
738
+ separately. As a result, ``identify`` will for example be able to
739
+ recognize ``exp(2*pi+3)`` with ``pi`` given as a base constant, but
740
+ not ``2*exp(pi)+3``. It will be able to recognize the latter if
741
+ ``exp(pi)`` is given explicitly as a base constant.
742
+
743
+ """
744
+
745
+ solutions = []
746
+
747
+ def addsolution(s):
748
+ if verbose: print("Found: ", s)
749
+ solutions.append(s)
750
+
751
+ x = ctx.mpf(x)
752
+
753
+ # Further along, x will be assumed positive
754
+ if x == 0:
755
+ if full: return ['0']
756
+ else: return '0'
757
+ if x < 0:
758
+ sol = ctx.identify(-x, constants, tol, maxcoeff, full, verbose)
759
+ if sol is None:
760
+ return sol
761
+ if full:
762
+ return ["-(%s)"%s for s in sol]
763
+ else:
764
+ return "-(%s)" % sol
765
+
766
+ if tol:
767
+ tol = ctx.mpf(tol)
768
+ else:
769
+ tol = ctx.eps**0.7
770
+ M = maxcoeff
771
+
772
+ if constants:
773
+ if isinstance(constants, dict):
774
+ constants = [(ctx.mpf(v), name) for (name, v) in sorted(constants.items())]
775
+ else:
776
+ namespace = dict((name, getattr(ctx,name)) for name in dir(ctx))
777
+ constants = [(eval(p, namespace), p) for p in constants]
778
+ else:
779
+ constants = []
780
+
781
+ # We always want to find at least rational terms
782
+ if 1 not in [value for (name, value) in constants]:
783
+ constants = [(ctx.mpf(1), '1')] + constants
784
+
785
+ # PSLQ with simple algebraic and functional transformations
786
+ for ft, ftn, red in transforms:
787
+ for c, cn in constants:
788
+ if red and cn == '1':
789
+ continue
790
+ t = ft(ctx,x,c)
791
+ # Prevent exponential transforms from wreaking havoc
792
+ if abs(t) > M**2 or abs(t) < tol:
793
+ continue
794
+ # Linear combination of base constants
795
+ r = ctx.pslq([t] + [a[0] for a in constants], tol, M)
796
+ s = None
797
+ if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
798
+ s = pslqstring(r, constants)
799
+ # Quadratic algebraic numbers
800
+ else:
801
+ q = ctx.pslq([ctx.one, t, t**2], tol, M)
802
+ if q is not None and len(q) == 3 and q[2]:
803
+ aa, bb, cc = q
804
+ if max(abs(aa),abs(bb),abs(cc)) <= M:
805
+ s = quadraticstring(ctx,t,aa,bb,cc)
806
+ if s:
807
+ if cn == '1' and ('/$c' in ftn):
808
+ s = ftn.replace('$y', s).replace('/$c', '')
809
+ else:
810
+ s = ftn.replace('$y', s).replace('$c', cn)
811
+ addsolution(s)
812
+ if not full: return solutions[0]
813
+
814
+ if verbose:
815
+ print(".")
816
+
817
+ # Check for a direct multiplicative formula
818
+ if x != 1:
819
+ # Allow fractional powers of fractions
820
+ ilogs = [2,3,5,7]
821
+ # Watch out for existing fractional powers of fractions
822
+ logs = []
823
+ for a, s in constants:
824
+ if not sum(bool(ctx.findpoly(ctx.ln(a)/ctx.ln(i),1)) for i in ilogs):
825
+ logs.append((ctx.ln(a), s))
826
+ logs = [(ctx.ln(i),str(i)) for i in ilogs] + logs
827
+ r = ctx.pslq([ctx.ln(x)] + [a[0] for a in logs], tol, M)
828
+ if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
829
+ addsolution(prodstring(r, logs))
830
+ if not full: return solutions[0]
831
+
832
+ if full:
833
+ return sorted(solutions, key=len)
834
+ else:
835
+ return None
836
+
837
+ IdentificationMethods.pslq = pslq
838
+ IdentificationMethods.findpoly = findpoly
839
+ IdentificationMethods.identify = identify
840
+
841
+
842
+ if __name__ == '__main__':
843
+ import doctest
844
+ doctest.testmod()
lib/python3.11/site-packages/mpmath/libmp/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .libmpf import (prec_to_dps, dps_to_prec, repr_dps,
2
+ round_down, round_up, round_floor, round_ceiling, round_nearest,
3
+ to_pickable, from_pickable, ComplexResult,
4
+ fzero, fnzero, fone, fnone, ftwo, ften, fhalf, fnan, finf, fninf,
5
+ math_float_inf, round_int, normalize, normalize1,
6
+ from_man_exp, from_int, to_man_exp, to_int, mpf_ceil, mpf_floor,
7
+ mpf_nint, mpf_frac,
8
+ from_float, from_npfloat, from_Decimal, to_float, from_rational, to_rational, to_fixed,
9
+ mpf_rand, mpf_eq, mpf_hash, mpf_cmp, mpf_lt, mpf_le, mpf_gt, mpf_ge,
10
+ mpf_pos, mpf_neg, mpf_abs, mpf_sign, mpf_add, mpf_sub, mpf_sum,
11
+ mpf_mul, mpf_mul_int, mpf_shift, mpf_frexp,
12
+ mpf_div, mpf_rdiv_int, mpf_mod, mpf_pow_int,
13
+ mpf_perturb,
14
+ to_digits_exp, to_str, str_to_man_exp, from_str, from_bstr, to_bstr,
15
+ mpf_sqrt, mpf_hypot)
16
+
17
+ from .libmpc import (mpc_one, mpc_zero, mpc_two, mpc_half,
18
+ mpc_is_inf, mpc_is_infnan, mpc_to_str, mpc_to_complex, mpc_hash,
19
+ mpc_conjugate, mpc_is_nonzero, mpc_add, mpc_add_mpf,
20
+ mpc_sub, mpc_sub_mpf, mpc_pos, mpc_neg, mpc_shift, mpc_abs,
21
+ mpc_arg, mpc_floor, mpc_ceil, mpc_nint, mpc_frac, mpc_mul, mpc_square,
22
+ mpc_mul_mpf, mpc_mul_imag_mpf, mpc_mul_int,
23
+ mpc_div, mpc_div_mpf, mpc_reciprocal, mpc_mpf_div,
24
+ complex_int_pow, mpc_pow, mpc_pow_mpf, mpc_pow_int,
25
+ mpc_sqrt, mpc_nthroot, mpc_cbrt, mpc_exp, mpc_log, mpc_cos, mpc_sin,
26
+ mpc_tan, mpc_cos_pi, mpc_sin_pi, mpc_cosh, mpc_sinh, mpc_tanh,
27
+ mpc_atan, mpc_acos, mpc_asin, mpc_asinh, mpc_acosh, mpc_atanh,
28
+ mpc_fibonacci, mpf_expj, mpf_expjpi, mpc_expj, mpc_expjpi,
29
+ mpc_cos_sin, mpc_cos_sin_pi)
30
+
31
+ from .libelefun import (ln2_fixed, mpf_ln2, ln10_fixed, mpf_ln10,
32
+ pi_fixed, mpf_pi, e_fixed, mpf_e, phi_fixed, mpf_phi,
33
+ degree_fixed, mpf_degree,
34
+ mpf_pow, mpf_nthroot, mpf_cbrt, log_int_fixed, agm_fixed,
35
+ mpf_log, mpf_log_hypot, mpf_exp, mpf_cos_sin, mpf_cos, mpf_sin, mpf_tan,
36
+ mpf_cos_sin_pi, mpf_cos_pi, mpf_sin_pi, mpf_cosh_sinh,
37
+ mpf_cosh, mpf_sinh, mpf_tanh, mpf_atan, mpf_atan2, mpf_asin,
38
+ mpf_acos, mpf_asinh, mpf_acosh, mpf_atanh, mpf_fibonacci)
39
+
40
+ from .libhyper import (NoConvergence, make_hyp_summator,
41
+ mpf_erf, mpf_erfc, mpf_ei, mpc_ei, mpf_e1, mpc_e1, mpf_expint,
42
+ mpf_ci_si, mpf_ci, mpf_si, mpc_ci, mpc_si, mpf_besseljn,
43
+ mpc_besseljn, mpf_agm, mpf_agm1, mpc_agm, mpc_agm1,
44
+ mpf_ellipk, mpc_ellipk, mpf_ellipe, mpc_ellipe)
45
+
46
+ from .gammazeta import (catalan_fixed, mpf_catalan,
47
+ khinchin_fixed, mpf_khinchin, glaisher_fixed, mpf_glaisher,
48
+ apery_fixed, mpf_apery, euler_fixed, mpf_euler, mertens_fixed,
49
+ mpf_mertens, twinprime_fixed, mpf_twinprime,
50
+ mpf_bernoulli, bernfrac, mpf_gamma_int,
51
+ mpf_factorial, mpc_factorial, mpf_gamma, mpc_gamma,
52
+ mpf_loggamma, mpc_loggamma, mpf_rgamma, mpc_rgamma,
53
+ mpf_harmonic, mpc_harmonic, mpf_psi0, mpc_psi0,
54
+ mpf_psi, mpc_psi, mpf_zeta_int, mpf_zeta, mpc_zeta,
55
+ mpf_altzeta, mpc_altzeta, mpf_zetasum, mpc_zetasum)
56
+
57
+ from .libmpi import (mpi_str,
58
+ mpi_from_str, mpi_to_str,
59
+ mpi_eq, mpi_ne,
60
+ mpi_lt, mpi_le, mpi_gt, mpi_ge,
61
+ mpi_add, mpi_sub, mpi_delta, mpi_mid,
62
+ mpi_pos, mpi_neg, mpi_abs, mpi_mul, mpi_div, mpi_exp,
63
+ mpi_log, mpi_sqrt, mpi_pow_int, mpi_pow, mpi_cos_sin,
64
+ mpi_cos, mpi_sin, mpi_tan, mpi_cot,
65
+ mpi_atan, mpi_atan2,
66
+ mpci_pos, mpci_neg, mpci_add, mpci_sub, mpci_mul, mpci_div, mpci_pow,
67
+ mpci_abs, mpci_pow, mpci_exp, mpci_log, mpci_cos, mpci_sin,
68
+ mpi_gamma, mpci_gamma, mpi_loggamma, mpci_loggamma,
69
+ mpi_rgamma, mpci_rgamma, mpi_factorial, mpci_factorial)
70
+
71
+ from .libintmath import (trailing, bitcount, numeral, bin_to_radix,
72
+ isqrt, isqrt_small, isqrt_fast, sqrt_fixed, sqrtrem, ifib, ifac,
73
+ list_primes, isprime, moebius, gcd, eulernum, stirling1, stirling2)
74
+
75
+ from .backend import (gmpy, sage, BACKEND, STRICT, MPZ, MPZ_TYPE,
76
+ MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_THREE, MPZ_FIVE, int_types,
77
+ HASH_MODULUS, HASH_BITS)
lib/python3.11/site-packages/mpmath/libmp/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/backend.cpython-311.pyc ADDED
Binary file (2.86 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/gammazeta.cpython-311.pyc ADDED
Binary file (89 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/libelefun.cpython-311.pyc ADDED
Binary file (54.8 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/libhyper.cpython-311.pyc ADDED
Binary file (52.9 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/libintmath.cpython-311.pyc ADDED
Binary file (22.9 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/libmpc.cpython-311.pyc ADDED
Binary file (43.8 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/libmpf.cpython-311.pyc ADDED
Binary file (52.6 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/__pycache__/libmpi.cpython-311.pyc ADDED
Binary file (43.4 kB). View file
 
lib/python3.11/site-packages/mpmath/libmp/backend.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ #----------------------------------------------------------------------------#
5
+ # Support GMPY for high-speed large integer arithmetic. #
6
+ # #
7
+ # To allow an external module to handle arithmetic, we need to make sure #
8
+ # that all high-precision variables are declared of the correct type. MPZ #
9
+ # is the constructor for the high-precision type. It defaults to Python's #
10
+ # long type but can be assinged another type, typically gmpy.mpz. #
11
+ # #
12
+ # MPZ must be used for the mantissa component of an mpf and must be used #
13
+ # for internal fixed-point operations. #
14
+ # #
15
+ # Side-effects #
16
+ # 1) "is" cannot be used to test for special values. Must use "==". #
17
+ # 2) There are bugs in GMPY prior to v1.02 so we must use v1.03 or later. #
18
+ #----------------------------------------------------------------------------#
19
+
20
+ # So we can import it from this module
21
+ gmpy = None
22
+ sage = None
23
+ sage_utils = None
24
+
25
+ if sys.version_info[0] < 3:
26
+ python3 = False
27
+ else:
28
+ python3 = True
29
+
30
+ BACKEND = 'python'
31
+
32
+ if not python3:
33
+ MPZ = long
34
+ xrange = xrange
35
+ basestring = basestring
36
+
37
+ def exec_(_code_, _globs_=None, _locs_=None):
38
+ """Execute code in a namespace."""
39
+ if _globs_ is None:
40
+ frame = sys._getframe(1)
41
+ _globs_ = frame.f_globals
42
+ if _locs_ is None:
43
+ _locs_ = frame.f_locals
44
+ del frame
45
+ elif _locs_ is None:
46
+ _locs_ = _globs_
47
+ exec("""exec _code_ in _globs_, _locs_""")
48
+ else:
49
+ MPZ = int
50
+ xrange = range
51
+ basestring = str
52
+
53
+ import builtins
54
+ exec_ = getattr(builtins, "exec")
55
+
56
+ # Define constants for calculating hash on Python 3.2.
57
+ if sys.version_info >= (3, 2):
58
+ HASH_MODULUS = sys.hash_info.modulus
59
+ if sys.hash_info.width == 32:
60
+ HASH_BITS = 31
61
+ else:
62
+ HASH_BITS = 61
63
+ else:
64
+ HASH_MODULUS = None
65
+ HASH_BITS = None
66
+
67
+ if 'MPMATH_NOGMPY' not in os.environ:
68
+ try:
69
+ try:
70
+ import gmpy2 as gmpy
71
+ except ImportError:
72
+ try:
73
+ import gmpy
74
+ except ImportError:
75
+ raise ImportError
76
+ if gmpy.version() >= '1.03':
77
+ BACKEND = 'gmpy'
78
+ MPZ = gmpy.mpz
79
+ except:
80
+ pass
81
+
82
+ if ('MPMATH_NOSAGE' not in os.environ and 'SAGE_ROOT' in os.environ or
83
+ 'MPMATH_SAGE' in os.environ):
84
+ try:
85
+ import sage.all
86
+ import sage.libs.mpmath.utils as _sage_utils
87
+ sage = sage.all
88
+ sage_utils = _sage_utils
89
+ BACKEND = 'sage'
90
+ MPZ = sage.Integer
91
+ except:
92
+ pass
93
+
94
+ if 'MPMATH_STRICT' in os.environ:
95
+ STRICT = True
96
+ else:
97
+ STRICT = False
98
+
99
+ MPZ_TYPE = type(MPZ(0))
100
+ MPZ_ZERO = MPZ(0)
101
+ MPZ_ONE = MPZ(1)
102
+ MPZ_TWO = MPZ(2)
103
+ MPZ_THREE = MPZ(3)
104
+ MPZ_FIVE = MPZ(5)
105
+
106
+ try:
107
+ if BACKEND == 'python':
108
+ int_types = (int, long)
109
+ else:
110
+ int_types = (int, long, MPZ_TYPE)
111
+ except NameError:
112
+ if BACKEND == 'python':
113
+ int_types = (int,)
114
+ else:
115
+ int_types = (int, MPZ_TYPE)
lib/python3.11/site-packages/mpmath/libmp/gammazeta.py ADDED
@@ -0,0 +1,2167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------
3
+ This module implements gamma- and zeta-related functions:
4
+
5
+ * Bernoulli numbers
6
+ * Factorials
7
+ * The gamma function
8
+ * Polygamma functions
9
+ * Harmonic numbers
10
+ * The Riemann zeta function
11
+ * Constants related to these functions
12
+
13
+ -----------------------------------------------------------------------
14
+ """
15
+
16
+ import math
17
+ import sys
18
+
19
+ from .backend import xrange
20
+ from .backend import MPZ, MPZ_ZERO, MPZ_ONE, MPZ_THREE, gmpy
21
+
22
+ from .libintmath import list_primes, ifac, ifac2, moebius
23
+
24
+ from .libmpf import (\
25
+ round_floor, round_ceiling, round_down, round_up,
26
+ round_nearest, round_fast,
27
+ lshift, sqrt_fixed, isqrt_fast,
28
+ fzero, fone, fnone, fhalf, ftwo, finf, fninf, fnan,
29
+ from_int, to_int, to_fixed, from_man_exp, from_rational,
30
+ mpf_pos, mpf_neg, mpf_abs, mpf_add, mpf_sub,
31
+ mpf_mul, mpf_mul_int, mpf_div, mpf_sqrt, mpf_pow_int,
32
+ mpf_rdiv_int,
33
+ mpf_perturb, mpf_le, mpf_lt, mpf_gt, mpf_shift,
34
+ negative_rnd, reciprocal_rnd,
35
+ bitcount, to_float, mpf_floor, mpf_sign, ComplexResult
36
+ )
37
+
38
+ from .libelefun import (\
39
+ constant_memo,
40
+ def_mpf_constant,
41
+ mpf_pi, pi_fixed, ln2_fixed, log_int_fixed, mpf_ln2,
42
+ mpf_exp, mpf_log, mpf_pow, mpf_cosh,
43
+ mpf_cos_sin, mpf_cosh_sinh, mpf_cos_sin_pi, mpf_cos_pi, mpf_sin_pi,
44
+ ln_sqrt2pi_fixed, mpf_ln_sqrt2pi, sqrtpi_fixed, mpf_sqrtpi,
45
+ cos_sin_fixed, exp_fixed
46
+ )
47
+
48
+ from .libmpc import (\
49
+ mpc_zero, mpc_one, mpc_half, mpc_two,
50
+ mpc_abs, mpc_shift, mpc_pos, mpc_neg,
51
+ mpc_add, mpc_sub, mpc_mul, mpc_div,
52
+ mpc_add_mpf, mpc_mul_mpf, mpc_div_mpf, mpc_mpf_div,
53
+ mpc_mul_int, mpc_pow_int,
54
+ mpc_log, mpc_exp, mpc_pow,
55
+ mpc_cos_pi, mpc_sin_pi,
56
+ mpc_reciprocal, mpc_square,
57
+ mpc_sub_mpf
58
+ )
59
+
60
+
61
+
62
+ # Catalan's constant is computed using Lupas's rapidly convergent series
63
+ # (listed on http://mathworld.wolfram.com/CatalansConstant.html)
64
+ # oo
65
+ # ___ n-1 8n 2 3 2
66
+ # 1 \ (-1) 2 (40n - 24n + 3) [(2n)!] (n!)
67
+ # K = --- ) -----------------------------------------
68
+ # 64 /___ 3 2
69
+ # n (2n-1) [(4n)!]
70
+ # n = 1
71
+
72
+ @constant_memo
73
+ def catalan_fixed(prec):
74
+ prec = prec + 20
75
+ a = one = MPZ_ONE << prec
76
+ s, t, n = 0, 1, 1
77
+ while t:
78
+ a *= 32 * n**3 * (2*n-1)
79
+ a //= (3-16*n+16*n**2)**2
80
+ t = a * (-1)**(n-1) * (40*n**2-24*n+3) // (n**3 * (2*n-1))
81
+ s += t
82
+ n += 1
83
+ return s >> (20 + 6)
84
+
85
+ # Khinchin's constant is relatively difficult to compute. Here
86
+ # we use the rational zeta series
87
+
88
+ # oo 2*n-1
89
+ # ___ ___
90
+ # \ ` zeta(2*n)-1 \ ` (-1)^(k+1)
91
+ # log(K)*log(2) = ) ------------ ) ----------
92
+ # /___. n /___. k
93
+ # n = 1 k = 1
94
+
95
+ # which adds half a digit per term. The essential trick for achieving
96
+ # reasonable efficiency is to recycle both the values of the zeta
97
+ # function (essentially Bernoulli numbers) and the partial terms of
98
+ # the inner sum.
99
+
100
+ # An alternative might be to use K = 2*exp[1/log(2) X] where
101
+
102
+ # / 1 1 [ pi*x*(1-x^2) ]
103
+ # X = | ------ log [ ------------ ].
104
+ # / 0 x(1+x) [ sin(pi*x) ]
105
+
106
+ # and integrate numerically. In practice, this seems to be slightly
107
+ # slower than the zeta series at high precision.
108
+
109
+ @constant_memo
110
+ def khinchin_fixed(prec):
111
+ wp = int(prec + prec**0.5 + 15)
112
+ s = MPZ_ZERO
113
+ fac = from_int(4)
114
+ t = ONE = MPZ_ONE << wp
115
+ pi = mpf_pi(wp)
116
+ pipow = twopi2 = mpf_shift(mpf_mul(pi, pi, wp), 2)
117
+ n = 1
118
+ while 1:
119
+ zeta2n = mpf_abs(mpf_bernoulli(2*n, wp))
120
+ zeta2n = mpf_mul(zeta2n, pipow, wp)
121
+ zeta2n = mpf_div(zeta2n, fac, wp)
122
+ zeta2n = to_fixed(zeta2n, wp)
123
+ term = (((zeta2n - ONE) * t) // n) >> wp
124
+ if term < 100:
125
+ break
126
+ #if not n % 10:
127
+ # print n, math.log(int(abs(term)))
128
+ s += term
129
+ t += ONE//(2*n+1) - ONE//(2*n)
130
+ n += 1
131
+ fac = mpf_mul_int(fac, (2*n)*(2*n-1), wp)
132
+ pipow = mpf_mul(pipow, twopi2, wp)
133
+ s = (s << wp) // ln2_fixed(wp)
134
+ K = mpf_exp(from_man_exp(s, -wp), wp)
135
+ K = to_fixed(K, prec)
136
+ return K
137
+
138
+
139
+ # Glaisher's constant is defined as A = exp(1/2 - zeta'(-1)).
140
+ # One way to compute it would be to perform direct numerical
141
+ # differentiation, but computing arbitrary Riemann zeta function
142
+ # values at high precision is expensive. We instead use the formula
143
+
144
+ # A = exp((6 (-zeta'(2))/pi^2 + log 2 pi + gamma)/12)
145
+
146
+ # and compute zeta'(2) from the series representation
147
+
148
+ # oo
149
+ # ___
150
+ # \ log k
151
+ # -zeta'(2) = ) -----
152
+ # /___ 2
153
+ # k
154
+ # k = 2
155
+
156
+ # This series converges exceptionally slowly, but can be accelerated
157
+ # using Euler-Maclaurin formula. The important insight is that the
158
+ # E-M integral can be done in closed form and that the high order
159
+ # are given by
160
+
161
+ # n / \
162
+ # d | log x | a + b log x
163
+ # --- | ----- | = -----------
164
+ # n | 2 | 2 + n
165
+ # dx \ x / x
166
+
167
+ # where a and b are integers given by a simple recurrence. Note
168
+ # that just one logarithm is needed. However, lots of integer
169
+ # logarithms are required for the initial summation.
170
+
171
+ # This algorithm could possibly be turned into a faster algorithm
172
+ # for general evaluation of zeta(s) or zeta'(s); this should be
173
+ # looked into.
174
+
175
+ @constant_memo
176
+ def glaisher_fixed(prec):
177
+ wp = prec + 30
178
+ # Number of direct terms to sum before applying the Euler-Maclaurin
179
+ # formula to the tail. TODO: choose more intelligently
180
+ N = int(0.33*prec + 5)
181
+ ONE = MPZ_ONE << wp
182
+ # Euler-Maclaurin, step 1: sum log(k)/k**2 for k from 2 to N-1
183
+ s = MPZ_ZERO
184
+ for k in range(2, N):
185
+ #print k, N
186
+ s += log_int_fixed(k, wp) // k**2
187
+ logN = log_int_fixed(N, wp)
188
+ #logN = to_fixed(mpf_log(from_int(N), wp+20), wp)
189
+ # E-M step 2: integral of log(x)/x**2 from N to inf
190
+ s += (ONE + logN) // N
191
+ # E-M step 3: endpoint correction term f(N)/2
192
+ s += logN // (N**2 * 2)
193
+ # E-M step 4: the series of derivatives
194
+ pN = N**3
195
+ a = 1
196
+ b = -2
197
+ j = 3
198
+ fac = from_int(2)
199
+ k = 1
200
+ while 1:
201
+ # D(2*k-1) * B(2*k) / fac(2*k) [D(n) = nth derivative]
202
+ D = ((a << wp) + b*logN) // pN
203
+ D = from_man_exp(D, -wp)
204
+ B = mpf_bernoulli(2*k, wp)
205
+ term = mpf_mul(B, D, wp)
206
+ term = mpf_div(term, fac, wp)
207
+ term = to_fixed(term, wp)
208
+ if abs(term) < 100:
209
+ break
210
+ #if not k % 10:
211
+ # print k, math.log(int(abs(term)), 10)
212
+ s -= term
213
+ # Advance derivative twice
214
+ a, b, pN, j = b-a*j, -j*b, pN*N, j+1
215
+ a, b, pN, j = b-a*j, -j*b, pN*N, j+1
216
+ k += 1
217
+ fac = mpf_mul_int(fac, (2*k)*(2*k-1), wp)
218
+ # A = exp((6*s/pi**2 + log(2*pi) + euler)/12)
219
+ pi = pi_fixed(wp)
220
+ s *= 6
221
+ s = (s << wp) // (pi**2 >> wp)
222
+ s += euler_fixed(wp)
223
+ s += to_fixed(mpf_log(from_man_exp(2*pi, -wp), wp), wp)
224
+ s //= 12
225
+ A = mpf_exp(from_man_exp(s, -wp), wp)
226
+ return to_fixed(A, prec)
227
+
228
+ # Apery's constant can be computed using the very rapidly convergent
229
+ # series
230
+ # oo
231
+ # ___ 2 10
232
+ # \ n 205 n + 250 n + 77 (n!)
233
+ # zeta(3) = ) (-1) ------------------- ----------
234
+ # /___ 64 5
235
+ # n = 0 ((2n+1)!)
236
+
237
+ @constant_memo
238
+ def apery_fixed(prec):
239
+ prec += 20
240
+ d = MPZ_ONE << prec
241
+ term = MPZ(77) << prec
242
+ n = 1
243
+ s = MPZ_ZERO
244
+ while term:
245
+ s += term
246
+ d *= (n**10)
247
+ d //= (((2*n+1)**5) * (2*n)**5)
248
+ term = (-1)**n * (205*(n**2) + 250*n + 77) * d
249
+ n += 1
250
+ return s >> (20 + 6)
251
+
252
+ """
253
+ Euler's constant (gamma) is computed using the Brent-McMillan formula,
254
+ gamma ~= I(n)/J(n) - log(n), where
255
+
256
+ I(n) = sum_{k=0,1,2,...} (n**k / k!)**2 * H(k)
257
+ J(n) = sum_{k=0,1,2,...} (n**k / k!)**2
258
+ H(k) = 1 + 1/2 + 1/3 + ... + 1/k
259
+
260
+ The error is bounded by O(exp(-4n)). Choosing n to be a power
261
+ of two, 2**p, the logarithm becomes particularly easy to calculate.[1]
262
+
263
+ We use the formulation of Algorithm 3.9 in [2] to make the summation
264
+ more efficient.
265
+
266
+ Reference:
267
+ [1] Xavier Gourdon & Pascal Sebah, The Euler constant: gamma
268
+ http://numbers.computation.free.fr/Constants/Gamma/gamma.pdf
269
+
270
+ [2] [BorweinBailey]_
271
+ """
272
+
273
+ @constant_memo
274
+ def euler_fixed(prec):
275
+ extra = 30
276
+ prec += extra
277
+ # choose p such that exp(-4*(2**p)) < 2**-n
278
+ p = int(math.log((prec/4) * math.log(2), 2)) + 1
279
+ n = 2**p
280
+ A = U = -p*ln2_fixed(prec)
281
+ B = V = MPZ_ONE << prec
282
+ k = 1
283
+ while 1:
284
+ B = B*n**2//k**2
285
+ A = (A*n**2//k + B)//k
286
+ U += A
287
+ V += B
288
+ if max(abs(A), abs(B)) < 100:
289
+ break
290
+ k += 1
291
+ return (U<<(prec-extra))//V
292
+
293
+ # Use zeta accelerated formulas for the Mertens and twin
294
+ # prime constants; see
295
+ # http://mathworld.wolfram.com/MertensConstant.html
296
+ # http://mathworld.wolfram.com/TwinPrimesConstant.html
297
+
298
+ @constant_memo
299
+ def mertens_fixed(prec):
300
+ wp = prec + 20
301
+ m = 2
302
+ s = mpf_euler(wp)
303
+ while 1:
304
+ t = mpf_zeta_int(m, wp)
305
+ if t == fone:
306
+ break
307
+ t = mpf_log(t, wp)
308
+ t = mpf_mul_int(t, moebius(m), wp)
309
+ t = mpf_div(t, from_int(m), wp)
310
+ s = mpf_add(s, t)
311
+ m += 1
312
+ return to_fixed(s, prec)
313
+
314
+ @constant_memo
315
+ def twinprime_fixed(prec):
316
+ def I(n):
317
+ return sum(moebius(d)<<(n//d) for d in xrange(1,n+1) if not n%d)//n
318
+ wp = 2*prec + 30
319
+ res = fone
320
+ primes = [from_rational(1,p,wp) for p in [2,3,5,7]]
321
+ ppowers = [mpf_mul(p,p,wp) for p in primes]
322
+ n = 2
323
+ while 1:
324
+ a = mpf_zeta_int(n, wp)
325
+ for i in range(4):
326
+ a = mpf_mul(a, mpf_sub(fone, ppowers[i]), wp)
327
+ ppowers[i] = mpf_mul(ppowers[i], primes[i], wp)
328
+ a = mpf_pow_int(a, -I(n), wp)
329
+ if mpf_pos(a, prec+10, 'n') == fone:
330
+ break
331
+ #from libmpf import to_str
332
+ #print n, to_str(mpf_sub(fone, a), 6)
333
+ res = mpf_mul(res, a, wp)
334
+ n += 1
335
+ res = mpf_mul(res, from_int(3*15*35), wp)
336
+ res = mpf_div(res, from_int(4*16*36), wp)
337
+ return to_fixed(res, prec)
338
+
339
+
340
+ mpf_euler = def_mpf_constant(euler_fixed)
341
+ mpf_apery = def_mpf_constant(apery_fixed)
342
+ mpf_khinchin = def_mpf_constant(khinchin_fixed)
343
+ mpf_glaisher = def_mpf_constant(glaisher_fixed)
344
+ mpf_catalan = def_mpf_constant(catalan_fixed)
345
+ mpf_mertens = def_mpf_constant(mertens_fixed)
346
+ mpf_twinprime = def_mpf_constant(twinprime_fixed)
347
+
348
+
349
+ #-----------------------------------------------------------------------#
350
+ # #
351
+ # Bernoulli numbers #
352
+ # #
353
+ #-----------------------------------------------------------------------#
354
+
355
+ MAX_BERNOULLI_CACHE = 3000
356
+
357
+
358
+ r"""
359
+ Small Bernoulli numbers and factorials are used in numerous summations,
360
+ so it is critical for speed that sequential computation is fast and that
361
+ values are cached up to a fairly high threshold.
362
+
363
+ On the other hand, we also want to support fast computation of isolated
364
+ large numbers. Currently, no such acceleration is provided for integer
365
+ factorials (though it is for large floating-point factorials, which are
366
+ computed via gamma if the precision is low enough).
367
+
368
+ For sequential computation of Bernoulli numbers, we use Ramanujan's formula
369
+
370
+ / n + 3 \
371
+ B = (A(n) - S(n)) / | |
372
+ n \ n /
373
+
374
+ where A(n) = (n+3)/3 when n = 0 or 2 (mod 6), A(n) = -(n+3)/6
375
+ when n = 4 (mod 6), and
376
+
377
+ [n/6]
378
+ ___
379
+ \ / n + 3 \
380
+ S(n) = ) | | * B
381
+ /___ \ n - 6*k / n-6*k
382
+ k = 1
383
+
384
+ For isolated large Bernoulli numbers, we use the Riemann zeta function
385
+ to calculate a numerical value for B_n. The von Staudt-Clausen theorem
386
+ can then be used to optionally find the exact value of the
387
+ numerator and denominator.
388
+ """
389
+
390
+ bernoulli_cache = {}
391
+ f3 = from_int(3)
392
+ f6 = from_int(6)
393
+
394
+ def bernoulli_size(n):
395
+ """Accurately estimate the size of B_n (even n > 2 only)"""
396
+ lgn = math.log(n,2)
397
+ return int(2.326 + 0.5*lgn + n*(lgn - 4.094))
398
+
399
+ BERNOULLI_PREC_CUTOFF = bernoulli_size(MAX_BERNOULLI_CACHE)
400
+
401
+ def mpf_bernoulli(n, prec, rnd=None):
402
+ """Computation of Bernoulli numbers (numerically)"""
403
+ if n < 2:
404
+ if n < 0:
405
+ raise ValueError("Bernoulli numbers only defined for n >= 0")
406
+ if n == 0:
407
+ return fone
408
+ if n == 1:
409
+ return mpf_neg(fhalf)
410
+ # For odd n > 1, the Bernoulli numbers are zero
411
+ if n & 1:
412
+ return fzero
413
+ # If precision is extremely high, we can save time by computing
414
+ # the Bernoulli number at a lower precision that is sufficient to
415
+ # obtain the exact fraction, round to the exact fraction, and
416
+ # convert the fraction back to an mpf value at the original precision
417
+ if prec > BERNOULLI_PREC_CUTOFF and prec > bernoulli_size(n)*1.1 + 1000:
418
+ p, q = bernfrac(n)
419
+ return from_rational(p, q, prec, rnd or round_floor)
420
+ if n > MAX_BERNOULLI_CACHE:
421
+ return mpf_bernoulli_huge(n, prec, rnd)
422
+ wp = prec + 30
423
+ # Reuse nearby precisions
424
+ wp += 32 - (prec & 31)
425
+ cached = bernoulli_cache.get(wp)
426
+ if cached:
427
+ numbers, state = cached
428
+ if n in numbers:
429
+ if not rnd:
430
+ return numbers[n]
431
+ return mpf_pos(numbers[n], prec, rnd)
432
+ m, bin, bin1 = state
433
+ if n - m > 10:
434
+ return mpf_bernoulli_huge(n, prec, rnd)
435
+ else:
436
+ if n > 10:
437
+ return mpf_bernoulli_huge(n, prec, rnd)
438
+ numbers = {0:fone}
439
+ m, bin, bin1 = state = [2, MPZ(10), MPZ_ONE]
440
+ bernoulli_cache[wp] = (numbers, state)
441
+ while m <= n:
442
+ #print m
443
+ case = m % 6
444
+ # Accurately estimate size of B_m so we can use
445
+ # fixed point math without using too much precision
446
+ szbm = bernoulli_size(m)
447
+ s = 0
448
+ sexp = max(0, szbm) - wp
449
+ if m < 6:
450
+ a = MPZ_ZERO
451
+ else:
452
+ a = bin1
453
+ for j in xrange(1, m//6+1):
454
+ usign, uman, uexp, ubc = u = numbers[m-6*j]
455
+ if usign:
456
+ uman = -uman
457
+ s += lshift(a*uman, uexp-sexp)
458
+ # Update inner binomial coefficient
459
+ j6 = 6*j
460
+ a *= ((m-5-j6)*(m-4-j6)*(m-3-j6)*(m-2-j6)*(m-1-j6)*(m-j6))
461
+ a //= ((4+j6)*(5+j6)*(6+j6)*(7+j6)*(8+j6)*(9+j6))
462
+ if case == 0: b = mpf_rdiv_int(m+3, f3, wp)
463
+ if case == 2: b = mpf_rdiv_int(m+3, f3, wp)
464
+ if case == 4: b = mpf_rdiv_int(-m-3, f6, wp)
465
+ s = from_man_exp(s, sexp, wp)
466
+ b = mpf_div(mpf_sub(b, s, wp), from_int(bin), wp)
467
+ numbers[m] = b
468
+ m += 2
469
+ # Update outer binomial coefficient
470
+ bin = bin * ((m+2)*(m+3)) // (m*(m-1))
471
+ if m > 6:
472
+ bin1 = bin1 * ((2+m)*(3+m)) // ((m-7)*(m-6))
473
+ state[:] = [m, bin, bin1]
474
+ return numbers[n]
475
+
476
+ def mpf_bernoulli_huge(n, prec, rnd=None):
477
+ wp = prec + 10
478
+ piprec = wp + int(math.log(n,2))
479
+ v = mpf_gamma_int(n+1, wp)
480
+ v = mpf_mul(v, mpf_zeta_int(n, wp), wp)
481
+ v = mpf_mul(v, mpf_pow_int(mpf_pi(piprec), -n, wp))
482
+ v = mpf_shift(v, 1-n)
483
+ if not n & 3:
484
+ v = mpf_neg(v)
485
+ return mpf_pos(v, prec, rnd or round_fast)
486
+
487
+ def bernfrac(n):
488
+ r"""
489
+ Returns a tuple of integers `(p, q)` such that `p/q = B_n` exactly,
490
+ where `B_n` denotes the `n`-th Bernoulli number. The fraction is
491
+ always reduced to lowest terms. Note that for `n > 1` and `n` odd,
492
+ `B_n = 0`, and `(0, 1)` is returned.
493
+
494
+ **Examples**
495
+
496
+ The first few Bernoulli numbers are exactly::
497
+
498
+ >>> from mpmath import *
499
+ >>> for n in range(15):
500
+ ... p, q = bernfrac(n)
501
+ ... print("%s %s/%s" % (n, p, q))
502
+ ...
503
+ 0 1/1
504
+ 1 -1/2
505
+ 2 1/6
506
+ 3 0/1
507
+ 4 -1/30
508
+ 5 0/1
509
+ 6 1/42
510
+ 7 0/1
511
+ 8 -1/30
512
+ 9 0/1
513
+ 10 5/66
514
+ 11 0/1
515
+ 12 -691/2730
516
+ 13 0/1
517
+ 14 7/6
518
+
519
+ This function works for arbitrarily large `n`::
520
+
521
+ >>> p, q = bernfrac(10**4)
522
+ >>> print(q)
523
+ 2338224387510
524
+ >>> print(len(str(p)))
525
+ 27692
526
+ >>> mp.dps = 15
527
+ >>> print(mpf(p) / q)
528
+ -9.04942396360948e+27677
529
+ >>> print(bernoulli(10**4))
530
+ -9.04942396360948e+27677
531
+
532
+ .. note ::
533
+
534
+ :func:`~mpmath.bernoulli` computes a floating-point approximation
535
+ directly, without computing the exact fraction first.
536
+ This is much faster for large `n`.
537
+
538
+ **Algorithm**
539
+
540
+ :func:`~mpmath.bernfrac` works by computing the value of `B_n` numerically
541
+ and then using the von Staudt-Clausen theorem [1] to reconstruct
542
+ the exact fraction. For large `n`, this is significantly faster than
543
+ computing `B_1, B_2, \ldots, B_2` recursively with exact arithmetic.
544
+ The implementation has been tested for `n = 10^m` up to `m = 6`.
545
+
546
+ In practice, :func:`~mpmath.bernfrac` appears to be about three times
547
+ slower than the specialized program calcbn.exe [2]
548
+
549
+ **References**
550
+
551
+ 1. MathWorld, von Staudt-Clausen Theorem:
552
+ http://mathworld.wolfram.com/vonStaudt-ClausenTheorem.html
553
+
554
+ 2. The Bernoulli Number Page:
555
+ http://www.bernoulli.org/
556
+
557
+ """
558
+ n = int(n)
559
+ if n < 3:
560
+ return [(1, 1), (-1, 2), (1, 6)][n]
561
+ if n & 1:
562
+ return (0, 1)
563
+ q = 1
564
+ for k in list_primes(n+1):
565
+ if not (n % (k-1)):
566
+ q *= k
567
+ prec = bernoulli_size(n) + int(math.log(q,2)) + 20
568
+ b = mpf_bernoulli(n, prec)
569
+ p = mpf_mul(b, from_int(q))
570
+ pint = to_int(p, round_nearest)
571
+ return (pint, q)
572
+
573
+
574
+ #-----------------------------------------------------------------------#
575
+ # #
576
+ # Polygamma functions #
577
+ # #
578
+ #-----------------------------------------------------------------------#
579
+
580
+ r"""
581
+ For all polygamma (psi) functions, we use the Euler-Maclaurin summation
582
+ formula. It looks slightly different in the m = 0 and m > 0 cases.
583
+
584
+ For m = 0, we have
585
+ oo
586
+ ___ B
587
+ (0) 1 \ 2 k -2 k
588
+ psi (z) ~ log z + --- - ) ------ z
589
+ 2 z /___ (2 k)!
590
+ k = 1
591
+
592
+ Experiment shows that the minimum term of the asymptotic series
593
+ reaches 2^(-p) when Re(z) > 0.11*p. So we simply use the recurrence
594
+ for psi (equivalent, in fact, to summing to the first few terms
595
+ directly before applying E-M) to obtain z large enough.
596
+
597
+ Since, very crudely, log z ~= 1 for Re(z) > 1, we can use
598
+ fixed-point arithmetic (if z is extremely large, log(z) itself
599
+ is a sufficient approximation, so we can stop there already).
600
+
601
+ For Re(z) << 0, we could use recurrence, but this is of course
602
+ inefficient for large negative z, so there we use the
603
+ reflection formula instead.
604
+
605
+ For m > 0, we have
606
+
607
+ N - 1
608
+ ___
609
+ ~~~(m) [ \ 1 ] 1 1
610
+ psi (z) ~ [ ) -------- ] + ---------- + -------- +
611
+ [ /___ m+1 ] m+1 m
612
+ k = 1 (z+k) ] 2 (z+N) m (z+N)
613
+
614
+ oo
615
+ ___ B
616
+ \ 2 k (m+1) (m+2) ... (m+2k-1)
617
+ + ) ------ ------------------------
618
+ /___ (2 k)! m + 2 k
619
+ k = 1 (z+N)
620
+
621
+ where ~~~ denotes the function rescaled by 1/((-1)^(m+1) m!).
622
+
623
+ Here again N is chosen to make z+N large enough for the minimum
624
+ term in the last series to become smaller than eps.
625
+
626
+ TODO: the current estimation of N for m > 0 is *very suboptimal*.
627
+
628
+ TODO: implement the reflection formula for m > 0, Re(z) << 0.
629
+ It is generally a combination of multiple cotangents. Need to
630
+ figure out a reasonably simple way to generate these formulas
631
+ on the fly.
632
+
633
+ TODO: maybe use exact algorithms to compute psi for integral
634
+ and certain rational arguments, as this can be much more
635
+ efficient. (On the other hand, the availability of these
636
+ special values provides a convenient way to test the general
637
+ algorithm.)
638
+ """
639
+
640
+ # Harmonic numbers are just shifted digamma functions
641
+ # We should calculate these exactly when x is an integer
642
+ # and when doing so is faster.
643
+
644
+ def mpf_harmonic(x, prec, rnd):
645
+ if x in (fzero, fnan, finf):
646
+ return x
647
+ a = mpf_psi0(mpf_add(fone, x, prec+5), prec)
648
+ return mpf_add(a, mpf_euler(prec+5, rnd), prec, rnd)
649
+
650
+ def mpc_harmonic(z, prec, rnd):
651
+ if z[1] == fzero:
652
+ return (mpf_harmonic(z[0], prec, rnd), fzero)
653
+ a = mpc_psi0(mpc_add_mpf(z, fone, prec+5), prec)
654
+ return mpc_add_mpf(a, mpf_euler(prec+5, rnd), prec, rnd)
655
+
656
+ def mpf_psi0(x, prec, rnd=round_fast):
657
+ """
658
+ Computation of the digamma function (psi function of order 0)
659
+ of a real argument.
660
+ """
661
+ sign, man, exp, bc = x
662
+ wp = prec + 10
663
+ if not man:
664
+ if x == finf: return x
665
+ if x == fninf or x == fnan: return fnan
666
+ if x == fzero or (exp >= 0 and sign):
667
+ raise ValueError("polygamma pole")
668
+ # Near 0 -- fixed-point arithmetic becomes bad
669
+ if exp+bc < -5:
670
+ v = mpf_psi0(mpf_add(x, fone, prec, rnd), prec, rnd)
671
+ return mpf_sub(v, mpf_div(fone, x, wp, rnd), prec, rnd)
672
+ # Reflection formula
673
+ if sign and exp+bc > 3:
674
+ c, s = mpf_cos_sin_pi(x, wp)
675
+ q = mpf_mul(mpf_div(c, s, wp), mpf_pi(wp), wp)
676
+ p = mpf_psi0(mpf_sub(fone, x, wp), wp)
677
+ return mpf_sub(p, q, prec, rnd)
678
+ # The logarithmic term is accurate enough
679
+ if (not sign) and bc + exp > wp:
680
+ return mpf_log(mpf_sub(x, fone, wp), prec, rnd)
681
+ # Initial recurrence to obtain a large enough x
682
+ m = to_int(x)
683
+ n = int(0.11*wp) + 2
684
+ s = MPZ_ZERO
685
+ x = to_fixed(x, wp)
686
+ one = MPZ_ONE << wp
687
+ if m < n:
688
+ for k in xrange(m, n):
689
+ s -= (one << wp) // x
690
+ x += one
691
+ x -= one
692
+ # Logarithmic term
693
+ s += to_fixed(mpf_log(from_man_exp(x, -wp, wp), wp), wp)
694
+ # Endpoint term in Euler-Maclaurin expansion
695
+ s += (one << wp) // (2*x)
696
+ # Euler-Maclaurin remainder sum
697
+ x2 = (x*x) >> wp
698
+ t = one
699
+ prev = 0
700
+ k = 1
701
+ while 1:
702
+ t = (t*x2) >> wp
703
+ bsign, bman, bexp, bbc = mpf_bernoulli(2*k, wp)
704
+ offset = (bexp + 2*wp)
705
+ if offset >= 0: term = (bman << offset) // (t*(2*k))
706
+ else: term = (bman >> (-offset)) // (t*(2*k))
707
+ if k & 1: s -= term
708
+ else: s += term
709
+ if k > 2 and term >= prev:
710
+ break
711
+ prev = term
712
+ k += 1
713
+ return from_man_exp(s, -wp, wp, rnd)
714
+
715
+ def mpc_psi0(z, prec, rnd=round_fast):
716
+ """
717
+ Computation of the digamma function (psi function of order 0)
718
+ of a complex argument.
719
+ """
720
+ re, im = z
721
+ # Fall back to the real case
722
+ if im == fzero:
723
+ return (mpf_psi0(re, prec, rnd), fzero)
724
+ wp = prec + 20
725
+ sign, man, exp, bc = re
726
+ # Reflection formula
727
+ if sign and exp+bc > 3:
728
+ c = mpc_cos_pi(z, wp)
729
+ s = mpc_sin_pi(z, wp)
730
+ q = mpc_mul_mpf(mpc_div(c, s, wp), mpf_pi(wp), wp)
731
+ p = mpc_psi0(mpc_sub(mpc_one, z, wp), wp)
732
+ return mpc_sub(p, q, prec, rnd)
733
+ # Just the logarithmic term
734
+ if (not sign) and bc + exp > wp:
735
+ return mpc_log(mpc_sub(z, mpc_one, wp), prec, rnd)
736
+ # Initial recurrence to obtain a large enough z
737
+ w = to_int(re)
738
+ n = int(0.11*wp) + 2
739
+ s = mpc_zero
740
+ if w < n:
741
+ for k in xrange(w, n):
742
+ s = mpc_sub(s, mpc_reciprocal(z, wp), wp)
743
+ z = mpc_add_mpf(z, fone, wp)
744
+ z = mpc_sub(z, mpc_one, wp)
745
+ # Logarithmic and endpoint term
746
+ s = mpc_add(s, mpc_log(z, wp), wp)
747
+ s = mpc_add(s, mpc_div(mpc_half, z, wp), wp)
748
+ # Euler-Maclaurin remainder sum
749
+ z2 = mpc_square(z, wp)
750
+ t = mpc_one
751
+ prev = mpc_zero
752
+ szprev = fzero
753
+ k = 1
754
+ eps = mpf_shift(fone, -wp+2)
755
+ while 1:
756
+ t = mpc_mul(t, z2, wp)
757
+ bern = mpf_bernoulli(2*k, wp)
758
+ term = mpc_mpf_div(bern, mpc_mul_int(t, 2*k, wp), wp)
759
+ s = mpc_sub(s, term, wp)
760
+ szterm = mpc_abs(term, 10)
761
+ if k > 2 and (mpf_le(szterm, eps) or mpf_le(szprev, szterm)):
762
+ break
763
+ prev = term
764
+ szprev = szterm
765
+ k += 1
766
+ return s
767
+
768
+ # Currently unoptimized
769
+ def mpf_psi(m, x, prec, rnd=round_fast):
770
+ """
771
+ Computation of the polygamma function of arbitrary integer order
772
+ m >= 0, for a real argument x.
773
+ """
774
+ if m == 0:
775
+ return mpf_psi0(x, prec, rnd=round_fast)
776
+ return mpc_psi(m, (x, fzero), prec, rnd)[0]
777
+
778
+ def mpc_psi(m, z, prec, rnd=round_fast):
779
+ """
780
+ Computation of the polygamma function of arbitrary integer order
781
+ m >= 0, for a complex argument z.
782
+ """
783
+ if m == 0:
784
+ return mpc_psi0(z, prec, rnd)
785
+ re, im = z
786
+ wp = prec + 20
787
+ sign, man, exp, bc = re
788
+ if not im[1]:
789
+ if im in (finf, fninf, fnan):
790
+ return (fnan, fnan)
791
+ if not man:
792
+ if re == finf and im == fzero:
793
+ return (fzero, fzero)
794
+ if re == fnan:
795
+ return (fnan, fnan)
796
+ # Recurrence
797
+ w = to_int(re)
798
+ n = int(0.4*wp + 4*m)
799
+ s = mpc_zero
800
+ if w < n:
801
+ for k in xrange(w, n):
802
+ t = mpc_pow_int(z, -m-1, wp)
803
+ s = mpc_add(s, t, wp)
804
+ z = mpc_add_mpf(z, fone, wp)
805
+ zm = mpc_pow_int(z, -m, wp)
806
+ z2 = mpc_pow_int(z, -2, wp)
807
+ # 1/m*(z+N)^m
808
+ integral_term = mpc_div_mpf(zm, from_int(m), wp)
809
+ s = mpc_add(s, integral_term, wp)
810
+ # 1/2*(z+N)^(-(m+1))
811
+ s = mpc_add(s, mpc_mul_mpf(mpc_div(zm, z, wp), fhalf, wp), wp)
812
+ a = m + 1
813
+ b = 2
814
+ k = 1
815
+ # Important: we want to sum up to the *relative* error,
816
+ # not the absolute error, because psi^(m)(z) might be tiny
817
+ magn = mpc_abs(s, 10)
818
+ magn = magn[2]+magn[3]
819
+ eps = mpf_shift(fone, magn-wp+2)
820
+ while 1:
821
+ zm = mpc_mul(zm, z2, wp)
822
+ bern = mpf_bernoulli(2*k, wp)
823
+ scal = mpf_mul_int(bern, a, wp)
824
+ scal = mpf_div(scal, from_int(b), wp)
825
+ term = mpc_mul_mpf(zm, scal, wp)
826
+ s = mpc_add(s, term, wp)
827
+ szterm = mpc_abs(term, 10)
828
+ if k > 2 and mpf_le(szterm, eps):
829
+ break
830
+ #print k, to_str(szterm, 10), to_str(eps, 10)
831
+ a *= (m+2*k)*(m+2*k+1)
832
+ b *= (2*k+1)*(2*k+2)
833
+ k += 1
834
+ # Scale and sign factor
835
+ v = mpc_mul_mpf(s, mpf_gamma(from_int(m+1), wp), prec, rnd)
836
+ if not (m & 1):
837
+ v = mpf_neg(v[0]), mpf_neg(v[1])
838
+ return v
839
+
840
+
841
+ #-----------------------------------------------------------------------#
842
+ # #
843
+ # Riemann zeta function #
844
+ # #
845
+ #-----------------------------------------------------------------------#
846
+
847
+ r"""
848
+ We use zeta(s) = eta(s) / (1 - 2**(1-s)) and Borwein's approximation
849
+
850
+ n-1
851
+ ___ k
852
+ -1 \ (-1) (d_k - d_n)
853
+ eta(s) ~= ---- ) ------------------
854
+ d_n /___ s
855
+ k = 0 (k + 1)
856
+ where
857
+ k
858
+ ___ i
859
+ \ (n + i - 1)! 4
860
+ d_k = n ) ---------------.
861
+ /___ (n - i)! (2i)!
862
+ i = 0
863
+
864
+ If s = a + b*I, the absolute error for eta(s) is bounded by
865
+
866
+ 3 (1 + 2|b|)
867
+ ------------ * exp(|b| pi/2)
868
+ n
869
+ (3+sqrt(8))
870
+
871
+ Disregarding the linear term, we have approximately,
872
+
873
+ log(err) ~= log(exp(1.58*|b|)) - log(5.8**n)
874
+ log(err) ~= 1.58*|b| - log(5.8)*n
875
+ log(err) ~= 1.58*|b| - 1.76*n
876
+ log2(err) ~= 2.28*|b| - 2.54*n
877
+
878
+ So for p bits, we should choose n > (p + 2.28*|b|) / 2.54.
879
+
880
+ References:
881
+ -----------
882
+
883
+ Peter Borwein, "An Efficient Algorithm for the Riemann Zeta Function"
884
+ http://www.cecm.sfu.ca/personal/pborwein/PAPERS/P117.ps
885
+
886
+ http://en.wikipedia.org/wiki/Dirichlet_eta_function
887
+ """
888
+
889
+ borwein_cache = {}
890
+
891
+ def borwein_coefficients(n):
892
+ if n in borwein_cache:
893
+ return borwein_cache[n]
894
+ ds = [MPZ_ZERO] * (n+1)
895
+ d = MPZ_ONE
896
+ s = ds[0] = MPZ_ONE
897
+ for i in range(1, n+1):
898
+ d = d * 4 * (n+i-1) * (n-i+1)
899
+ d //= ((2*i) * ((2*i)-1))
900
+ s += d
901
+ ds[i] = s
902
+ borwein_cache[n] = ds
903
+ return ds
904
+
905
+ ZETA_INT_CACHE_MAX_PREC = 1000
906
+ zeta_int_cache = {}
907
+
908
+ def mpf_zeta_int(s, prec, rnd=round_fast):
909
+ """
910
+ Optimized computation of zeta(s) for an integer s.
911
+ """
912
+ wp = prec + 20
913
+ s = int(s)
914
+ if s in zeta_int_cache and zeta_int_cache[s][0] >= wp:
915
+ return mpf_pos(zeta_int_cache[s][1], prec, rnd)
916
+ if s < 2:
917
+ if s == 1:
918
+ raise ValueError("zeta(1) pole")
919
+ if not s:
920
+ return mpf_neg(fhalf)
921
+ return mpf_div(mpf_bernoulli(-s+1, wp), from_int(s-1), prec, rnd)
922
+ # 2^-s term vanishes?
923
+ if s >= wp:
924
+ return mpf_perturb(fone, 0, prec, rnd)
925
+ # 5^-s term vanishes?
926
+ elif s >= wp*0.431:
927
+ t = one = 1 << wp
928
+ t += 1 << (wp - s)
929
+ t += one // (MPZ_THREE ** s)
930
+ t += 1 << max(0, wp - s*2)
931
+ return from_man_exp(t, -wp, prec, rnd)
932
+ else:
933
+ # Fast enough to sum directly?
934
+ # Even better, we use the Euler product (idea stolen from pari)
935
+ m = (float(wp)/(s-1) + 1)
936
+ if m < 30:
937
+ needed_terms = int(2.0**m + 1)
938
+ if needed_terms < int(wp/2.54 + 5) / 10:
939
+ t = fone
940
+ for k in list_primes(needed_terms):
941
+ #print k, needed_terms
942
+ powprec = int(wp - s*math.log(k,2))
943
+ if powprec < 2:
944
+ break
945
+ a = mpf_sub(fone, mpf_pow_int(from_int(k), -s, powprec), wp)
946
+ t = mpf_mul(t, a, wp)
947
+ return mpf_div(fone, t, wp)
948
+ # Use Borwein's algorithm
949
+ n = int(wp/2.54 + 5)
950
+ d = borwein_coefficients(n)
951
+ t = MPZ_ZERO
952
+ s = MPZ(s)
953
+ for k in xrange(n):
954
+ t += (((-1)**k * (d[k] - d[n])) << wp) // (k+1)**s
955
+ t = (t << wp) // (-d[n])
956
+ t = (t << wp) // ((1 << wp) - (1 << (wp+1-s)))
957
+ if (s in zeta_int_cache and zeta_int_cache[s][0] < wp) or (s not in zeta_int_cache):
958
+ zeta_int_cache[s] = (wp, from_man_exp(t, -wp-wp))
959
+ return from_man_exp(t, -wp-wp, prec, rnd)
960
+
961
+ def mpf_zeta(s, prec, rnd=round_fast, alt=0):
962
+ sign, man, exp, bc = s
963
+ if not man:
964
+ if s == fzero:
965
+ if alt:
966
+ return fhalf
967
+ else:
968
+ return mpf_neg(fhalf)
969
+ if s == finf:
970
+ return fone
971
+ return fnan
972
+ wp = prec + 20
973
+ # First term vanishes?
974
+ if (not sign) and (exp + bc > (math.log(wp,2) + 2)):
975
+ return mpf_perturb(fone, alt, prec, rnd)
976
+ # Optimize for integer arguments
977
+ elif exp >= 0:
978
+ if alt:
979
+ if s == fone:
980
+ return mpf_ln2(prec, rnd)
981
+ z = mpf_zeta_int(to_int(s), wp, negative_rnd[rnd])
982
+ q = mpf_sub(fone, mpf_pow(ftwo, mpf_sub(fone, s, wp), wp), wp)
983
+ return mpf_mul(z, q, prec, rnd)
984
+ else:
985
+ return mpf_zeta_int(to_int(s), prec, rnd)
986
+ # Negative: use the reflection formula
987
+ # Borwein only proves the accuracy bound for x >= 1/2. However, based on
988
+ # tests, the accuracy without reflection is quite good even some distance
989
+ # to the left of 1/2. XXX: verify this.
990
+ if sign:
991
+ # XXX: could use the separate refl. formula for Dirichlet eta
992
+ if alt:
993
+ q = mpf_sub(fone, mpf_pow(ftwo, mpf_sub(fone, s, wp), wp), wp)
994
+ return mpf_mul(mpf_zeta(s, wp), q, prec, rnd)
995
+ # XXX: -1 should be done exactly
996
+ y = mpf_sub(fone, s, 10*wp)
997
+ a = mpf_gamma(y, wp)
998
+ b = mpf_zeta(y, wp)
999
+ c = mpf_sin_pi(mpf_shift(s, -1), wp)
1000
+ wp2 = wp + max(0,exp+bc)
1001
+ pi = mpf_pi(wp+wp2)
1002
+ d = mpf_div(mpf_pow(mpf_shift(pi, 1), s, wp2), pi, wp2)
1003
+ return mpf_mul(a,mpf_mul(b,mpf_mul(c,d,wp),wp),prec,rnd)
1004
+
1005
+ # Near pole
1006
+ r = mpf_sub(fone, s, wp)
1007
+ asign, aman, aexp, abc = mpf_abs(r)
1008
+ pole_dist = -2*(aexp+abc)
1009
+ if pole_dist > wp:
1010
+ if alt:
1011
+ return mpf_ln2(prec, rnd)
1012
+ else:
1013
+ q = mpf_neg(mpf_div(fone, r, wp))
1014
+ return mpf_add(q, mpf_euler(wp), prec, rnd)
1015
+ else:
1016
+ wp += max(0, pole_dist)
1017
+
1018
+ t = MPZ_ZERO
1019
+ #wp += 16 - (prec & 15)
1020
+ # Use Borwein's algorithm
1021
+ n = int(wp/2.54 + 5)
1022
+ d = borwein_coefficients(n)
1023
+ t = MPZ_ZERO
1024
+ sf = to_fixed(s, wp)
1025
+ ln2 = ln2_fixed(wp)
1026
+ for k in xrange(n):
1027
+ u = (-sf*log_int_fixed(k+1, wp, ln2)) >> wp
1028
+ #esign, eman, eexp, ebc = mpf_exp(u, wp)
1029
+ #offset = eexp + wp
1030
+ #if offset >= 0:
1031
+ # w = ((d[k] - d[n]) * eman) << offset
1032
+ #else:
1033
+ # w = ((d[k] - d[n]) * eman) >> (-offset)
1034
+ eman = exp_fixed(u, wp, ln2)
1035
+ w = (d[k] - d[n]) * eman
1036
+ if k & 1:
1037
+ t -= w
1038
+ else:
1039
+ t += w
1040
+ t = t // (-d[n])
1041
+ t = from_man_exp(t, -wp, wp)
1042
+ if alt:
1043
+ return mpf_pos(t, prec, rnd)
1044
+ else:
1045
+ q = mpf_sub(fone, mpf_pow(ftwo, mpf_sub(fone, s, wp), wp), wp)
1046
+ return mpf_div(t, q, prec, rnd)
1047
+
1048
+ def mpc_zeta(s, prec, rnd=round_fast, alt=0, force=False):
1049
+ re, im = s
1050
+ if im == fzero:
1051
+ return mpf_zeta(re, prec, rnd, alt), fzero
1052
+
1053
+ # slow for large s
1054
+ if (not force) and mpf_gt(mpc_abs(s, 10), from_int(prec)):
1055
+ raise NotImplementedError
1056
+
1057
+ wp = prec + 20
1058
+
1059
+ # Near pole
1060
+ r = mpc_sub(mpc_one, s, wp)
1061
+ asign, aman, aexp, abc = mpc_abs(r, 10)
1062
+ pole_dist = -2*(aexp+abc)
1063
+ if pole_dist > wp:
1064
+ if alt:
1065
+ q = mpf_ln2(wp)
1066
+ y = mpf_mul(q, mpf_euler(wp), wp)
1067
+ g = mpf_shift(mpf_mul(q, q, wp), -1)
1068
+ g = mpf_sub(y, g)
1069
+ z = mpc_mul_mpf(r, mpf_neg(g), wp)
1070
+ z = mpc_add_mpf(z, q, wp)
1071
+ return mpc_pos(z, prec, rnd)
1072
+ else:
1073
+ q = mpc_neg(mpc_div(mpc_one, r, wp))
1074
+ q = mpc_add_mpf(q, mpf_euler(wp), wp)
1075
+ return mpc_pos(q, prec, rnd)
1076
+ else:
1077
+ wp += max(0, pole_dist)
1078
+
1079
+ # Reflection formula. To be rigorous, we should reflect to the left of
1080
+ # re = 1/2 (see comments for mpf_zeta), but this leads to unnecessary
1081
+ # slowdown for interesting values of s
1082
+ if mpf_lt(re, fzero):
1083
+ # XXX: could use the separate refl. formula for Dirichlet eta
1084
+ if alt:
1085
+ q = mpc_sub(mpc_one, mpc_pow(mpc_two, mpc_sub(mpc_one, s, wp),
1086
+ wp), wp)
1087
+ return mpc_mul(mpc_zeta(s, wp), q, prec, rnd)
1088
+ # XXX: -1 should be done exactly
1089
+ y = mpc_sub(mpc_one, s, 10*wp)
1090
+ a = mpc_gamma(y, wp)
1091
+ b = mpc_zeta(y, wp)
1092
+ c = mpc_sin_pi(mpc_shift(s, -1), wp)
1093
+ rsign, rman, rexp, rbc = re
1094
+ isign, iman, iexp, ibc = im
1095
+ mag = max(rexp+rbc, iexp+ibc)
1096
+ wp2 = wp + max(0, mag)
1097
+ pi = mpf_pi(wp+wp2)
1098
+ pi2 = (mpf_shift(pi, 1), fzero)
1099
+ d = mpc_div_mpf(mpc_pow(pi2, s, wp2), pi, wp2)
1100
+ return mpc_mul(a,mpc_mul(b,mpc_mul(c,d,wp),wp),prec,rnd)
1101
+ n = int(wp/2.54 + 5)
1102
+ n += int(0.9*abs(to_int(im)))
1103
+ d = borwein_coefficients(n)
1104
+ ref = to_fixed(re, wp)
1105
+ imf = to_fixed(im, wp)
1106
+ tre = MPZ_ZERO
1107
+ tim = MPZ_ZERO
1108
+ one = MPZ_ONE << wp
1109
+ one_2wp = MPZ_ONE << (2*wp)
1110
+ critical_line = re == fhalf
1111
+ ln2 = ln2_fixed(wp)
1112
+ pi2 = pi_fixed(wp-1)
1113
+ wp2 = wp+wp
1114
+ for k in xrange(n):
1115
+ log = log_int_fixed(k+1, wp, ln2)
1116
+ # A square root is much cheaper than an exp
1117
+ if critical_line:
1118
+ w = one_2wp // isqrt_fast((k+1) << wp2)
1119
+ else:
1120
+ w = exp_fixed((-ref*log) >> wp, wp)
1121
+ if k & 1:
1122
+ w *= (d[n] - d[k])
1123
+ else:
1124
+ w *= (d[k] - d[n])
1125
+ wre, wim = cos_sin_fixed((-imf*log)>>wp, wp, pi2)
1126
+ tre += (w * wre) >> wp
1127
+ tim += (w * wim) >> wp
1128
+ tre //= (-d[n])
1129
+ tim //= (-d[n])
1130
+ tre = from_man_exp(tre, -wp, wp)
1131
+ tim = from_man_exp(tim, -wp, wp)
1132
+ if alt:
1133
+ return mpc_pos((tre, tim), prec, rnd)
1134
+ else:
1135
+ q = mpc_sub(mpc_one, mpc_pow(mpc_two, r, wp), wp)
1136
+ return mpc_div((tre, tim), q, prec, rnd)
1137
+
1138
+ def mpf_altzeta(s, prec, rnd=round_fast):
1139
+ return mpf_zeta(s, prec, rnd, 1)
1140
+
1141
+ def mpc_altzeta(s, prec, rnd=round_fast):
1142
+ return mpc_zeta(s, prec, rnd, 1)
1143
+
1144
+ # Not optimized currently
1145
+ mpf_zetasum = None
1146
+
1147
+
1148
+ def pow_fixed(x, n, wp):
1149
+ if n == 1:
1150
+ return x
1151
+ y = MPZ_ONE << wp
1152
+ while n:
1153
+ if n & 1:
1154
+ y = (y*x) >> wp
1155
+ n -= 1
1156
+ x = (x*x) >> wp
1157
+ n //= 2
1158
+ return y
1159
+
1160
+ # TODO: optimize / cleanup interface / unify with list_primes
1161
+ sieve_cache = []
1162
+ primes_cache = []
1163
+ mult_cache = []
1164
+
1165
+ def primesieve(n):
1166
+ global sieve_cache, primes_cache, mult_cache
1167
+ if n < len(sieve_cache):
1168
+ sieve = sieve_cache#[:n+1]
1169
+ primes = primes_cache[:primes_cache.index(max(sieve))+1]
1170
+ mult = mult_cache#[:n+1]
1171
+ return sieve, primes, mult
1172
+ sieve = [0] * (n+1)
1173
+ mult = [0] * (n+1)
1174
+ primes = list_primes(n)
1175
+ for p in primes:
1176
+ #sieve[p::p] = p
1177
+ for k in xrange(p,n+1,p):
1178
+ sieve[k] = p
1179
+ for i, p in enumerate(sieve):
1180
+ if i >= 2:
1181
+ m = 1
1182
+ n = i // p
1183
+ while not n % p:
1184
+ n //= p
1185
+ m += 1
1186
+ mult[i] = m
1187
+ sieve_cache = sieve
1188
+ primes_cache = primes
1189
+ mult_cache = mult
1190
+ return sieve, primes, mult
1191
+
1192
+ def zetasum_sieved(critical_line, sre, sim, a, n, wp):
1193
+ if a < 1:
1194
+ raise ValueError("a cannot be less than 1")
1195
+ sieve, primes, mult = primesieve(a+n)
1196
+ basic_powers = {}
1197
+ one = MPZ_ONE << wp
1198
+ one_2wp = MPZ_ONE << (2*wp)
1199
+ wp2 = wp+wp
1200
+ ln2 = ln2_fixed(wp)
1201
+ pi2 = pi_fixed(wp-1)
1202
+ for p in primes:
1203
+ if p*2 > a+n:
1204
+ break
1205
+ log = log_int_fixed(p, wp, ln2)
1206
+ cos, sin = cos_sin_fixed((-sim*log)>>wp, wp, pi2)
1207
+ if critical_line:
1208
+ u = one_2wp // isqrt_fast(p<<wp2)
1209
+ else:
1210
+ u = exp_fixed((-sre*log)>>wp, wp)
1211
+ pre = (u*cos) >> wp
1212
+ pim = (u*sin) >> wp
1213
+ basic_powers[p] = [(pre, pim)]
1214
+ tre, tim = pre, pim
1215
+ for m in range(1,int(math.log(a+n,p)+0.01)+1):
1216
+ tre, tim = ((pre*tre-pim*tim)>>wp), ((pim*tre+pre*tim)>>wp)
1217
+ basic_powers[p].append((tre,tim))
1218
+ xre = MPZ_ZERO
1219
+ xim = MPZ_ZERO
1220
+ if a == 1:
1221
+ xre += one
1222
+ aa = max(a,2)
1223
+ for k in xrange(aa, a+n+1):
1224
+ p = sieve[k]
1225
+ if p in basic_powers:
1226
+ m = mult[k]
1227
+ tre, tim = basic_powers[p][m-1]
1228
+ while 1:
1229
+ k //= p**m
1230
+ if k == 1:
1231
+ break
1232
+ p = sieve[k]
1233
+ m = mult[k]
1234
+ pre, pim = basic_powers[p][m-1]
1235
+ tre, tim = ((pre*tre-pim*tim)>>wp), ((pim*tre+pre*tim)>>wp)
1236
+ else:
1237
+ log = log_int_fixed(k, wp, ln2)
1238
+ cos, sin = cos_sin_fixed((-sim*log)>>wp, wp, pi2)
1239
+ if critical_line:
1240
+ u = one_2wp // isqrt_fast(k<<wp2)
1241
+ else:
1242
+ u = exp_fixed((-sre*log)>>wp, wp)
1243
+ tre = (u*cos) >> wp
1244
+ tim = (u*sin) >> wp
1245
+ xre += tre
1246
+ xim += tim
1247
+ return xre, xim
1248
+
1249
+ # Set to something large to disable
1250
+ ZETASUM_SIEVE_CUTOFF = 10
1251
+
1252
+ def mpc_zetasum(s, a, n, derivatives, reflect, prec):
1253
+ """
1254
+ Fast version of mp._zetasum, assuming s = complex, a = integer.
1255
+ """
1256
+
1257
+ wp = prec + 10
1258
+ derivatives = list(derivatives)
1259
+ have_derivatives = derivatives != [0]
1260
+ have_one_derivative = len(derivatives) == 1
1261
+
1262
+ # parse s
1263
+ sre, sim = s
1264
+ critical_line = (sre == fhalf)
1265
+ sre = to_fixed(sre, wp)
1266
+ sim = to_fixed(sim, wp)
1267
+
1268
+ if a > 0 and n > ZETASUM_SIEVE_CUTOFF and not have_derivatives \
1269
+ and not reflect and (n < 4e7 or sys.maxsize > 2**32):
1270
+ re, im = zetasum_sieved(critical_line, sre, sim, a, n, wp)
1271
+ xs = [(from_man_exp(re, -wp, prec, 'n'), from_man_exp(im, -wp, prec, 'n'))]
1272
+ return xs, []
1273
+
1274
+ maxd = max(derivatives)
1275
+ if not have_one_derivative:
1276
+ derivatives = range(maxd+1)
1277
+
1278
+ # x_d = 0, y_d = 0
1279
+ xre = [MPZ_ZERO for d in derivatives]
1280
+ xim = [MPZ_ZERO for d in derivatives]
1281
+ if reflect:
1282
+ yre = [MPZ_ZERO for d in derivatives]
1283
+ yim = [MPZ_ZERO for d in derivatives]
1284
+ else:
1285
+ yre = yim = []
1286
+
1287
+ one = MPZ_ONE << wp
1288
+ one_2wp = MPZ_ONE << (2*wp)
1289
+
1290
+ ln2 = ln2_fixed(wp)
1291
+ pi2 = pi_fixed(wp-1)
1292
+ wp2 = wp+wp
1293
+
1294
+ for w in xrange(a, a+n+1):
1295
+ log = log_int_fixed(w, wp, ln2)
1296
+ cos, sin = cos_sin_fixed((-sim*log)>>wp, wp, pi2)
1297
+ if critical_line:
1298
+ u = one_2wp // isqrt_fast(w<<wp2)
1299
+ else:
1300
+ u = exp_fixed((-sre*log)>>wp, wp)
1301
+ xterm_re = (u * cos) >> wp
1302
+ xterm_im = (u * sin) >> wp
1303
+ if reflect:
1304
+ reciprocal = (one_2wp // (u*w))
1305
+ yterm_re = (reciprocal * cos) >> wp
1306
+ yterm_im = (reciprocal * sin) >> wp
1307
+
1308
+ if have_derivatives:
1309
+ if have_one_derivative:
1310
+ log = pow_fixed(log, maxd, wp)
1311
+ xre[0] += (xterm_re * log) >> wp
1312
+ xim[0] += (xterm_im * log) >> wp
1313
+ if reflect:
1314
+ yre[0] += (yterm_re * log) >> wp
1315
+ yim[0] += (yterm_im * log) >> wp
1316
+ else:
1317
+ t = MPZ_ONE << wp
1318
+ for d in derivatives:
1319
+ xre[d] += (xterm_re * t) >> wp
1320
+ xim[d] += (xterm_im * t) >> wp
1321
+ if reflect:
1322
+ yre[d] += (yterm_re * t) >> wp
1323
+ yim[d] += (yterm_im * t) >> wp
1324
+ t = (t * log) >> wp
1325
+ else:
1326
+ xre[0] += xterm_re
1327
+ xim[0] += xterm_im
1328
+ if reflect:
1329
+ yre[0] += yterm_re
1330
+ yim[0] += yterm_im
1331
+ if have_derivatives:
1332
+ if have_one_derivative:
1333
+ if maxd % 2:
1334
+ xre[0] = -xre[0]
1335
+ xim[0] = -xim[0]
1336
+ if reflect:
1337
+ yre[0] = -yre[0]
1338
+ yim[0] = -yim[0]
1339
+ else:
1340
+ xre = [(-1)**d * xre[d] for d in derivatives]
1341
+ xim = [(-1)**d * xim[d] for d in derivatives]
1342
+ if reflect:
1343
+ yre = [(-1)**d * yre[d] for d in derivatives]
1344
+ yim = [(-1)**d * yim[d] for d in derivatives]
1345
+ xs = [(from_man_exp(xa, -wp, prec, 'n'), from_man_exp(xb, -wp, prec, 'n'))
1346
+ for (xa, xb) in zip(xre, xim)]
1347
+ ys = [(from_man_exp(ya, -wp, prec, 'n'), from_man_exp(yb, -wp, prec, 'n'))
1348
+ for (ya, yb) in zip(yre, yim)]
1349
+ return xs, ys
1350
+
1351
+
1352
+ #-----------------------------------------------------------------------#
1353
+ # #
1354
+ # The gamma function (NEW IMPLEMENTATION) #
1355
+ # #
1356
+ #-----------------------------------------------------------------------#
1357
+
1358
+ # Higher means faster, but more precomputation time
1359
+ MAX_GAMMA_TAYLOR_PREC = 5000
1360
+ # Need to derive higher bounds for Taylor series to go higher
1361
+ assert MAX_GAMMA_TAYLOR_PREC < 15000
1362
+
1363
+ # Use Stirling's series if abs(x) > beta*prec
1364
+ # Important: must be large enough for convergence!
1365
+ GAMMA_STIRLING_BETA = 0.2
1366
+
1367
+ SMALL_FACTORIAL_CACHE_SIZE = 150
1368
+
1369
+ gamma_taylor_cache = {}
1370
+ gamma_stirling_cache = {}
1371
+
1372
+ small_factorial_cache = [from_int(ifac(n)) for \
1373
+ n in range(SMALL_FACTORIAL_CACHE_SIZE+1)]
1374
+
1375
+ def zeta_array(N, prec):
1376
+ """
1377
+ zeta(n) = A * pi**n / n! + B
1378
+
1379
+ where A is a rational number (A = Bernoulli number
1380
+ for n even) and B is an infinite sum over powers of exp(2*pi).
1381
+ (B = 0 for n even).
1382
+
1383
+ TODO: this is currently only used for gamma, but could
1384
+ be very useful elsewhere.
1385
+ """
1386
+ extra = 30
1387
+ wp = prec+extra
1388
+ zeta_values = [MPZ_ZERO] * (N+2)
1389
+ pi = pi_fixed(wp)
1390
+ # STEP 1:
1391
+ one = MPZ_ONE << wp
1392
+ zeta_values[0] = -one//2
1393
+ f_2pi = mpf_shift(mpf_pi(wp),1)
1394
+ exp_2pi_k = exp_2pi = mpf_exp(f_2pi, wp)
1395
+ # Compute exponential series
1396
+ # Store values of 1/(exp(2*pi*k)-1),
1397
+ # exp(2*pi*k)/(exp(2*pi*k)-1)**2, 1/(exp(2*pi*k)-1)**2
1398
+ # pi*k*exp(2*pi*k)/(exp(2*pi*k)-1)**2
1399
+ exps3 = []
1400
+ k = 1
1401
+ while 1:
1402
+ tp = wp - 9*k
1403
+ if tp < 1:
1404
+ break
1405
+ # 1/(exp(2*pi*k-1)
1406
+ q1 = mpf_div(fone, mpf_sub(exp_2pi_k, fone, tp), tp)
1407
+ # pi*k*exp(2*pi*k)/(exp(2*pi*k)-1)**2
1408
+ q2 = mpf_mul(exp_2pi_k, mpf_mul(q1,q1,tp), tp)
1409
+ q1 = to_fixed(q1, wp)
1410
+ q2 = to_fixed(q2, wp)
1411
+ q2 = (k * q2 * pi) >> wp
1412
+ exps3.append((q1, q2))
1413
+ # Multiply for next round
1414
+ exp_2pi_k = mpf_mul(exp_2pi_k, exp_2pi, wp)
1415
+ k += 1
1416
+ # Exponential sum
1417
+ for n in xrange(3, N+1, 2):
1418
+ s = MPZ_ZERO
1419
+ k = 1
1420
+ for e1, e2 in exps3:
1421
+ if n%4 == 3:
1422
+ t = e1 // k**n
1423
+ else:
1424
+ U = (n-1)//4
1425
+ t = (e1 + e2//U) // k**n
1426
+ if not t:
1427
+ break
1428
+ s += t
1429
+ k += 1
1430
+ zeta_values[n] = -2*s
1431
+ # Even zeta values
1432
+ B = [mpf_abs(mpf_bernoulli(k,wp)) for k in xrange(N+2)]
1433
+ pi_pow = fpi = mpf_pow_int(mpf_shift(mpf_pi(wp), 1), 2, wp)
1434
+ pi_pow = mpf_div(pi_pow, from_int(4), wp)
1435
+ for n in xrange(2,N+2,2):
1436
+ z = mpf_mul(B[n], pi_pow, wp)
1437
+ zeta_values[n] = to_fixed(z, wp)
1438
+ pi_pow = mpf_mul(pi_pow, fpi, wp)
1439
+ pi_pow = mpf_div(pi_pow, from_int((n+1)*(n+2)), wp)
1440
+ # Zeta sum
1441
+ reciprocal_pi = (one << wp) // pi
1442
+ for n in xrange(3, N+1, 4):
1443
+ U = (n-3)//4
1444
+ s = zeta_values[4*U+4]*(4*U+7)//4
1445
+ for k in xrange(1, U+1):
1446
+ s -= (zeta_values[4*k] * zeta_values[4*U+4-4*k]) >> wp
1447
+ zeta_values[n] += (2*s*reciprocal_pi) >> wp
1448
+ for n in xrange(5, N+1, 4):
1449
+ U = (n-1)//4
1450
+ s = zeta_values[4*U+2]*(2*U+1)
1451
+ for k in xrange(1, 2*U+1):
1452
+ s += ((-1)**k*2*k* zeta_values[2*k] * zeta_values[4*U+2-2*k])>>wp
1453
+ zeta_values[n] += ((s*reciprocal_pi)>>wp)//(2*U)
1454
+ return [x>>extra for x in zeta_values]
1455
+
1456
+ def gamma_taylor_coefficients(inprec):
1457
+ """
1458
+ Gives the Taylor coefficients of 1/gamma(1+x) as
1459
+ a list of fixed-point numbers. Enough coefficients are returned
1460
+ to ensure that the series converges to the given precision
1461
+ when x is in [0.5, 1.5].
1462
+ """
1463
+ # Reuse nearby cache values (small case)
1464
+ if inprec < 400:
1465
+ prec = inprec + (10-(inprec%10))
1466
+ elif inprec < 1000:
1467
+ prec = inprec + (30-(inprec%30))
1468
+ else:
1469
+ prec = inprec
1470
+ if prec in gamma_taylor_cache:
1471
+ return gamma_taylor_cache[prec], prec
1472
+
1473
+ # Experimentally determined bounds
1474
+ if prec < 1000:
1475
+ N = int(prec**0.76 + 2)
1476
+ else:
1477
+ # Valid to at least 15000 bits
1478
+ N = int(prec**0.787 + 2)
1479
+
1480
+ # Reuse higher precision values
1481
+ for cprec in gamma_taylor_cache:
1482
+ if cprec > prec:
1483
+ coeffs = [x>>(cprec-prec) for x in gamma_taylor_cache[cprec][-N:]]
1484
+ if inprec < 1000:
1485
+ gamma_taylor_cache[prec] = coeffs
1486
+ return coeffs, prec
1487
+
1488
+ # Cache at a higher precision (large case)
1489
+ if prec > 1000:
1490
+ prec = int(prec * 1.2)
1491
+
1492
+ wp = prec + 20
1493
+ A = [0] * N
1494
+ A[0] = MPZ_ZERO
1495
+ A[1] = MPZ_ONE << wp
1496
+ A[2] = euler_fixed(wp)
1497
+ # SLOW, reference implementation
1498
+ #zeta_values = [0,0]+[to_fixed(mpf_zeta_int(k,wp),wp) for k in xrange(2,N)]
1499
+ zeta_values = zeta_array(N, wp)
1500
+ for k in xrange(3, N):
1501
+ a = (-A[2]*A[k-1])>>wp
1502
+ for j in xrange(2,k):
1503
+ a += ((-1)**j * zeta_values[j] * A[k-j]) >> wp
1504
+ a //= (1-k)
1505
+ A[k] = a
1506
+ A = [a>>20 for a in A]
1507
+ A = A[::-1]
1508
+ A = A[:-1]
1509
+ gamma_taylor_cache[prec] = A
1510
+ #return A, prec
1511
+ return gamma_taylor_coefficients(inprec)
1512
+
1513
+ def gamma_fixed_taylor(xmpf, x, wp, prec, rnd, type):
1514
+ # Determine nearest multiple of N/2
1515
+ #n = int(x >> (wp-1))
1516
+ #steps = (n-1)>>1
1517
+ nearest_int = ((x >> (wp-1)) + MPZ_ONE) >> 1
1518
+ one = MPZ_ONE << wp
1519
+ coeffs, cwp = gamma_taylor_coefficients(wp)
1520
+ if nearest_int > 0:
1521
+ r = one
1522
+ for i in xrange(nearest_int-1):
1523
+ x -= one
1524
+ r = (r*x) >> wp
1525
+ x -= one
1526
+ p = MPZ_ZERO
1527
+ for c in coeffs:
1528
+ p = c + ((x*p)>>wp)
1529
+ p >>= (cwp-wp)
1530
+ if type == 0:
1531
+ return from_man_exp((r<<wp)//p, -wp, prec, rnd)
1532
+ if type == 2:
1533
+ return mpf_shift(from_rational(p, (r<<wp), prec, rnd), wp)
1534
+ if type == 3:
1535
+ return mpf_log(mpf_abs(from_man_exp((r<<wp)//p, -wp)), prec, rnd)
1536
+ else:
1537
+ r = one
1538
+ for i in xrange(-nearest_int):
1539
+ r = (r*x) >> wp
1540
+ x += one
1541
+ p = MPZ_ZERO
1542
+ for c in coeffs:
1543
+ p = c + ((x*p)>>wp)
1544
+ p >>= (cwp-wp)
1545
+ if wp - bitcount(abs(x)) > 10:
1546
+ # pass very close to 0, so do floating-point multiply
1547
+ g = mpf_add(xmpf, from_int(-nearest_int)) # exact
1548
+ r = from_man_exp(p*r,-wp-wp)
1549
+ r = mpf_mul(r, g, wp)
1550
+ if type == 0:
1551
+ return mpf_div(fone, r, prec, rnd)
1552
+ if type == 2:
1553
+ return mpf_pos(r, prec, rnd)
1554
+ if type == 3:
1555
+ return mpf_log(mpf_abs(mpf_div(fone, r, wp)), prec, rnd)
1556
+ else:
1557
+ r = from_man_exp(x*p*r,-3*wp)
1558
+ if type == 0: return mpf_div(fone, r, prec, rnd)
1559
+ if type == 2: return mpf_pos(r, prec, rnd)
1560
+ if type == 3: return mpf_neg(mpf_log(mpf_abs(r), prec, rnd))
1561
+
1562
+ def stirling_coefficient(n):
1563
+ if n in gamma_stirling_cache:
1564
+ return gamma_stirling_cache[n]
1565
+ p, q = bernfrac(n)
1566
+ q *= MPZ(n*(n-1))
1567
+ gamma_stirling_cache[n] = p, q, bitcount(abs(p)), bitcount(q)
1568
+ return gamma_stirling_cache[n]
1569
+
1570
+ def real_stirling_series(x, prec):
1571
+ """
1572
+ Sums the rational part of Stirling's expansion,
1573
+
1574
+ log(sqrt(2*pi)) - z + 1/(12*z) - 1/(360*z^3) + ...
1575
+
1576
+ """
1577
+ t = (MPZ_ONE<<(prec+prec)) // x # t = 1/x
1578
+ u = (t*t)>>prec # u = 1/x**2
1579
+ s = ln_sqrt2pi_fixed(prec) - x
1580
+ # Add initial terms of Stirling's series
1581
+ s += t//12; t = (t*u)>>prec
1582
+ s -= t//360; t = (t*u)>>prec
1583
+ s += t//1260; t = (t*u)>>prec
1584
+ s -= t//1680; t = (t*u)>>prec
1585
+ if not t: return s
1586
+ s += t//1188; t = (t*u)>>prec
1587
+ s -= 691*t//360360; t = (t*u)>>prec
1588
+ s += t//156; t = (t*u)>>prec
1589
+ if not t: return s
1590
+ s -= 3617*t//122400; t = (t*u)>>prec
1591
+ s += 43867*t//244188; t = (t*u)>>prec
1592
+ s -= 174611*t//125400; t = (t*u)>>prec
1593
+ if not t: return s
1594
+ k = 22
1595
+ # From here on, the coefficients are growing, so we
1596
+ # have to keep t at a roughly constant size
1597
+ usize = bitcount(abs(u))
1598
+ tsize = bitcount(abs(t))
1599
+ texp = 0
1600
+ while 1:
1601
+ p, q, pb, qb = stirling_coefficient(k)
1602
+ term_mag = tsize + pb + texp
1603
+ shift = -texp
1604
+ m = pb - term_mag
1605
+ if m > 0 and shift < m:
1606
+ p >>= m
1607
+ shift -= m
1608
+ m = tsize - term_mag
1609
+ if m > 0 and shift < m:
1610
+ w = t >> m
1611
+ shift -= m
1612
+ else:
1613
+ w = t
1614
+ term = (t*p//q) >> shift
1615
+ if not term:
1616
+ break
1617
+ s += term
1618
+ t = (t*u) >> usize
1619
+ texp -= (prec - usize)
1620
+ k += 2
1621
+ return s
1622
+
1623
+ def complex_stirling_series(x, y, prec):
1624
+ # t = 1/z
1625
+ _m = (x*x + y*y) >> prec
1626
+ tre = (x << prec) // _m
1627
+ tim = (-y << prec) // _m
1628
+ # u = 1/z**2
1629
+ ure = (tre*tre - tim*tim) >> prec
1630
+ uim = tim*tre >> (prec-1)
1631
+ # s = log(sqrt(2*pi)) - z
1632
+ sre = ln_sqrt2pi_fixed(prec) - x
1633
+ sim = -y
1634
+
1635
+ # Add initial terms of Stirling's series
1636
+ sre += tre//12; sim += tim//12;
1637
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1638
+ sre -= tre//360; sim -= tim//360;
1639
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1640
+ sre += tre//1260; sim += tim//1260;
1641
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1642
+ sre -= tre//1680; sim -= tim//1680;
1643
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1644
+ if abs(tre) + abs(tim) < 5: return sre, sim
1645
+ sre += tre//1188; sim += tim//1188;
1646
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1647
+ sre -= 691*tre//360360; sim -= 691*tim//360360;
1648
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1649
+ sre += tre//156; sim += tim//156;
1650
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1651
+ if abs(tre) + abs(tim) < 5: return sre, sim
1652
+ sre -= 3617*tre//122400; sim -= 3617*tim//122400;
1653
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1654
+ sre += 43867*tre//244188; sim += 43867*tim//244188;
1655
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1656
+ sre -= 174611*tre//125400; sim -= 174611*tim//125400;
1657
+ tre, tim = ((tre*ure-tim*uim)>>prec), ((tre*uim+tim*ure)>>prec)
1658
+ if abs(tre) + abs(tim) < 5: return sre, sim
1659
+
1660
+ k = 22
1661
+ # From here on, the coefficients are growing, so we
1662
+ # have to keep t at a roughly constant size
1663
+ usize = bitcount(max(abs(ure), abs(uim)))
1664
+ tsize = bitcount(max(abs(tre), abs(tim)))
1665
+ texp = 0
1666
+ while 1:
1667
+ p, q, pb, qb = stirling_coefficient(k)
1668
+ term_mag = tsize + pb + texp
1669
+ shift = -texp
1670
+ m = pb - term_mag
1671
+ if m > 0 and shift < m:
1672
+ p >>= m
1673
+ shift -= m
1674
+ m = tsize - term_mag
1675
+ if m > 0 and shift < m:
1676
+ wre = tre >> m
1677
+ wim = tim >> m
1678
+ shift -= m
1679
+ else:
1680
+ wre = tre
1681
+ wim = tim
1682
+ termre = (tre*p//q) >> shift
1683
+ termim = (tim*p//q) >> shift
1684
+ if abs(termre) + abs(termim) < 5:
1685
+ break
1686
+ sre += termre
1687
+ sim += termim
1688
+ tre, tim = ((tre*ure - tim*uim)>>usize), \
1689
+ ((tre*uim + tim*ure)>>usize)
1690
+ texp -= (prec - usize)
1691
+ k += 2
1692
+ return sre, sim
1693
+
1694
+
1695
+ def mpf_gamma(x, prec, rnd='d', type=0):
1696
+ """
1697
+ This function implements multipurpose evaluation of the gamma
1698
+ function, G(x), as well as the following versions of the same:
1699
+
1700
+ type = 0 -- G(x) [standard gamma function]
1701
+ type = 1 -- G(x+1) = x*G(x+1) = x! [factorial]
1702
+ type = 2 -- 1/G(x) [reciprocal gamma function]
1703
+ type = 3 -- log(|G(x)|) [log-gamma function, real part]
1704
+ """
1705
+
1706
+ # Specal values
1707
+ sign, man, exp, bc = x
1708
+ if not man:
1709
+ if x == fzero:
1710
+ if type == 1: return fone
1711
+ if type == 2: return fzero
1712
+ raise ValueError("gamma function pole")
1713
+ if x == finf:
1714
+ if type == 2: return fzero
1715
+ return finf
1716
+ return fnan
1717
+
1718
+ # First of all, for log gamma, numbers can be well beyond the fixed-point
1719
+ # range, so we must take care of huge numbers before e.g. trying
1720
+ # to convert x to the nearest integer
1721
+ if type == 3:
1722
+ wp = prec+20
1723
+ if exp+bc > wp and not sign:
1724
+ return mpf_sub(mpf_mul(x, mpf_log(x, wp), wp), x, prec, rnd)
1725
+
1726
+ # We strongly want to special-case small integers
1727
+ is_integer = exp >= 0
1728
+ if is_integer:
1729
+ # Poles
1730
+ if sign:
1731
+ if type == 2:
1732
+ return fzero
1733
+ raise ValueError("gamma function pole")
1734
+ # n = x
1735
+ n = man << exp
1736
+ if n < SMALL_FACTORIAL_CACHE_SIZE:
1737
+ if type == 0:
1738
+ return mpf_pos(small_factorial_cache[n-1], prec, rnd)
1739
+ if type == 1:
1740
+ return mpf_pos(small_factorial_cache[n], prec, rnd)
1741
+ if type == 2:
1742
+ return mpf_div(fone, small_factorial_cache[n-1], prec, rnd)
1743
+ if type == 3:
1744
+ return mpf_log(small_factorial_cache[n-1], prec, rnd)
1745
+ else:
1746
+ # floor(abs(x))
1747
+ n = int(man >> (-exp))
1748
+
1749
+ # Estimate size and precision
1750
+ # Estimate log(gamma(|x|),2) as x*log(x,2)
1751
+ mag = exp + bc
1752
+ gamma_size = n*mag
1753
+
1754
+ if type == 3:
1755
+ wp = prec + 20
1756
+ else:
1757
+ wp = prec + bitcount(gamma_size) + 20
1758
+
1759
+ # Very close to 0, pole
1760
+ if mag < -wp:
1761
+ if type == 0:
1762
+ return mpf_sub(mpf_div(fone,x, wp),mpf_shift(fone,-wp),prec,rnd)
1763
+ if type == 1: return mpf_sub(fone, x, prec, rnd)
1764
+ if type == 2: return mpf_add(x, mpf_shift(fone,mag-wp), prec, rnd)
1765
+ if type == 3: return mpf_neg(mpf_log(mpf_abs(x), prec, rnd))
1766
+
1767
+ # From now on, we assume having a gamma function
1768
+ if type == 1:
1769
+ return mpf_gamma(mpf_add(x, fone), prec, rnd, 0)
1770
+
1771
+ # Special case integers (those not small enough to be caught above,
1772
+ # but still small enough for an exact factorial to be faster
1773
+ # than an approximate algorithm), and half-integers
1774
+ if exp >= -1:
1775
+ if is_integer:
1776
+ if gamma_size < 10*wp:
1777
+ if type == 0:
1778
+ return from_int(ifac(n-1), prec, rnd)
1779
+ if type == 2:
1780
+ return from_rational(MPZ_ONE, ifac(n-1), prec, rnd)
1781
+ if type == 3:
1782
+ return mpf_log(from_int(ifac(n-1)), prec, rnd)
1783
+ # half-integer
1784
+ if n < 100 or gamma_size < 10*wp:
1785
+ if sign:
1786
+ w = sqrtpi_fixed(wp)
1787
+ if n % 2: f = ifac2(2*n+1)
1788
+ else: f = -ifac2(2*n+1)
1789
+ if type == 0:
1790
+ return mpf_shift(from_rational(w, f, prec, rnd), -wp+n+1)
1791
+ if type == 2:
1792
+ return mpf_shift(from_rational(f, w, prec, rnd), wp-n-1)
1793
+ if type == 3:
1794
+ return mpf_log(mpf_shift(from_rational(w, abs(f),
1795
+ prec, rnd), -wp+n+1), prec, rnd)
1796
+ elif n == 0:
1797
+ if type == 0: return mpf_sqrtpi(prec, rnd)
1798
+ if type == 2: return mpf_div(fone, mpf_sqrtpi(wp), prec, rnd)
1799
+ if type == 3: return mpf_log(mpf_sqrtpi(wp), prec, rnd)
1800
+ else:
1801
+ w = sqrtpi_fixed(wp)
1802
+ w = from_man_exp(w * ifac2(2*n-1), -wp-n)
1803
+ if type == 0: return mpf_pos(w, prec, rnd)
1804
+ if type == 2: return mpf_div(fone, w, prec, rnd)
1805
+ if type == 3: return mpf_log(mpf_abs(w), prec, rnd)
1806
+
1807
+ # Convert to fixed point
1808
+ offset = exp + wp
1809
+ if offset >= 0: absxman = man << offset
1810
+ else: absxman = man >> (-offset)
1811
+
1812
+ # For log gamma, provide accurate evaluation for x = 1+eps and 2+eps
1813
+ if type == 3 and not sign:
1814
+ one = MPZ_ONE << wp
1815
+ one_dist = abs(absxman-one)
1816
+ two_dist = abs(absxman-2*one)
1817
+ cancellation = (wp - bitcount(min(one_dist, two_dist)))
1818
+ if cancellation > 10:
1819
+ xsub1 = mpf_sub(fone, x)
1820
+ xsub2 = mpf_sub(ftwo, x)
1821
+ xsub1mag = xsub1[2]+xsub1[3]
1822
+ xsub2mag = xsub2[2]+xsub2[3]
1823
+ if xsub1mag < -wp:
1824
+ return mpf_mul(mpf_euler(wp), mpf_sub(fone, x), prec, rnd)
1825
+ if xsub2mag < -wp:
1826
+ return mpf_mul(mpf_sub(fone, mpf_euler(wp)),
1827
+ mpf_sub(x, ftwo), prec, rnd)
1828
+ # Proceed but increase precision
1829
+ wp += max(-xsub1mag, -xsub2mag)
1830
+ offset = exp + wp
1831
+ if offset >= 0: absxman = man << offset
1832
+ else: absxman = man >> (-offset)
1833
+
1834
+ # Use Taylor series if appropriate
1835
+ n_for_stirling = int(GAMMA_STIRLING_BETA*wp)
1836
+ if n < max(100, n_for_stirling) and wp < MAX_GAMMA_TAYLOR_PREC:
1837
+ if sign:
1838
+ absxman = -absxman
1839
+ return gamma_fixed_taylor(x, absxman, wp, prec, rnd, type)
1840
+
1841
+ # Use Stirling's series
1842
+ # First ensure that |x| is large enough for rapid convergence
1843
+ xorig = x
1844
+
1845
+ # Argument reduction
1846
+ r = 0
1847
+ if n < n_for_stirling:
1848
+ r = one = MPZ_ONE << wp
1849
+ d = n_for_stirling - n
1850
+ for k in xrange(d):
1851
+ r = (r * absxman) >> wp
1852
+ absxman += one
1853
+ x = xabs = from_man_exp(absxman, -wp)
1854
+ if sign:
1855
+ x = mpf_neg(x)
1856
+ else:
1857
+ xabs = mpf_abs(x)
1858
+
1859
+ # Asymptotic series
1860
+ y = real_stirling_series(absxman, wp)
1861
+ u = to_fixed(mpf_log(xabs, wp), wp)
1862
+ u = ((absxman - (MPZ_ONE<<(wp-1))) * u) >> wp
1863
+ y += u
1864
+ w = from_man_exp(y, -wp)
1865
+
1866
+ # Compute final value
1867
+ if sign:
1868
+ # Reflection formula
1869
+ A = mpf_mul(mpf_sin_pi(xorig, wp), xorig, wp)
1870
+ B = mpf_neg(mpf_pi(wp))
1871
+ if type == 0 or type == 2:
1872
+ A = mpf_mul(A, mpf_exp(w, wp))
1873
+ if r:
1874
+ B = mpf_mul(B, from_man_exp(r, -wp), wp)
1875
+ if type == 0:
1876
+ return mpf_div(B, A, prec, rnd)
1877
+ if type == 2:
1878
+ return mpf_div(A, B, prec, rnd)
1879
+ if type == 3:
1880
+ if r:
1881
+ B = mpf_mul(B, from_man_exp(r, -wp), wp)
1882
+ A = mpf_add(mpf_log(mpf_abs(A), wp), w, wp)
1883
+ return mpf_sub(mpf_log(mpf_abs(B), wp), A, prec, rnd)
1884
+ else:
1885
+ if type == 0:
1886
+ if r:
1887
+ return mpf_div(mpf_exp(w, wp),
1888
+ from_man_exp(r, -wp), prec, rnd)
1889
+ return mpf_exp(w, prec, rnd)
1890
+ if type == 2:
1891
+ if r:
1892
+ return mpf_div(from_man_exp(r, -wp),
1893
+ mpf_exp(w, wp), prec, rnd)
1894
+ return mpf_exp(mpf_neg(w), prec, rnd)
1895
+ if type == 3:
1896
+ if r:
1897
+ return mpf_sub(w, mpf_log(from_man_exp(r,-wp), wp), prec, rnd)
1898
+ return mpf_pos(w, prec, rnd)
1899
+
1900
+
1901
+ def mpc_gamma(z, prec, rnd='d', type=0):
1902
+ a, b = z
1903
+ asign, aman, aexp, abc = a
1904
+ bsign, bman, bexp, bbc = b
1905
+
1906
+ if b == fzero:
1907
+ # Imaginary part on negative half-axis for log-gamma function
1908
+ if type == 3 and asign:
1909
+ re = mpf_gamma(a, prec, rnd, 3)
1910
+ n = (-aman) >> (-aexp)
1911
+ im = mpf_mul_int(mpf_pi(prec+10), n, prec, rnd)
1912
+ return re, im
1913
+ return mpf_gamma(a, prec, rnd, type), fzero
1914
+
1915
+ # Some kind of complex inf/nan
1916
+ if (not aman and aexp) or (not bman and bexp):
1917
+ return (fnan, fnan)
1918
+
1919
+ # Initial working precision
1920
+ wp = prec + 20
1921
+
1922
+ amag = aexp+abc
1923
+ bmag = bexp+bbc
1924
+ if aman:
1925
+ mag = max(amag, bmag)
1926
+ else:
1927
+ mag = bmag
1928
+
1929
+ # Close to 0
1930
+ if mag < -8:
1931
+ if mag < -wp:
1932
+ # 1/gamma(z) = z + euler*z^2 + O(z^3)
1933
+ v = mpc_add(z, mpc_mul_mpf(mpc_mul(z,z,wp),mpf_euler(wp),wp), wp)
1934
+ if type == 0: return mpc_reciprocal(v, prec, rnd)
1935
+ if type == 1: return mpc_div(z, v, prec, rnd)
1936
+ if type == 2: return mpc_pos(v, prec, rnd)
1937
+ if type == 3: return mpc_log(mpc_reciprocal(v, prec), prec, rnd)
1938
+ elif type != 1:
1939
+ wp += (-mag)
1940
+
1941
+ # Handle huge log-gamma values; must do this before converting to
1942
+ # a fixed-point value. TODO: determine a precise cutoff of validity
1943
+ # depending on amag and bmag
1944
+ if type == 3 and mag > wp and ((not asign) or (bmag >= amag)):
1945
+ return mpc_sub(mpc_mul(z, mpc_log(z, wp), wp), z, prec, rnd)
1946
+
1947
+ # From now on, we assume having a gamma function
1948
+ if type == 1:
1949
+ return mpc_gamma((mpf_add(a, fone), b), prec, rnd, 0)
1950
+
1951
+ an = abs(to_int(a))
1952
+ bn = abs(to_int(b))
1953
+ absn = max(an, bn)
1954
+ gamma_size = absn*mag
1955
+ if type == 3:
1956
+ pass
1957
+ else:
1958
+ wp += bitcount(gamma_size)
1959
+
1960
+ # Reflect to the right half-plane. Note that Stirling's expansion
1961
+ # is valid in the left half-plane too, as long as we're not too close
1962
+ # to the real axis, but in order to use this argument reduction
1963
+ # in the negative direction must be implemented.
1964
+ #need_reflection = asign and ((bmag < 0) or (amag-bmag > 4))
1965
+ need_reflection = asign
1966
+ zorig = z
1967
+ if need_reflection:
1968
+ z = mpc_neg(z)
1969
+ asign, aman, aexp, abc = a = z[0]
1970
+ bsign, bman, bexp, bbc = b = z[1]
1971
+
1972
+ # Imaginary part very small compared to real one?
1973
+ yfinal = 0
1974
+ balance_prec = 0
1975
+ if bmag < -10:
1976
+ # Check z ~= 1 and z ~= 2 for loggamma
1977
+ if type == 3:
1978
+ zsub1 = mpc_sub_mpf(z, fone)
1979
+ if zsub1[0] == fzero:
1980
+ cancel1 = -bmag
1981
+ else:
1982
+ cancel1 = -max(zsub1[0][2]+zsub1[0][3], bmag)
1983
+ if cancel1 > wp:
1984
+ pi = mpf_pi(wp)
1985
+ x = mpc_mul_mpf(zsub1, pi, wp)
1986
+ x = mpc_mul(x, x, wp)
1987
+ x = mpc_div_mpf(x, from_int(12), wp)
1988
+ y = mpc_mul_mpf(zsub1, mpf_neg(mpf_euler(wp)), wp)
1989
+ yfinal = mpc_add(x, y, wp)
1990
+ if not need_reflection:
1991
+ return mpc_pos(yfinal, prec, rnd)
1992
+ elif cancel1 > 0:
1993
+ wp += cancel1
1994
+ zsub2 = mpc_sub_mpf(z, ftwo)
1995
+ if zsub2[0] == fzero:
1996
+ cancel2 = -bmag
1997
+ else:
1998
+ cancel2 = -max(zsub2[0][2]+zsub2[0][3], bmag)
1999
+ if cancel2 > wp:
2000
+ pi = mpf_pi(wp)
2001
+ t = mpf_sub(mpf_mul(pi, pi), from_int(6))
2002
+ x = mpc_mul_mpf(mpc_mul(zsub2, zsub2, wp), t, wp)
2003
+ x = mpc_div_mpf(x, from_int(12), wp)
2004
+ y = mpc_mul_mpf(zsub2, mpf_sub(fone, mpf_euler(wp)), wp)
2005
+ yfinal = mpc_add(x, y, wp)
2006
+ if not need_reflection:
2007
+ return mpc_pos(yfinal, prec, rnd)
2008
+ elif cancel2 > 0:
2009
+ wp += cancel2
2010
+ if bmag < -wp:
2011
+ # Compute directly from the real gamma function.
2012
+ pp = 2*(wp+10)
2013
+ aabs = mpf_abs(a)
2014
+ eps = mpf_shift(fone, amag-wp)
2015
+ x1 = mpf_gamma(aabs, pp, type=type)
2016
+ x2 = mpf_gamma(mpf_add(aabs, eps), pp, type=type)
2017
+ xprime = mpf_div(mpf_sub(x2, x1, pp), eps, pp)
2018
+ y = mpf_mul(b, xprime, prec, rnd)
2019
+ yfinal = (x1, y)
2020
+ # Note: we still need to use the reflection formula for
2021
+ # near-poles, and the correct branch of the log-gamma function
2022
+ if not need_reflection:
2023
+ return mpc_pos(yfinal, prec, rnd)
2024
+ else:
2025
+ balance_prec += (-bmag)
2026
+
2027
+ wp += balance_prec
2028
+ n_for_stirling = int(GAMMA_STIRLING_BETA*wp)
2029
+ need_reduction = absn < n_for_stirling
2030
+
2031
+ afix = to_fixed(a, wp)
2032
+ bfix = to_fixed(b, wp)
2033
+
2034
+ r = 0
2035
+ if not yfinal:
2036
+ zprered = z
2037
+ # Argument reduction
2038
+ if absn < n_for_stirling:
2039
+ absn = complex(an, bn)
2040
+ d = int((1 + n_for_stirling**2 - bn**2)**0.5 - an)
2041
+ rre = one = MPZ_ONE << wp
2042
+ rim = MPZ_ZERO
2043
+ for k in xrange(d):
2044
+ rre, rim = ((afix*rre-bfix*rim)>>wp), ((afix*rim + bfix*rre)>>wp)
2045
+ afix += one
2046
+ r = from_man_exp(rre, -wp), from_man_exp(rim, -wp)
2047
+ a = from_man_exp(afix, -wp)
2048
+ z = a, b
2049
+
2050
+ yre, yim = complex_stirling_series(afix, bfix, wp)
2051
+ # (z-1/2)*log(z) + S
2052
+ lre, lim = mpc_log(z, wp)
2053
+ lre = to_fixed(lre, wp)
2054
+ lim = to_fixed(lim, wp)
2055
+ yre = ((lre*afix - lim*bfix)>>wp) - (lre>>1) + yre
2056
+ yim = ((lre*bfix + lim*afix)>>wp) - (lim>>1) + yim
2057
+ y = from_man_exp(yre, -wp), from_man_exp(yim, -wp)
2058
+
2059
+ if r and type == 3:
2060
+ # If re(z) > 0 and abs(z) <= 4, the branches of loggamma(z)
2061
+ # and log(gamma(z)) coincide. Otherwise, use the zeroth order
2062
+ # Stirling expansion to compute the correct imaginary part.
2063
+ y = mpc_sub(y, mpc_log(r, wp), wp)
2064
+ zfa = to_float(zprered[0])
2065
+ zfb = to_float(zprered[1])
2066
+ zfabs = math.hypot(zfa,zfb)
2067
+ #if not (zfa > 0.0 and zfabs <= 4):
2068
+ yfb = to_float(y[1])
2069
+ u = math.atan2(zfb, zfa)
2070
+ if zfabs <= 0.5:
2071
+ gi = 0.577216*zfb - u
2072
+ else:
2073
+ gi = -zfb - 0.5*u + zfa*u + zfb*math.log(zfabs)
2074
+ n = int(math.floor((gi-yfb)/(2*math.pi)+0.5))
2075
+ y = (y[0], mpf_add(y[1], mpf_mul_int(mpf_pi(wp), 2*n, wp), wp))
2076
+
2077
+ if need_reflection:
2078
+ if type == 0 or type == 2:
2079
+ A = mpc_mul(mpc_sin_pi(zorig, wp), zorig, wp)
2080
+ B = (mpf_neg(mpf_pi(wp)), fzero)
2081
+ if yfinal:
2082
+ if type == 2:
2083
+ A = mpc_div(A, yfinal, wp)
2084
+ else:
2085
+ A = mpc_mul(A, yfinal, wp)
2086
+ else:
2087
+ A = mpc_mul(A, mpc_exp(y, wp), wp)
2088
+ if r:
2089
+ B = mpc_mul(B, r, wp)
2090
+ if type == 0: return mpc_div(B, A, prec, rnd)
2091
+ if type == 2: return mpc_div(A, B, prec, rnd)
2092
+
2093
+ # Reflection formula for the log-gamma function with correct branch
2094
+ # http://functions.wolfram.com/GammaBetaErf/LogGamma/16/01/01/0006/
2095
+ # LogGamma[z] == -LogGamma[-z] - Log[-z] +
2096
+ # Sign[Im[z]] Floor[Re[z]] Pi I + Log[Pi] -
2097
+ # Log[Sin[Pi (z - Floor[Re[z]])]] -
2098
+ # Pi I (1 - Abs[Sign[Im[z]]]) Abs[Floor[Re[z]]]
2099
+ if type == 3:
2100
+ if yfinal:
2101
+ s1 = mpc_neg(yfinal)
2102
+ else:
2103
+ s1 = mpc_neg(y)
2104
+ # s -= log(-z)
2105
+ s1 = mpc_sub(s1, mpc_log(mpc_neg(zorig), wp), wp)
2106
+ # floor(re(z))
2107
+ rezfloor = mpf_floor(zorig[0])
2108
+ imzsign = mpf_sign(zorig[1])
2109
+ pi = mpf_pi(wp)
2110
+ t = mpf_mul(pi, rezfloor)
2111
+ t = mpf_mul_int(t, imzsign, wp)
2112
+ s1 = (s1[0], mpf_add(s1[1], t, wp))
2113
+ s1 = mpc_add_mpf(s1, mpf_log(pi, wp), wp)
2114
+ t = mpc_sin_pi(mpc_sub_mpf(zorig, rezfloor), wp)
2115
+ t = mpc_log(t, wp)
2116
+ s1 = mpc_sub(s1, t, wp)
2117
+ # Note: may actually be unused, because we fall back
2118
+ # to the mpf_ function for real arguments
2119
+ if not imzsign:
2120
+ t = mpf_mul(pi, mpf_floor(rezfloor), wp)
2121
+ s1 = (s1[0], mpf_sub(s1[1], t, wp))
2122
+ return mpc_pos(s1, prec, rnd)
2123
+ else:
2124
+ if type == 0:
2125
+ if r:
2126
+ return mpc_div(mpc_exp(y, wp), r, prec, rnd)
2127
+ return mpc_exp(y, prec, rnd)
2128
+ if type == 2:
2129
+ if r:
2130
+ return mpc_div(r, mpc_exp(y, wp), prec, rnd)
2131
+ return mpc_exp(mpc_neg(y), prec, rnd)
2132
+ if type == 3:
2133
+ return mpc_pos(y, prec, rnd)
2134
+
2135
+ def mpf_factorial(x, prec, rnd='d'):
2136
+ return mpf_gamma(x, prec, rnd, 1)
2137
+
2138
+ def mpc_factorial(x, prec, rnd='d'):
2139
+ return mpc_gamma(x, prec, rnd, 1)
2140
+
2141
+ def mpf_rgamma(x, prec, rnd='d'):
2142
+ return mpf_gamma(x, prec, rnd, 2)
2143
+
2144
+ def mpc_rgamma(x, prec, rnd='d'):
2145
+ return mpc_gamma(x, prec, rnd, 2)
2146
+
2147
+ def mpf_loggamma(x, prec, rnd='d'):
2148
+ sign, man, exp, bc = x
2149
+ if sign:
2150
+ raise ComplexResult
2151
+ return mpf_gamma(x, prec, rnd, 3)
2152
+
2153
+ def mpc_loggamma(z, prec, rnd='d'):
2154
+ a, b = z
2155
+ asign, aman, aexp, abc = a
2156
+ bsign, bman, bexp, bbc = b
2157
+ if b == fzero and asign:
2158
+ re = mpf_gamma(a, prec, rnd, 3)
2159
+ n = (-aman) >> (-aexp)
2160
+ im = mpf_mul_int(mpf_pi(prec+10), n, prec, rnd)
2161
+ return re, im
2162
+ return mpc_gamma(z, prec, rnd, 3)
2163
+
2164
+ def mpf_gamma_int(n, prec, rnd=round_fast):
2165
+ if n < SMALL_FACTORIAL_CACHE_SIZE:
2166
+ return mpf_pos(small_factorial_cache[n-1], prec, rnd)
2167
+ return mpf_gamma(from_int(n), prec, rnd)
lib/python3.11/site-packages/mpmath/libmp/libelefun.py ADDED
@@ -0,0 +1,1428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements computation of elementary transcendental
3
+ functions (powers, logarithms, trigonometric and hyperbolic
4
+ functions, inverse trigonometric and hyperbolic) for real
5
+ floating-point numbers.
6
+
7
+ For complex and interval implementations of the same functions,
8
+ see libmpc and libmpi.
9
+
10
+ """
11
+
12
+ import math
13
+ from bisect import bisect
14
+
15
+ from .backend import xrange
16
+ from .backend import MPZ, MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_FIVE, BACKEND
17
+
18
+ from .libmpf import (
19
+ round_floor, round_ceiling, round_down, round_up,
20
+ round_nearest, round_fast,
21
+ ComplexResult,
22
+ bitcount, bctable, lshift, rshift, giant_steps, sqrt_fixed,
23
+ from_int, to_int, from_man_exp, to_fixed, to_float, from_float,
24
+ from_rational, normalize,
25
+ fzero, fone, fnone, fhalf, finf, fninf, fnan,
26
+ mpf_cmp, mpf_sign, mpf_abs,
27
+ mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul, mpf_div, mpf_shift,
28
+ mpf_rdiv_int, mpf_pow_int, mpf_sqrt,
29
+ reciprocal_rnd, negative_rnd, mpf_perturb,
30
+ isqrt_fast
31
+ )
32
+
33
+ from .libintmath import ifib
34
+
35
+
36
+ #-------------------------------------------------------------------------------
37
+ # Tuning parameters
38
+ #-------------------------------------------------------------------------------
39
+
40
+ # Cutoff for computing exp from cosh+sinh. This reduces the
41
+ # number of terms by half, but also requires a square root which
42
+ # is expensive with the pure-Python square root code.
43
+ if BACKEND == 'python':
44
+ EXP_COSH_CUTOFF = 600
45
+ else:
46
+ EXP_COSH_CUTOFF = 400
47
+ # Cutoff for using more than 2 series
48
+ EXP_SERIES_U_CUTOFF = 1500
49
+
50
+ # Also basically determined by sqrt
51
+ if BACKEND == 'python':
52
+ COS_SIN_CACHE_PREC = 400
53
+ else:
54
+ COS_SIN_CACHE_PREC = 200
55
+ COS_SIN_CACHE_STEP = 8
56
+ cos_sin_cache = {}
57
+
58
+ # Number of integer logarithms to cache (for zeta sums)
59
+ MAX_LOG_INT_CACHE = 2000
60
+ log_int_cache = {}
61
+
62
+ LOG_TAYLOR_PREC = 2500 # Use Taylor series with caching up to this prec
63
+ LOG_TAYLOR_SHIFT = 9 # Cache log values in steps of size 2^-N
64
+ log_taylor_cache = {}
65
+ # prec/size ratio of x for fastest convergence in AGM formula
66
+ LOG_AGM_MAG_PREC_RATIO = 20
67
+
68
+ ATAN_TAYLOR_PREC = 3000 # Same as for log
69
+ ATAN_TAYLOR_SHIFT = 7 # steps of size 2^-N
70
+ atan_taylor_cache = {}
71
+
72
+
73
+ # ~= next power of two + 20
74
+ cache_prec_steps = [22,22]
75
+ for k in xrange(1, bitcount(LOG_TAYLOR_PREC)+1):
76
+ cache_prec_steps += [min(2**k,LOG_TAYLOR_PREC)+20] * 2**(k-1)
77
+
78
+
79
+ #----------------------------------------------------------------------------#
80
+ # #
81
+ # Elementary mathematical constants #
82
+ # #
83
+ #----------------------------------------------------------------------------#
84
+
85
+ def constant_memo(f):
86
+ """
87
+ Decorator for caching computed values of mathematical
88
+ constants. This decorator should be applied to a
89
+ function taking a single argument prec as input and
90
+ returning a fixed-point value with the given precision.
91
+ """
92
+ f.memo_prec = -1
93
+ f.memo_val = None
94
+ def g(prec, **kwargs):
95
+ memo_prec = f.memo_prec
96
+ if prec <= memo_prec:
97
+ return f.memo_val >> (memo_prec-prec)
98
+ newprec = int(prec*1.05+10)
99
+ f.memo_val = f(newprec, **kwargs)
100
+ f.memo_prec = newprec
101
+ return f.memo_val >> (newprec-prec)
102
+ g.__name__ = f.__name__
103
+ g.__doc__ = f.__doc__
104
+ return g
105
+
106
+ def def_mpf_constant(fixed):
107
+ """
108
+ Create a function that computes the mpf value for a mathematical
109
+ constant, given a function that computes the fixed-point value.
110
+
111
+ Assumptions: the constant is positive and has magnitude ~= 1;
112
+ the fixed-point function rounds to floor.
113
+ """
114
+ def f(prec, rnd=round_fast):
115
+ wp = prec + 20
116
+ v = fixed(wp)
117
+ if rnd in (round_up, round_ceiling):
118
+ v += 1
119
+ return normalize(0, v, -wp, bitcount(v), prec, rnd)
120
+ f.__doc__ = fixed.__doc__
121
+ return f
122
+
123
+ def bsp_acot(q, a, b, hyperbolic):
124
+ if b - a == 1:
125
+ a1 = MPZ(2*a + 3)
126
+ if hyperbolic or a&1:
127
+ return MPZ_ONE, a1 * q**2, a1
128
+ else:
129
+ return -MPZ_ONE, a1 * q**2, a1
130
+ m = (a+b)//2
131
+ p1, q1, r1 = bsp_acot(q, a, m, hyperbolic)
132
+ p2, q2, r2 = bsp_acot(q, m, b, hyperbolic)
133
+ return q2*p1 + r1*p2, q1*q2, r1*r2
134
+
135
+ # the acoth(x) series converges like the geometric series for x^2
136
+ # N = ceil(p*log(2)/(2*log(x)))
137
+ def acot_fixed(a, prec, hyperbolic):
138
+ """
139
+ Compute acot(a) or acoth(a) for an integer a with binary splitting; see
140
+ http://numbers.computation.free.fr/Constants/Algorithms/splitting.html
141
+ """
142
+ N = int(0.35 * prec/math.log(a) + 20)
143
+ p, q, r = bsp_acot(a, 0,N, hyperbolic)
144
+ return ((p+q)<<prec)//(q*a)
145
+
146
+ def machin(coefs, prec, hyperbolic=False):
147
+ """
148
+ Evaluate a Machin-like formula, i.e., a linear combination of
149
+ acot(n) or acoth(n) for specific integer values of n, using fixed-
150
+ point arithmetic. The input should be a list [(c, n), ...], giving
151
+ c*acot[h](n) + ...
152
+ """
153
+ extraprec = 10
154
+ s = MPZ_ZERO
155
+ for a, b in coefs:
156
+ s += MPZ(a) * acot_fixed(MPZ(b), prec+extraprec, hyperbolic)
157
+ return (s >> extraprec)
158
+
159
+ # Logarithms of integers are needed for various computations involving
160
+ # logarithms, powers, radix conversion, etc
161
+
162
+ @constant_memo
163
+ def ln2_fixed(prec):
164
+ """
165
+ Computes ln(2). This is done with a hyperbolic Machin-type formula,
166
+ with binary splitting at high precision.
167
+ """
168
+ return machin([(18, 26), (-2, 4801), (8, 8749)], prec, True)
169
+
170
+ @constant_memo
171
+ def ln10_fixed(prec):
172
+ """
173
+ Computes ln(10). This is done with a hyperbolic Machin-type formula.
174
+ """
175
+ return machin([(46, 31), (34, 49), (20, 161)], prec, True)
176
+
177
+
178
+ r"""
179
+ For computation of pi, we use the Chudnovsky series:
180
+
181
+ oo
182
+ ___ k
183
+ 1 \ (-1) (6 k)! (A + B k)
184
+ ----- = ) -----------------------
185
+ 12 pi /___ 3 3k+3/2
186
+ (3 k)! (k!) C
187
+ k = 0
188
+
189
+ where A, B, and C are certain integer constants. This series adds roughly
190
+ 14 digits per term. Note that C^(3/2) can be extracted so that the
191
+ series contains only rational terms. This makes binary splitting very
192
+ efficient.
193
+
194
+ The recurrence formulas for the binary splitting were taken from
195
+ ftp://ftp.gmplib.org/pub/src/gmp-chudnovsky.c
196
+
197
+ Previously, Machin's formula was used at low precision and the AGM iteration
198
+ was used at high precision. However, the Chudnovsky series is essentially as
199
+ fast as the Machin formula at low precision and in practice about 3x faster
200
+ than the AGM at high precision (despite theoretically having a worse
201
+ asymptotic complexity), so there is no reason not to use it in all cases.
202
+
203
+ """
204
+
205
+ # Constants in Chudnovsky's series
206
+ CHUD_A = MPZ(13591409)
207
+ CHUD_B = MPZ(545140134)
208
+ CHUD_C = MPZ(640320)
209
+ CHUD_D = MPZ(12)
210
+
211
+ def bs_chudnovsky(a, b, level, verbose):
212
+ """
213
+ Computes the sum from a to b of the series in the Chudnovsky
214
+ formula. Returns g, p, q where p/q is the sum as an exact
215
+ fraction and g is a temporary value used to save work
216
+ for recursive calls.
217
+ """
218
+ if b-a == 1:
219
+ g = MPZ((6*b-5)*(2*b-1)*(6*b-1))
220
+ p = b**3 * CHUD_C**3 // 24
221
+ q = (-1)**b * g * (CHUD_A+CHUD_B*b)
222
+ else:
223
+ if verbose and level < 4:
224
+ print(" binary splitting", a, b)
225
+ mid = (a+b)//2
226
+ g1, p1, q1 = bs_chudnovsky(a, mid, level+1, verbose)
227
+ g2, p2, q2 = bs_chudnovsky(mid, b, level+1, verbose)
228
+ p = p1*p2
229
+ g = g1*g2
230
+ q = q1*p2 + q2*g1
231
+ return g, p, q
232
+
233
+ @constant_memo
234
+ def pi_fixed(prec, verbose=False, verbose_base=None):
235
+ """
236
+ Compute floor(pi * 2**prec) as a big integer.
237
+
238
+ This is done using Chudnovsky's series (see comments in
239
+ libelefun.py for details).
240
+ """
241
+ # The Chudnovsky series gives 14.18 digits per term
242
+ N = int(prec/3.3219280948/14.181647462 + 2)
243
+ if verbose:
244
+ print("binary splitting with N =", N)
245
+ g, p, q = bs_chudnovsky(0, N, 0, verbose)
246
+ sqrtC = isqrt_fast(CHUD_C<<(2*prec))
247
+ v = p*CHUD_C*sqrtC//((q+CHUD_A*p)*CHUD_D)
248
+ return v
249
+
250
+ def degree_fixed(prec):
251
+ return pi_fixed(prec)//180
252
+
253
+ def bspe(a, b):
254
+ """
255
+ Sum series for exp(1)-1 between a, b, returning the result
256
+ as an exact fraction (p, q).
257
+ """
258
+ if b-a == 1:
259
+ return MPZ_ONE, MPZ(b)
260
+ m = (a+b)//2
261
+ p1, q1 = bspe(a, m)
262
+ p2, q2 = bspe(m, b)
263
+ return p1*q2+p2, q1*q2
264
+
265
+ @constant_memo
266
+ def e_fixed(prec):
267
+ """
268
+ Computes exp(1). This is done using the ordinary Taylor series for
269
+ exp, with binary splitting. For a description of the algorithm,
270
+ see:
271
+
272
+ http://numbers.computation.free.fr/Constants/
273
+ Algorithms/splitting.html
274
+ """
275
+ # Slight overestimate of N needed for 1/N! < 2**(-prec)
276
+ # This could be tightened for large N.
277
+ N = int(1.1*prec/math.log(prec) + 20)
278
+ p, q = bspe(0,N)
279
+ return ((p+q)<<prec)//q
280
+
281
+ @constant_memo
282
+ def phi_fixed(prec):
283
+ """
284
+ Computes the golden ratio, (1+sqrt(5))/2
285
+ """
286
+ prec += 10
287
+ a = isqrt_fast(MPZ_FIVE<<(2*prec)) + (MPZ_ONE << prec)
288
+ return a >> 11
289
+
290
+ mpf_phi = def_mpf_constant(phi_fixed)
291
+ mpf_pi = def_mpf_constant(pi_fixed)
292
+ mpf_e = def_mpf_constant(e_fixed)
293
+ mpf_degree = def_mpf_constant(degree_fixed)
294
+ mpf_ln2 = def_mpf_constant(ln2_fixed)
295
+ mpf_ln10 = def_mpf_constant(ln10_fixed)
296
+
297
+
298
+ @constant_memo
299
+ def ln_sqrt2pi_fixed(prec):
300
+ wp = prec + 10
301
+ # ln(sqrt(2*pi)) = ln(2*pi)/2
302
+ return to_fixed(mpf_log(mpf_shift(mpf_pi(wp), 1), wp), prec-1)
303
+
304
+ @constant_memo
305
+ def sqrtpi_fixed(prec):
306
+ return sqrt_fixed(pi_fixed(prec), prec)
307
+
308
+ mpf_sqrtpi = def_mpf_constant(sqrtpi_fixed)
309
+ mpf_ln_sqrt2pi = def_mpf_constant(ln_sqrt2pi_fixed)
310
+
311
+
312
+ #----------------------------------------------------------------------------#
313
+ # #
314
+ # Powers #
315
+ # #
316
+ #----------------------------------------------------------------------------#
317
+
318
+ def mpf_pow(s, t, prec, rnd=round_fast):
319
+ """
320
+ Compute s**t. Raises ComplexResult if s is negative and t is
321
+ fractional.
322
+ """
323
+ ssign, sman, sexp, sbc = s
324
+ tsign, tman, texp, tbc = t
325
+ if ssign and texp < 0:
326
+ raise ComplexResult("negative number raised to a fractional power")
327
+ if texp >= 0:
328
+ return mpf_pow_int(s, (-1)**tsign * (tman<<texp), prec, rnd)
329
+ # s**(n/2) = sqrt(s)**n
330
+ if texp == -1:
331
+ if tman == 1:
332
+ if tsign:
333
+ return mpf_div(fone, mpf_sqrt(s, prec+10,
334
+ reciprocal_rnd[rnd]), prec, rnd)
335
+ return mpf_sqrt(s, prec, rnd)
336
+ else:
337
+ if tsign:
338
+ return mpf_pow_int(mpf_sqrt(s, prec+10,
339
+ reciprocal_rnd[rnd]), -tman, prec, rnd)
340
+ return mpf_pow_int(mpf_sqrt(s, prec+10, rnd), tman, prec, rnd)
341
+ # General formula: s**t = exp(t*log(s))
342
+ # TODO: handle rnd direction of the logarithm carefully
343
+ c = mpf_log(s, prec+10, rnd)
344
+ return mpf_exp(mpf_mul(t, c), prec, rnd)
345
+
346
+ def int_pow_fixed(y, n, prec):
347
+ """n-th power of a fixed point number with precision prec
348
+
349
+ Returns the power in the form man, exp,
350
+ man * 2**exp ~= y**n
351
+ """
352
+ if n == 2:
353
+ return (y*y), 0
354
+ bc = bitcount(y)
355
+ exp = 0
356
+ workprec = 2 * (prec + 4*bitcount(n) + 4)
357
+ _, pm, pe, pbc = fone
358
+ while 1:
359
+ if n & 1:
360
+ pm = pm*y
361
+ pe = pe+exp
362
+ pbc += bc - 2
363
+ pbc = pbc + bctable[int(pm >> pbc)]
364
+ if pbc > workprec:
365
+ pm = pm >> (pbc-workprec)
366
+ pe += pbc - workprec
367
+ pbc = workprec
368
+ n -= 1
369
+ if not n:
370
+ break
371
+ y = y*y
372
+ exp = exp+exp
373
+ bc = bc + bc - 2
374
+ bc = bc + bctable[int(y >> bc)]
375
+ if bc > workprec:
376
+ y = y >> (bc-workprec)
377
+ exp += bc - workprec
378
+ bc = workprec
379
+ n = n // 2
380
+ return pm, pe
381
+
382
+ # froot(s, n, prec, rnd) computes the real n-th root of a
383
+ # positive mpf tuple s.
384
+ # To compute the root we start from a 50-bit estimate for r
385
+ # generated with ordinary floating-point arithmetic, and then refine
386
+ # the value to full accuracy using the iteration
387
+
388
+ # 1 / y \
389
+ # r = --- | (n-1) * r + ---------- |
390
+ # n+1 n \ n r_n**(n-1) /
391
+
392
+ # which is simply Newton's method applied to the equation r**n = y.
393
+ # With giant_steps(start, prec+extra) = [p0,...,pm, prec+extra]
394
+ # and y = man * 2**-shift one has
395
+ # (man * 2**exp)**(1/n) =
396
+ # y**(1/n) * 2**(start-prec/n) * 2**(p0-start) * ... * 2**(prec+extra-pm) *
397
+ # 2**((exp+shift-(n-1)*prec)/n -extra))
398
+ # The last factor is accounted for in the last line of froot.
399
+
400
+ def nthroot_fixed(y, n, prec, exp1):
401
+ start = 50
402
+ try:
403
+ y1 = rshift(y, prec - n*start)
404
+ r = MPZ(int(y1**(1.0/n)))
405
+ except OverflowError:
406
+ y1 = from_int(y1, start)
407
+ fn = from_int(n)
408
+ fn = mpf_rdiv_int(1, fn, start)
409
+ r = mpf_pow(y1, fn, start)
410
+ r = to_int(r)
411
+ extra = 10
412
+ extra1 = n
413
+ prevp = start
414
+ for p in giant_steps(start, prec+extra):
415
+ pm, pe = int_pow_fixed(r, n-1, prevp)
416
+ r2 = rshift(pm, (n-1)*prevp - p - pe - extra1)
417
+ B = lshift(y, 2*p-prec+extra1)//r2
418
+ r = (B + (n-1) * lshift(r, p-prevp))//n
419
+ prevp = p
420
+ return r
421
+
422
+ def mpf_nthroot(s, n, prec, rnd=round_fast):
423
+ """nth-root of a positive number
424
+
425
+ Use the Newton method when faster, otherwise use x**(1/n)
426
+ """
427
+ sign, man, exp, bc = s
428
+ if sign:
429
+ raise ComplexResult("nth root of a negative number")
430
+ if not man:
431
+ if s == fnan:
432
+ return fnan
433
+ if s == fzero:
434
+ if n > 0:
435
+ return fzero
436
+ if n == 0:
437
+ return fone
438
+ return finf
439
+ # Infinity
440
+ if not n:
441
+ return fnan
442
+ if n < 0:
443
+ return fzero
444
+ return finf
445
+ flag_inverse = False
446
+ if n < 2:
447
+ if n == 0:
448
+ return fone
449
+ if n == 1:
450
+ return mpf_pos(s, prec, rnd)
451
+ if n == -1:
452
+ return mpf_div(fone, s, prec, rnd)
453
+ # n < 0
454
+ rnd = reciprocal_rnd[rnd]
455
+ flag_inverse = True
456
+ extra_inverse = 5
457
+ prec += extra_inverse
458
+ n = -n
459
+ if n > 20 and (n >= 20000 or prec < int(233 + 28.3 * n**0.62)):
460
+ prec2 = prec + 10
461
+ fn = from_int(n)
462
+ nth = mpf_rdiv_int(1, fn, prec2)
463
+ r = mpf_pow(s, nth, prec2, rnd)
464
+ s = normalize(r[0], r[1], r[2], r[3], prec, rnd)
465
+ if flag_inverse:
466
+ return mpf_div(fone, s, prec-extra_inverse, rnd)
467
+ else:
468
+ return s
469
+ # Convert to a fixed-point number with prec2 bits.
470
+ prec2 = prec + 2*n - (prec%n)
471
+ # a few tests indicate that
472
+ # for 10 < n < 10**4 a bit more precision is needed
473
+ if n > 10:
474
+ prec2 += prec2//10
475
+ prec2 = prec2 - prec2%n
476
+ # Mantissa may have more bits than we need. Trim it down.
477
+ shift = bc - prec2
478
+ # Adjust exponents to make prec2 and exp+shift multiples of n.
479
+ sign1 = 0
480
+ es = exp+shift
481
+ if es < 0:
482
+ sign1 = 1
483
+ es = -es
484
+ if sign1:
485
+ shift += es%n
486
+ else:
487
+ shift -= es%n
488
+ man = rshift(man, shift)
489
+ extra = 10
490
+ exp1 = ((exp+shift-(n-1)*prec2)//n) - extra
491
+ rnd_shift = 0
492
+ if flag_inverse:
493
+ if rnd == 'u' or rnd == 'c':
494
+ rnd_shift = 1
495
+ else:
496
+ if rnd == 'd' or rnd == 'f':
497
+ rnd_shift = 1
498
+ man = nthroot_fixed(man+rnd_shift, n, prec2, exp1)
499
+ s = from_man_exp(man, exp1, prec, rnd)
500
+ if flag_inverse:
501
+ return mpf_div(fone, s, prec-extra_inverse, rnd)
502
+ else:
503
+ return s
504
+
505
+ def mpf_cbrt(s, prec, rnd=round_fast):
506
+ """cubic root of a positive number"""
507
+ return mpf_nthroot(s, 3, prec, rnd)
508
+
509
+ #----------------------------------------------------------------------------#
510
+ # #
511
+ # Logarithms #
512
+ # #
513
+ #----------------------------------------------------------------------------#
514
+
515
+
516
+ def log_int_fixed(n, prec, ln2=None):
517
+ """
518
+ Fast computation of log(n), caching the value for small n,
519
+ intended for zeta sums.
520
+ """
521
+ if n in log_int_cache:
522
+ value, vprec = log_int_cache[n]
523
+ if vprec >= prec:
524
+ return value >> (vprec - prec)
525
+ wp = prec + 10
526
+ if wp <= LOG_TAYLOR_SHIFT:
527
+ if ln2 is None:
528
+ ln2 = ln2_fixed(wp)
529
+ r = bitcount(n)
530
+ x = n << (wp-r)
531
+ v = log_taylor_cached(x, wp) + r*ln2
532
+ else:
533
+ v = to_fixed(mpf_log(from_int(n), wp+5), wp)
534
+ if n < MAX_LOG_INT_CACHE:
535
+ log_int_cache[n] = (v, wp)
536
+ return v >> (wp-prec)
537
+
538
+ def agm_fixed(a, b, prec):
539
+ """
540
+ Fixed-point computation of agm(a,b), assuming
541
+ a, b both close to unit magnitude.
542
+ """
543
+ i = 0
544
+ while 1:
545
+ anew = (a+b)>>1
546
+ if i > 4 and abs(a-anew) < 8:
547
+ return a
548
+ b = isqrt_fast(a*b)
549
+ a = anew
550
+ i += 1
551
+ return a
552
+
553
+ def log_agm(x, prec):
554
+ """
555
+ Fixed-point computation of -log(x) = log(1/x), suitable
556
+ for large precision. It is required that 0 < x < 1. The
557
+ algorithm used is the Sasaki-Kanada formula
558
+
559
+ -log(x) = pi/agm(theta2(x)^2,theta3(x)^2). [1]
560
+
561
+ For faster convergence in the theta functions, x should
562
+ be chosen closer to 0.
563
+
564
+ Guard bits must be added by the caller.
565
+
566
+ HYPOTHESIS: if x = 2^(-n), n bits need to be added to
567
+ account for the truncation to a fixed-point number,
568
+ and this is the only significant cancellation error.
569
+
570
+ The number of bits lost to roundoff is small and can be
571
+ considered constant.
572
+
573
+ [1] Richard P. Brent, "Fast Algorithms for High-Precision
574
+ Computation of Elementary Functions (extended abstract)",
575
+ http://wwwmaths.anu.edu.au/~brent/pd/RNC7-Brent.pdf
576
+
577
+ """
578
+ x2 = (x*x) >> prec
579
+ # Compute jtheta2(x)**2
580
+ s = a = b = x2
581
+ while a:
582
+ b = (b*x2) >> prec
583
+ a = (a*b) >> prec
584
+ s += a
585
+ s += (MPZ_ONE<<prec)
586
+ s = (s*s)>>(prec-2)
587
+ s = (s*isqrt_fast(x<<prec))>>prec
588
+ # Compute jtheta3(x)**2
589
+ t = a = b = x
590
+ while a:
591
+ b = (b*x2) >> prec
592
+ a = (a*b) >> prec
593
+ t += a
594
+ t = (MPZ_ONE<<prec) + (t<<1)
595
+ t = (t*t)>>prec
596
+ # Final formula
597
+ p = agm_fixed(s, t, prec)
598
+ return (pi_fixed(prec) << prec) // p
599
+
600
+ def log_taylor(x, prec, r=0):
601
+ """
602
+ Fixed-point calculation of log(x). It is assumed that x is close
603
+ enough to 1 for the Taylor series to converge quickly. Convergence
604
+ can be improved by specifying r > 0 to compute
605
+ log(x^(1/2^r))*2^r, at the cost of performing r square roots.
606
+
607
+ The caller must provide sufficient guard bits.
608
+ """
609
+ for i in xrange(r):
610
+ x = isqrt_fast(x<<prec)
611
+ one = MPZ_ONE << prec
612
+ v = ((x-one)<<prec)//(x+one)
613
+ sign = v < 0
614
+ if sign:
615
+ v = -v
616
+ v2 = (v*v) >> prec
617
+ v4 = (v2*v2) >> prec
618
+ s0 = v
619
+ s1 = v//3
620
+ v = (v*v4) >> prec
621
+ k = 5
622
+ while v:
623
+ s0 += v // k
624
+ k += 2
625
+ s1 += v // k
626
+ v = (v*v4) >> prec
627
+ k += 2
628
+ s1 = (s1*v2) >> prec
629
+ s = (s0+s1) << (1+r)
630
+ if sign:
631
+ return -s
632
+ return s
633
+
634
+ def log_taylor_cached(x, prec):
635
+ """
636
+ Fixed-point computation of log(x), assuming x in (0.5, 2)
637
+ and prec <= LOG_TAYLOR_PREC.
638
+ """
639
+ n = x >> (prec-LOG_TAYLOR_SHIFT)
640
+ cached_prec = cache_prec_steps[prec]
641
+ dprec = cached_prec - prec
642
+ if (n, cached_prec) in log_taylor_cache:
643
+ a, log_a = log_taylor_cache[n, cached_prec]
644
+ else:
645
+ a = n << (cached_prec - LOG_TAYLOR_SHIFT)
646
+ log_a = log_taylor(a, cached_prec, 8)
647
+ log_taylor_cache[n, cached_prec] = (a, log_a)
648
+ a >>= dprec
649
+ log_a >>= dprec
650
+ u = ((x - a) << prec) // a
651
+ v = (u << prec) // ((MPZ_TWO << prec) + u)
652
+ v2 = (v*v) >> prec
653
+ v4 = (v2*v2) >> prec
654
+ s0 = v
655
+ s1 = v//3
656
+ v = (v*v4) >> prec
657
+ k = 5
658
+ while v:
659
+ s0 += v//k
660
+ k += 2
661
+ s1 += v//k
662
+ v = (v*v4) >> prec
663
+ k += 2
664
+ s1 = (s1*v2) >> prec
665
+ s = (s0+s1) << 1
666
+ return log_a + s
667
+
668
+ def mpf_log(x, prec, rnd=round_fast):
669
+ """
670
+ Compute the natural logarithm of the mpf value x. If x is negative,
671
+ ComplexResult is raised.
672
+ """
673
+ sign, man, exp, bc = x
674
+ #------------------------------------------------------------------
675
+ # Handle special values
676
+ if not man:
677
+ if x == fzero: return fninf
678
+ if x == finf: return finf
679
+ if x == fnan: return fnan
680
+ if sign:
681
+ raise ComplexResult("logarithm of a negative number")
682
+ wp = prec + 20
683
+ #------------------------------------------------------------------
684
+ # Handle log(2^n) = log(n)*2.
685
+ # Here we catch the only possible exact value, log(1) = 0
686
+ if man == 1:
687
+ if not exp:
688
+ return fzero
689
+ return from_man_exp(exp*ln2_fixed(wp), -wp, prec, rnd)
690
+ mag = exp+bc
691
+ abs_mag = abs(mag)
692
+ #------------------------------------------------------------------
693
+ # Handle x = 1+eps, where log(x) ~ x. We need to check for
694
+ # cancellation when moving to fixed-point math and compensate
695
+ # by increasing the precision. Note that abs_mag in (0, 1) <=>
696
+ # 0.5 < x < 2 and x != 1
697
+ if abs_mag <= 1:
698
+ # Calculate t = x-1 to measure distance from 1 in bits
699
+ tsign = 1-abs_mag
700
+ if tsign:
701
+ tman = (MPZ_ONE<<bc) - man
702
+ else:
703
+ tman = man - (MPZ_ONE<<(bc-1))
704
+ tbc = bitcount(tman)
705
+ cancellation = bc - tbc
706
+ if cancellation > wp:
707
+ t = normalize(tsign, tman, abs_mag-bc, tbc, tbc, 'n')
708
+ return mpf_perturb(t, tsign, prec, rnd)
709
+ else:
710
+ wp += cancellation
711
+ # TODO: if close enough to 1, we could use Taylor series
712
+ # even in the AGM precision range, since the Taylor series
713
+ # converges rapidly
714
+ #------------------------------------------------------------------
715
+ # Another special case:
716
+ # n*log(2) is a good enough approximation
717
+ if abs_mag > 10000:
718
+ if bitcount(abs_mag) > wp:
719
+ return from_man_exp(exp*ln2_fixed(wp), -wp, prec, rnd)
720
+ #------------------------------------------------------------------
721
+ # General case.
722
+ # Perform argument reduction using log(x) = log(x*2^n) - n*log(2):
723
+ # If we are in the Taylor precision range, choose magnitude 0 or 1.
724
+ # If we are in the AGM precision range, choose magnitude -m for
725
+ # some large m; benchmarking on one machine showed m = prec/20 to be
726
+ # optimal between 1000 and 100,000 digits.
727
+ if wp <= LOG_TAYLOR_PREC:
728
+ m = log_taylor_cached(lshift(man, wp-bc), wp)
729
+ if mag:
730
+ m += mag*ln2_fixed(wp)
731
+ else:
732
+ optimal_mag = -wp//LOG_AGM_MAG_PREC_RATIO
733
+ n = optimal_mag - mag
734
+ x = mpf_shift(x, n)
735
+ wp += (-optimal_mag)
736
+ m = -log_agm(to_fixed(x, wp), wp)
737
+ m -= n*ln2_fixed(wp)
738
+ return from_man_exp(m, -wp, prec, rnd)
739
+
740
+ def mpf_log_hypot(a, b, prec, rnd):
741
+ """
742
+ Computes log(sqrt(a^2+b^2)) accurately.
743
+ """
744
+ # If either a or b is inf/nan/0, assume it to be a
745
+ if not b[1]:
746
+ a, b = b, a
747
+ # a is inf/nan/0
748
+ if not a[1]:
749
+ # both are inf/nan/0
750
+ if not b[1]:
751
+ if a == b == fzero:
752
+ return fninf
753
+ if fnan in (a, b):
754
+ return fnan
755
+ # at least one term is (+/- inf)^2
756
+ return finf
757
+ # only a is inf/nan/0
758
+ if a == fzero:
759
+ # log(sqrt(0+b^2)) = log(|b|)
760
+ return mpf_log(mpf_abs(b), prec, rnd)
761
+ if a == fnan:
762
+ return fnan
763
+ return finf
764
+ # Exact
765
+ a2 = mpf_mul(a,a)
766
+ b2 = mpf_mul(b,b)
767
+ extra = 20
768
+ # Not exact
769
+ h2 = mpf_add(a2, b2, prec+extra)
770
+ cancelled = mpf_add(h2, fnone, 10)
771
+ mag_cancelled = cancelled[2]+cancelled[3]
772
+ # Just redo the sum exactly if necessary (could be smarter
773
+ # and avoid memory allocation when a or b is precisely 1
774
+ # and the other is tiny...)
775
+ if cancelled == fzero or mag_cancelled < -extra//2:
776
+ h2 = mpf_add(a2, b2, prec+extra-min(a2[2],b2[2]))
777
+ return mpf_shift(mpf_log(h2, prec, rnd), -1)
778
+
779
+
780
+ #----------------------------------------------------------------------
781
+ # Inverse tangent
782
+ #
783
+
784
+ def atan_newton(x, prec):
785
+ if prec >= 100:
786
+ r = math.atan(int((x>>(prec-53)))/2.0**53)
787
+ else:
788
+ r = math.atan(int(x)/2.0**prec)
789
+ prevp = 50
790
+ r = MPZ(int(r * 2.0**53) >> (53-prevp))
791
+ extra_p = 50
792
+ for wp in giant_steps(prevp, prec):
793
+ wp += extra_p
794
+ r = r << (wp-prevp)
795
+ cos, sin = cos_sin_fixed(r, wp)
796
+ tan = (sin << wp) // cos
797
+ a = ((tan-rshift(x, prec-wp)) << wp) // ((MPZ_ONE<<wp) + ((tan**2)>>wp))
798
+ r = r - a
799
+ prevp = wp
800
+ return rshift(r, prevp-prec)
801
+
802
+ def atan_taylor_get_cached(n, prec):
803
+ # Taylor series with caching wins up to huge precisions
804
+ # To avoid unnecessary precomputation at low precision, we
805
+ # do it in steps
806
+ # Round to next power of 2
807
+ prec2 = (1<<(bitcount(prec-1))) + 20
808
+ dprec = prec2 - prec
809
+ if (n, prec2) in atan_taylor_cache:
810
+ a, atan_a = atan_taylor_cache[n, prec2]
811
+ else:
812
+ a = n << (prec2 - ATAN_TAYLOR_SHIFT)
813
+ atan_a = atan_newton(a, prec2)
814
+ atan_taylor_cache[n, prec2] = (a, atan_a)
815
+ return (a >> dprec), (atan_a >> dprec)
816
+
817
+ def atan_taylor(x, prec):
818
+ n = (x >> (prec-ATAN_TAYLOR_SHIFT))
819
+ a, atan_a = atan_taylor_get_cached(n, prec)
820
+ d = x - a
821
+ s0 = v = (d << prec) // ((a**2 >> prec) + (a*d >> prec) + (MPZ_ONE << prec))
822
+ v2 = (v**2 >> prec)
823
+ v4 = (v2 * v2) >> prec
824
+ s1 = v//3
825
+ v = (v * v4) >> prec
826
+ k = 5
827
+ while v:
828
+ s0 += v // k
829
+ k += 2
830
+ s1 += v // k
831
+ v = (v * v4) >> prec
832
+ k += 2
833
+ s1 = (s1 * v2) >> prec
834
+ s = s0 - s1
835
+ return atan_a + s
836
+
837
+ def atan_inf(sign, prec, rnd):
838
+ if not sign:
839
+ return mpf_shift(mpf_pi(prec, rnd), -1)
840
+ return mpf_neg(mpf_shift(mpf_pi(prec, negative_rnd[rnd]), -1))
841
+
842
+ def mpf_atan(x, prec, rnd=round_fast):
843
+ sign, man, exp, bc = x
844
+ if not man:
845
+ if x == fzero: return fzero
846
+ if x == finf: return atan_inf(0, prec, rnd)
847
+ if x == fninf: return atan_inf(1, prec, rnd)
848
+ return fnan
849
+ mag = exp + bc
850
+ # Essentially infinity
851
+ if mag > prec+20:
852
+ return atan_inf(sign, prec, rnd)
853
+ # Essentially ~ x
854
+ if -mag > prec+20:
855
+ return mpf_perturb(x, 1-sign, prec, rnd)
856
+ wp = prec + 30 + abs(mag)
857
+ # For large x, use atan(x) = pi/2 - atan(1/x)
858
+ if mag >= 2:
859
+ x = mpf_rdiv_int(1, x, wp)
860
+ reciprocal = True
861
+ else:
862
+ reciprocal = False
863
+ t = to_fixed(x, wp)
864
+ if sign:
865
+ t = -t
866
+ if wp < ATAN_TAYLOR_PREC:
867
+ a = atan_taylor(t, wp)
868
+ else:
869
+ a = atan_newton(t, wp)
870
+ if reciprocal:
871
+ a = ((pi_fixed(wp)>>1)+1) - a
872
+ if sign:
873
+ a = -a
874
+ return from_man_exp(a, -wp, prec, rnd)
875
+
876
+ # TODO: cleanup the special cases
877
+ def mpf_atan2(y, x, prec, rnd=round_fast):
878
+ xsign, xman, xexp, xbc = x
879
+ ysign, yman, yexp, ybc = y
880
+ if not yman:
881
+ if y == fzero and x != fnan:
882
+ if mpf_sign(x) >= 0:
883
+ return fzero
884
+ return mpf_pi(prec, rnd)
885
+ if y in (finf, fninf):
886
+ if x in (finf, fninf):
887
+ return fnan
888
+ # pi/2
889
+ if y == finf:
890
+ return mpf_shift(mpf_pi(prec, rnd), -1)
891
+ # -pi/2
892
+ return mpf_neg(mpf_shift(mpf_pi(prec, negative_rnd[rnd]), -1))
893
+ return fnan
894
+ if ysign:
895
+ return mpf_neg(mpf_atan2(mpf_neg(y), x, prec, negative_rnd[rnd]))
896
+ if not xman:
897
+ if x == fnan:
898
+ return fnan
899
+ if x == finf:
900
+ return fzero
901
+ if x == fninf:
902
+ return mpf_pi(prec, rnd)
903
+ if y == fzero:
904
+ return fzero
905
+ return mpf_shift(mpf_pi(prec, rnd), -1)
906
+ tquo = mpf_atan(mpf_div(y, x, prec+4), prec+4)
907
+ if xsign:
908
+ return mpf_add(mpf_pi(prec+4), tquo, prec, rnd)
909
+ else:
910
+ return mpf_pos(tquo, prec, rnd)
911
+
912
+ def mpf_asin(x, prec, rnd=round_fast):
913
+ sign, man, exp, bc = x
914
+ if bc+exp > 0 and x not in (fone, fnone):
915
+ raise ComplexResult("asin(x) is real only for -1 <= x <= 1")
916
+ # asin(x) = 2*atan(x/(1+sqrt(1-x**2)))
917
+ wp = prec + 15
918
+ a = mpf_mul(x, x)
919
+ b = mpf_add(fone, mpf_sqrt(mpf_sub(fone, a, wp), wp), wp)
920
+ c = mpf_div(x, b, wp)
921
+ return mpf_shift(mpf_atan(c, prec, rnd), 1)
922
+
923
+ def mpf_acos(x, prec, rnd=round_fast):
924
+ # acos(x) = 2*atan(sqrt(1-x**2)/(1+x))
925
+ sign, man, exp, bc = x
926
+ if bc + exp > 0:
927
+ if x not in (fone, fnone):
928
+ raise ComplexResult("acos(x) is real only for -1 <= x <= 1")
929
+ if x == fnone:
930
+ return mpf_pi(prec, rnd)
931
+ wp = prec + 15
932
+ a = mpf_mul(x, x)
933
+ b = mpf_sqrt(mpf_sub(fone, a, wp), wp)
934
+ c = mpf_div(b, mpf_add(fone, x, wp), wp)
935
+ return mpf_shift(mpf_atan(c, prec, rnd), 1)
936
+
937
+ def mpf_asinh(x, prec, rnd=round_fast):
938
+ wp = prec + 20
939
+ sign, man, exp, bc = x
940
+ mag = exp+bc
941
+ if mag < -8:
942
+ if mag < -wp:
943
+ return mpf_perturb(x, 1-sign, prec, rnd)
944
+ wp += (-mag)
945
+ # asinh(x) = log(x+sqrt(x**2+1))
946
+ # use reflection symmetry to avoid cancellation
947
+ q = mpf_sqrt(mpf_add(mpf_mul(x, x), fone, wp), wp)
948
+ q = mpf_add(mpf_abs(x), q, wp)
949
+ if sign:
950
+ return mpf_neg(mpf_log(q, prec, negative_rnd[rnd]))
951
+ else:
952
+ return mpf_log(q, prec, rnd)
953
+
954
+ def mpf_acosh(x, prec, rnd=round_fast):
955
+ # acosh(x) = log(x+sqrt(x**2-1))
956
+ wp = prec + 15
957
+ if mpf_cmp(x, fone) == -1:
958
+ raise ComplexResult("acosh(x) is real only for x >= 1")
959
+ q = mpf_sqrt(mpf_add(mpf_mul(x,x), fnone, wp), wp)
960
+ return mpf_log(mpf_add(x, q, wp), prec, rnd)
961
+
962
+ def mpf_atanh(x, prec, rnd=round_fast):
963
+ # atanh(x) = log((1+x)/(1-x))/2
964
+ sign, man, exp, bc = x
965
+ if (not man) and exp:
966
+ if x in (fzero, fnan):
967
+ return x
968
+ raise ComplexResult("atanh(x) is real only for -1 <= x <= 1")
969
+ mag = bc + exp
970
+ if mag > 0:
971
+ if mag == 1 and man == 1:
972
+ return [finf, fninf][sign]
973
+ raise ComplexResult("atanh(x) is real only for -1 <= x <= 1")
974
+ wp = prec + 15
975
+ if mag < -8:
976
+ if mag < -wp:
977
+ return mpf_perturb(x, sign, prec, rnd)
978
+ wp += (-mag)
979
+ a = mpf_add(x, fone, wp)
980
+ b = mpf_sub(fone, x, wp)
981
+ return mpf_shift(mpf_log(mpf_div(a, b, wp), prec, rnd), -1)
982
+
983
+ def mpf_fibonacci(x, prec, rnd=round_fast):
984
+ sign, man, exp, bc = x
985
+ if not man:
986
+ if x == fninf:
987
+ return fnan
988
+ return x
989
+ # F(2^n) ~= 2^(2^n)
990
+ size = abs(exp+bc)
991
+ if exp >= 0:
992
+ # Exact
993
+ if size < 10 or size <= bitcount(prec):
994
+ return from_int(ifib(to_int(x)), prec, rnd)
995
+ # Use the modified Binet formula
996
+ wp = prec + size + 20
997
+ a = mpf_phi(wp)
998
+ b = mpf_add(mpf_shift(a, 1), fnone, wp)
999
+ u = mpf_pow(a, x, wp)
1000
+ v = mpf_cos_pi(x, wp)
1001
+ v = mpf_div(v, u, wp)
1002
+ u = mpf_sub(u, v, wp)
1003
+ u = mpf_div(u, b, prec, rnd)
1004
+ return u
1005
+
1006
+
1007
+ #-------------------------------------------------------------------------------
1008
+ # Exponential-type functions
1009
+ #-------------------------------------------------------------------------------
1010
+
1011
+ def exponential_series(x, prec, type=0):
1012
+ """
1013
+ Taylor series for cosh/sinh or cos/sin.
1014
+
1015
+ type = 0 -- returns exp(x) (slightly faster than cosh+sinh)
1016
+ type = 1 -- returns (cosh(x), sinh(x))
1017
+ type = 2 -- returns (cos(x), sin(x))
1018
+ """
1019
+ if x < 0:
1020
+ x = -x
1021
+ sign = 1
1022
+ else:
1023
+ sign = 0
1024
+ r = int(0.5*prec**0.5)
1025
+ xmag = bitcount(x) - prec
1026
+ r = max(0, xmag + r)
1027
+ extra = 10 + 2*max(r,-xmag)
1028
+ wp = prec + extra
1029
+ x <<= (extra - r)
1030
+ one = MPZ_ONE << wp
1031
+ alt = (type == 2)
1032
+ if prec < EXP_SERIES_U_CUTOFF:
1033
+ x2 = a = (x*x) >> wp
1034
+ x4 = (x2*x2) >> wp
1035
+ s0 = s1 = MPZ_ZERO
1036
+ k = 2
1037
+ while a:
1038
+ a //= (k-1)*k; s0 += a; k += 2
1039
+ a //= (k-1)*k; s1 += a; k += 2
1040
+ a = (a*x4) >> wp
1041
+ s1 = (x2*s1) >> wp
1042
+ if alt:
1043
+ c = s1 - s0 + one
1044
+ else:
1045
+ c = s1 + s0 + one
1046
+ else:
1047
+ u = int(0.3*prec**0.35)
1048
+ x2 = a = (x*x) >> wp
1049
+ xpowers = [one, x2]
1050
+ for i in xrange(1, u):
1051
+ xpowers.append((xpowers[-1]*x2)>>wp)
1052
+ sums = [MPZ_ZERO] * u
1053
+ k = 2
1054
+ while a:
1055
+ for i in xrange(u):
1056
+ a //= (k-1)*k
1057
+ if alt and k & 2: sums[i] -= a
1058
+ else: sums[i] += a
1059
+ k += 2
1060
+ a = (a*xpowers[-1]) >> wp
1061
+ for i in xrange(1, u):
1062
+ sums[i] = (sums[i]*xpowers[i]) >> wp
1063
+ c = sum(sums) + one
1064
+ if type == 0:
1065
+ s = isqrt_fast(c*c - (one<<wp))
1066
+ if sign:
1067
+ v = c - s
1068
+ else:
1069
+ v = c + s
1070
+ for i in xrange(r):
1071
+ v = (v*v) >> wp
1072
+ return v >> extra
1073
+ else:
1074
+ # Repeatedly apply the double-angle formula
1075
+ # cosh(2*x) = 2*cosh(x)^2 - 1
1076
+ # cos(2*x) = 2*cos(x)^2 - 1
1077
+ pshift = wp-1
1078
+ for i in xrange(r):
1079
+ c = ((c*c) >> pshift) - one
1080
+ # With the abs, this is the same for sinh and sin
1081
+ s = isqrt_fast(abs((one<<wp) - c*c))
1082
+ if sign:
1083
+ s = -s
1084
+ return (c>>extra), (s>>extra)
1085
+
1086
+ def exp_basecase(x, prec):
1087
+ """
1088
+ Compute exp(x) as a fixed-point number. Works for any x,
1089
+ but for speed should have |x| < 1. For an arbitrary number,
1090
+ use exp(x) = exp(x-m*log(2)) * 2^m where m = floor(x/log(2)).
1091
+ """
1092
+ if prec > EXP_COSH_CUTOFF:
1093
+ return exponential_series(x, prec, 0)
1094
+ r = int(prec**0.5)
1095
+ prec += r
1096
+ s0 = s1 = (MPZ_ONE << prec)
1097
+ k = 2
1098
+ a = x2 = (x*x) >> prec
1099
+ while a:
1100
+ a //= k; s0 += a; k += 1
1101
+ a //= k; s1 += a; k += 1
1102
+ a = (a*x2) >> prec
1103
+ s1 = (s1*x) >> prec
1104
+ s = s0 + s1
1105
+ u = r
1106
+ while r:
1107
+ s = (s*s) >> prec
1108
+ r -= 1
1109
+ return s >> u
1110
+
1111
+ def exp_expneg_basecase(x, prec):
1112
+ """
1113
+ Computation of exp(x), exp(-x)
1114
+ """
1115
+ if prec > EXP_COSH_CUTOFF:
1116
+ cosh, sinh = exponential_series(x, prec, 1)
1117
+ return cosh+sinh, cosh-sinh
1118
+ a = exp_basecase(x, prec)
1119
+ b = (MPZ_ONE << (prec+prec)) // a
1120
+ return a, b
1121
+
1122
+ def cos_sin_basecase(x, prec):
1123
+ """
1124
+ Compute cos(x), sin(x) as fixed-point numbers, assuming x
1125
+ in [0, pi/2). For an arbitrary number, use x' = x - m*(pi/2)
1126
+ where m = floor(x/(pi/2)) along with quarter-period symmetries.
1127
+ """
1128
+ if prec > COS_SIN_CACHE_PREC:
1129
+ return exponential_series(x, prec, 2)
1130
+ precs = prec - COS_SIN_CACHE_STEP
1131
+ t = x >> precs
1132
+ n = int(t)
1133
+ if n not in cos_sin_cache:
1134
+ w = t<<(10+COS_SIN_CACHE_PREC-COS_SIN_CACHE_STEP)
1135
+ cos_t, sin_t = exponential_series(w, 10+COS_SIN_CACHE_PREC, 2)
1136
+ cos_sin_cache[n] = (cos_t>>10), (sin_t>>10)
1137
+ cos_t, sin_t = cos_sin_cache[n]
1138
+ offset = COS_SIN_CACHE_PREC - prec
1139
+ cos_t >>= offset
1140
+ sin_t >>= offset
1141
+ x -= t << precs
1142
+ cos = MPZ_ONE << prec
1143
+ sin = x
1144
+ k = 2
1145
+ a = -((x*x) >> prec)
1146
+ while a:
1147
+ a //= k; cos += a; k += 1; a = (a*x) >> prec
1148
+ a //= k; sin += a; k += 1; a = -((a*x) >> prec)
1149
+ return ((cos*cos_t-sin*sin_t) >> prec), ((sin*cos_t+cos*sin_t) >> prec)
1150
+
1151
+ def mpf_exp(x, prec, rnd=round_fast):
1152
+ sign, man, exp, bc = x
1153
+ if man:
1154
+ mag = bc + exp
1155
+ wp = prec + 14
1156
+ if sign:
1157
+ man = -man
1158
+ # TODO: the best cutoff depends on both x and the precision.
1159
+ if prec > 600 and exp >= 0:
1160
+ # Need about log2(exp(n)) ~= 1.45*mag extra precision
1161
+ e = mpf_e(wp+int(1.45*mag))
1162
+ return mpf_pow_int(e, man<<exp, prec, rnd)
1163
+ if mag < -wp:
1164
+ return mpf_perturb(fone, sign, prec, rnd)
1165
+ # |x| >= 2
1166
+ if mag > 1:
1167
+ # For large arguments: exp(2^mag*(1+eps)) =
1168
+ # exp(2^mag)*exp(2^mag*eps) = exp(2^mag)*(1 + 2^mag*eps + ...)
1169
+ # so about mag extra bits is required.
1170
+ wpmod = wp + mag
1171
+ offset = exp + wpmod
1172
+ if offset >= 0:
1173
+ t = man << offset
1174
+ else:
1175
+ t = man >> (-offset)
1176
+ lg2 = ln2_fixed(wpmod)
1177
+ n, t = divmod(t, lg2)
1178
+ n = int(n)
1179
+ t >>= mag
1180
+ else:
1181
+ offset = exp + wp
1182
+ if offset >= 0:
1183
+ t = man << offset
1184
+ else:
1185
+ t = man >> (-offset)
1186
+ n = 0
1187
+ man = exp_basecase(t, wp)
1188
+ return from_man_exp(man, n-wp, prec, rnd)
1189
+ if not exp:
1190
+ return fone
1191
+ if x == fninf:
1192
+ return fzero
1193
+ return x
1194
+
1195
+
1196
+ def mpf_cosh_sinh(x, prec, rnd=round_fast, tanh=0):
1197
+ """Simultaneously compute (cosh(x), sinh(x)) for real x"""
1198
+ sign, man, exp, bc = x
1199
+ if (not man) and exp:
1200
+ if tanh:
1201
+ if x == finf: return fone
1202
+ if x == fninf: return fnone
1203
+ return fnan
1204
+ if x == finf: return (finf, finf)
1205
+ if x == fninf: return (finf, fninf)
1206
+ return fnan, fnan
1207
+ mag = exp+bc
1208
+ wp = prec+14
1209
+ if mag < -4:
1210
+ # Extremely close to 0, sinh(x) ~= x and cosh(x) ~= 1
1211
+ if mag < -wp:
1212
+ if tanh:
1213
+ return mpf_perturb(x, 1-sign, prec, rnd)
1214
+ cosh = mpf_perturb(fone, 0, prec, rnd)
1215
+ sinh = mpf_perturb(x, sign, prec, rnd)
1216
+ return cosh, sinh
1217
+ # Fix for cancellation when computing sinh
1218
+ wp += (-mag)
1219
+ # Does exp(-2*x) vanish?
1220
+ if mag > 10:
1221
+ if 3*(1<<(mag-1)) > wp:
1222
+ # XXX: rounding
1223
+ if tanh:
1224
+ return mpf_perturb([fone,fnone][sign], 1-sign, prec, rnd)
1225
+ c = s = mpf_shift(mpf_exp(mpf_abs(x), prec, rnd), -1)
1226
+ if sign:
1227
+ s = mpf_neg(s)
1228
+ return c, s
1229
+ # |x| > 1
1230
+ if mag > 1:
1231
+ wpmod = wp + mag
1232
+ offset = exp + wpmod
1233
+ if offset >= 0:
1234
+ t = man << offset
1235
+ else:
1236
+ t = man >> (-offset)
1237
+ lg2 = ln2_fixed(wpmod)
1238
+ n, t = divmod(t, lg2)
1239
+ n = int(n)
1240
+ t >>= mag
1241
+ else:
1242
+ offset = exp + wp
1243
+ if offset >= 0:
1244
+ t = man << offset
1245
+ else:
1246
+ t = man >> (-offset)
1247
+ n = 0
1248
+ a, b = exp_expneg_basecase(t, wp)
1249
+ # TODO: optimize division precision
1250
+ cosh = a + (b>>(2*n))
1251
+ sinh = a - (b>>(2*n))
1252
+ if sign:
1253
+ sinh = -sinh
1254
+ if tanh:
1255
+ man = (sinh << wp) // cosh
1256
+ return from_man_exp(man, -wp, prec, rnd)
1257
+ else:
1258
+ cosh = from_man_exp(cosh, n-wp-1, prec, rnd)
1259
+ sinh = from_man_exp(sinh, n-wp-1, prec, rnd)
1260
+ return cosh, sinh
1261
+
1262
+
1263
+ def mod_pi2(man, exp, mag, wp):
1264
+ # Reduce to standard interval
1265
+ if mag > 0:
1266
+ i = 0
1267
+ while 1:
1268
+ cancellation_prec = 20 << i
1269
+ wpmod = wp + mag + cancellation_prec
1270
+ pi2 = pi_fixed(wpmod-1)
1271
+ pi4 = pi2 >> 1
1272
+ offset = wpmod + exp
1273
+ if offset >= 0:
1274
+ t = man << offset
1275
+ else:
1276
+ t = man >> (-offset)
1277
+ n, y = divmod(t, pi2)
1278
+ if y > pi4:
1279
+ small = pi2 - y
1280
+ else:
1281
+ small = y
1282
+ if small >> (wp+mag-10):
1283
+ n = int(n)
1284
+ t = y >> mag
1285
+ wp = wpmod - mag
1286
+ break
1287
+ i += 1
1288
+ else:
1289
+ wp += (-mag)
1290
+ offset = exp + wp
1291
+ if offset >= 0:
1292
+ t = man << offset
1293
+ else:
1294
+ t = man >> (-offset)
1295
+ n = 0
1296
+ return t, n, wp
1297
+
1298
+
1299
+ def mpf_cos_sin(x, prec, rnd=round_fast, which=0, pi=False):
1300
+ """
1301
+ which:
1302
+ 0 -- return cos(x), sin(x)
1303
+ 1 -- return cos(x)
1304
+ 2 -- return sin(x)
1305
+ 3 -- return tan(x)
1306
+
1307
+ if pi=True, compute for pi*x
1308
+ """
1309
+ sign, man, exp, bc = x
1310
+ if not man:
1311
+ if exp:
1312
+ c, s = fnan, fnan
1313
+ else:
1314
+ c, s = fone, fzero
1315
+ if which == 0: return c, s
1316
+ if which == 1: return c
1317
+ if which == 2: return s
1318
+ if which == 3: return s
1319
+
1320
+ mag = bc + exp
1321
+ wp = prec + 10
1322
+
1323
+ # Extremely small?
1324
+ if mag < 0:
1325
+ if mag < -wp:
1326
+ if pi:
1327
+ x = mpf_mul(x, mpf_pi(wp))
1328
+ c = mpf_perturb(fone, 1, prec, rnd)
1329
+ s = mpf_perturb(x, 1-sign, prec, rnd)
1330
+ if which == 0: return c, s
1331
+ if which == 1: return c
1332
+ if which == 2: return s
1333
+ if which == 3: return mpf_perturb(x, sign, prec, rnd)
1334
+ if pi:
1335
+ if exp >= -1:
1336
+ if exp == -1:
1337
+ c = fzero
1338
+ s = (fone, fnone)[bool(man & 2) ^ sign]
1339
+ elif exp == 0:
1340
+ c, s = (fnone, fzero)
1341
+ else:
1342
+ c, s = (fone, fzero)
1343
+ if which == 0: return c, s
1344
+ if which == 1: return c
1345
+ if which == 2: return s
1346
+ if which == 3: return mpf_div(s, c, prec, rnd)
1347
+ # Subtract nearest half-integer (= mod by pi/2)
1348
+ n = ((man >> (-exp-2)) + 1) >> 1
1349
+ man = man - (n << (-exp-1))
1350
+ mag2 = bitcount(man) + exp
1351
+ wp = prec + 10 - mag2
1352
+ offset = exp + wp
1353
+ if offset >= 0:
1354
+ t = man << offset
1355
+ else:
1356
+ t = man >> (-offset)
1357
+ t = (t*pi_fixed(wp)) >> wp
1358
+ else:
1359
+ t, n, wp = mod_pi2(man, exp, mag, wp)
1360
+ c, s = cos_sin_basecase(t, wp)
1361
+ m = n & 3
1362
+ if m == 1: c, s = -s, c
1363
+ elif m == 2: c, s = -c, -s
1364
+ elif m == 3: c, s = s, -c
1365
+ if sign:
1366
+ s = -s
1367
+ if which == 0:
1368
+ c = from_man_exp(c, -wp, prec, rnd)
1369
+ s = from_man_exp(s, -wp, prec, rnd)
1370
+ return c, s
1371
+ if which == 1:
1372
+ return from_man_exp(c, -wp, prec, rnd)
1373
+ if which == 2:
1374
+ return from_man_exp(s, -wp, prec, rnd)
1375
+ if which == 3:
1376
+ return from_rational(s, c, prec, rnd)
1377
+
1378
+ def mpf_cos(x, prec, rnd=round_fast): return mpf_cos_sin(x, prec, rnd, 1)
1379
+ def mpf_sin(x, prec, rnd=round_fast): return mpf_cos_sin(x, prec, rnd, 2)
1380
+ def mpf_tan(x, prec, rnd=round_fast): return mpf_cos_sin(x, prec, rnd, 3)
1381
+ def mpf_cos_sin_pi(x, prec, rnd=round_fast): return mpf_cos_sin(x, prec, rnd, 0, 1)
1382
+ def mpf_cos_pi(x, prec, rnd=round_fast): return mpf_cos_sin(x, prec, rnd, 1, 1)
1383
+ def mpf_sin_pi(x, prec, rnd=round_fast): return mpf_cos_sin(x, prec, rnd, 2, 1)
1384
+ def mpf_cosh(x, prec, rnd=round_fast): return mpf_cosh_sinh(x, prec, rnd)[0]
1385
+ def mpf_sinh(x, prec, rnd=round_fast): return mpf_cosh_sinh(x, prec, rnd)[1]
1386
+ def mpf_tanh(x, prec, rnd=round_fast): return mpf_cosh_sinh(x, prec, rnd, tanh=1)
1387
+
1388
+
1389
+ # Low-overhead fixed-point versions
1390
+
1391
+ def cos_sin_fixed(x, prec, pi2=None):
1392
+ if pi2 is None:
1393
+ pi2 = pi_fixed(prec-1)
1394
+ n, t = divmod(x, pi2)
1395
+ n = int(n)
1396
+ c, s = cos_sin_basecase(t, prec)
1397
+ m = n & 3
1398
+ if m == 0: return c, s
1399
+ if m == 1: return -s, c
1400
+ if m == 2: return -c, -s
1401
+ if m == 3: return s, -c
1402
+
1403
+ def exp_fixed(x, prec, ln2=None):
1404
+ if ln2 is None:
1405
+ ln2 = ln2_fixed(prec)
1406
+ n, t = divmod(x, ln2)
1407
+ n = int(n)
1408
+ v = exp_basecase(t, prec)
1409
+ if n >= 0:
1410
+ return v << n
1411
+ else:
1412
+ return v >> (-n)
1413
+
1414
+
1415
+ if BACKEND == 'sage':
1416
+ try:
1417
+ import sage.libs.mpmath.ext_libmp as _lbmp
1418
+ mpf_sqrt = _lbmp.mpf_sqrt
1419
+ mpf_exp = _lbmp.mpf_exp
1420
+ mpf_log = _lbmp.mpf_log
1421
+ mpf_cos = _lbmp.mpf_cos
1422
+ mpf_sin = _lbmp.mpf_sin
1423
+ mpf_pow = _lbmp.mpf_pow
1424
+ exp_fixed = _lbmp.exp_fixed
1425
+ cos_sin_fixed = _lbmp.cos_sin_fixed
1426
+ log_int_fixed = _lbmp.log_int_fixed
1427
+ except (ImportError, AttributeError):
1428
+ print("Warning: Sage imports in libelefun failed")
lib/python3.11/site-packages/mpmath/libmp/libhyper.py ADDED
@@ -0,0 +1,1150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements computation of hypergeometric and related
3
+ functions. In particular, it provides code for generic summation
4
+ of hypergeometric series. Optimized versions for various special
5
+ cases are also provided.
6
+ """
7
+
8
+ import operator
9
+ import math
10
+
11
+ from .backend import MPZ_ZERO, MPZ_ONE, BACKEND, xrange, exec_
12
+
13
+ from .libintmath import gcd
14
+
15
+ from .libmpf import (\
16
+ ComplexResult, round_fast, round_nearest,
17
+ negative_rnd, bitcount, to_fixed, from_man_exp, from_int, to_int,
18
+ from_rational,
19
+ fzero, fone, fnone, ftwo, finf, fninf, fnan,
20
+ mpf_sign, mpf_add, mpf_abs, mpf_pos,
21
+ mpf_cmp, mpf_lt, mpf_le, mpf_gt, mpf_min_max,
22
+ mpf_perturb, mpf_neg, mpf_shift, mpf_sub, mpf_mul, mpf_div,
23
+ sqrt_fixed, mpf_sqrt, mpf_rdiv_int, mpf_pow_int,
24
+ to_rational,
25
+ )
26
+
27
+ from .libelefun import (\
28
+ mpf_pi, mpf_exp, mpf_log, pi_fixed, mpf_cos_sin, mpf_cos, mpf_sin,
29
+ mpf_sqrt, agm_fixed,
30
+ )
31
+
32
+ from .libmpc import (\
33
+ mpc_one, mpc_sub, mpc_mul_mpf, mpc_mul, mpc_neg, complex_int_pow,
34
+ mpc_div, mpc_add_mpf, mpc_sub_mpf,
35
+ mpc_log, mpc_add, mpc_pos, mpc_shift,
36
+ mpc_is_infnan, mpc_zero, mpc_sqrt, mpc_abs,
37
+ mpc_mpf_div, mpc_square, mpc_exp
38
+ )
39
+
40
+ from .libintmath import ifac
41
+ from .gammazeta import mpf_gamma_int, mpf_euler, euler_fixed
42
+
43
+ class NoConvergence(Exception):
44
+ pass
45
+
46
+
47
+ #-----------------------------------------------------------------------#
48
+ # #
49
+ # Generic hypergeometric series #
50
+ # #
51
+ #-----------------------------------------------------------------------#
52
+
53
+ """
54
+ TODO:
55
+
56
+ 1. proper mpq parsing
57
+ 2. imaginary z special-cased (also: rational, integer?)
58
+ 3. more clever handling of series that don't converge because of stupid
59
+ upwards rounding
60
+ 4. checking for cancellation
61
+
62
+ """
63
+
64
+ def make_hyp_summator(key):
65
+ """
66
+ Returns a function that sums a generalized hypergeometric series,
67
+ for given parameter types (integer, rational, real, complex).
68
+
69
+ """
70
+ p, q, param_types, ztype = key
71
+
72
+ pstring = "".join(param_types)
73
+ fname = "hypsum_%i_%i_%s_%s_%s" % (p, q, pstring[:p], pstring[p:], ztype)
74
+ #print "generating hypsum", fname
75
+
76
+ have_complex_param = 'C' in param_types
77
+ have_complex_arg = ztype == 'C'
78
+ have_complex = have_complex_param or have_complex_arg
79
+
80
+ source = []
81
+ add = source.append
82
+
83
+ aint = []
84
+ arat = []
85
+ bint = []
86
+ brat = []
87
+ areal = []
88
+ breal = []
89
+ acomplex = []
90
+ bcomplex = []
91
+
92
+ #add("wp = prec + 40")
93
+ add("MAX = kwargs.get('maxterms', wp*100)")
94
+ add("HIGH = MPZ_ONE<<epsshift")
95
+ add("LOW = -HIGH")
96
+
97
+ # Setup code
98
+ add("SRE = PRE = one = (MPZ_ONE << wp)")
99
+ if have_complex:
100
+ add("SIM = PIM = MPZ_ZERO")
101
+
102
+ if have_complex_arg:
103
+ add("xsign, xm, xe, xbc = z[0]")
104
+ add("if xsign: xm = -xm")
105
+ add("ysign, ym, ye, ybc = z[1]")
106
+ add("if ysign: ym = -ym")
107
+ else:
108
+ add("xsign, xm, xe, xbc = z")
109
+ add("if xsign: xm = -xm")
110
+
111
+ add("offset = xe + wp")
112
+ add("if offset >= 0:")
113
+ add(" ZRE = xm << offset")
114
+ add("else:")
115
+ add(" ZRE = xm >> (-offset)")
116
+ if have_complex_arg:
117
+ add("offset = ye + wp")
118
+ add("if offset >= 0:")
119
+ add(" ZIM = ym << offset")
120
+ add("else:")
121
+ add(" ZIM = ym >> (-offset)")
122
+
123
+ for i, flag in enumerate(param_types):
124
+ W = ["A", "B"][i >= p]
125
+ if flag == 'Z':
126
+ ([aint,bint][i >= p]).append(i)
127
+ add("%sINT_%i = coeffs[%i]" % (W, i, i))
128
+ elif flag == 'Q':
129
+ ([arat,brat][i >= p]).append(i)
130
+ add("%sP_%i, %sQ_%i = coeffs[%i]._mpq_" % (W, i, W, i, i))
131
+ elif flag == 'R':
132
+ ([areal,breal][i >= p]).append(i)
133
+ add("xsign, xm, xe, xbc = coeffs[%i]._mpf_" % i)
134
+ add("if xsign: xm = -xm")
135
+ add("offset = xe + wp")
136
+ add("if offset >= 0:")
137
+ add(" %sREAL_%i = xm << offset" % (W, i))
138
+ add("else:")
139
+ add(" %sREAL_%i = xm >> (-offset)" % (W, i))
140
+ elif flag == 'C':
141
+ ([acomplex,bcomplex][i >= p]).append(i)
142
+ add("__re, __im = coeffs[%i]._mpc_" % i)
143
+ add("xsign, xm, xe, xbc = __re")
144
+ add("if xsign: xm = -xm")
145
+ add("ysign, ym, ye, ybc = __im")
146
+ add("if ysign: ym = -ym")
147
+
148
+ add("offset = xe + wp")
149
+ add("if offset >= 0:")
150
+ add(" %sCRE_%i = xm << offset" % (W, i))
151
+ add("else:")
152
+ add(" %sCRE_%i = xm >> (-offset)" % (W, i))
153
+ add("offset = ye + wp")
154
+ add("if offset >= 0:")
155
+ add(" %sCIM_%i = ym << offset" % (W, i))
156
+ add("else:")
157
+ add(" %sCIM_%i = ym >> (-offset)" % (W, i))
158
+ else:
159
+ raise ValueError
160
+
161
+ l_areal = len(areal)
162
+ l_breal = len(breal)
163
+ cancellable_real = min(l_areal, l_breal)
164
+ noncancellable_real_num = areal[cancellable_real:]
165
+ noncancellable_real_den = breal[cancellable_real:]
166
+
167
+ # LOOP
168
+ add("for n in xrange(1,10**8):")
169
+
170
+ add(" if n in magnitude_check:")
171
+ add(" p_mag = bitcount(abs(PRE))")
172
+ if have_complex:
173
+ add(" p_mag = max(p_mag, bitcount(abs(PIM)))")
174
+ add(" magnitude_check[n] = wp-p_mag")
175
+
176
+ # Real factors
177
+ multiplier = " * ".join(["AINT_#".replace("#", str(i)) for i in aint] + \
178
+ ["AP_#".replace("#", str(i)) for i in arat] + \
179
+ ["BQ_#".replace("#", str(i)) for i in brat])
180
+
181
+ divisor = " * ".join(["BINT_#".replace("#", str(i)) for i in bint] + \
182
+ ["BP_#".replace("#", str(i)) for i in brat] + \
183
+ ["AQ_#".replace("#", str(i)) for i in arat] + ["n"])
184
+
185
+ if multiplier:
186
+ add(" mul = " + multiplier)
187
+ add(" div = " + divisor)
188
+
189
+ # Check for singular terms
190
+ add(" if not div:")
191
+ if multiplier:
192
+ add(" if not mul:")
193
+ add(" break")
194
+ add(" raise ZeroDivisionError")
195
+
196
+ # Update product
197
+ if have_complex:
198
+
199
+ # TODO: when there are several real parameters and just a few complex
200
+ # (maybe just the complex argument), we only need to do about
201
+ # half as many ops if we accumulate the real factor in a single real variable
202
+ for k in range(cancellable_real): add(" PRE = PRE * AREAL_%i // BREAL_%i" % (areal[k], breal[k]))
203
+ for i in noncancellable_real_num: add(" PRE = (PRE * AREAL_#) >> wp".replace("#", str(i)))
204
+ for i in noncancellable_real_den: add(" PRE = (PRE << wp) // BREAL_#".replace("#", str(i)))
205
+ for k in range(cancellable_real): add(" PIM = PIM * AREAL_%i // BREAL_%i" % (areal[k], breal[k]))
206
+ for i in noncancellable_real_num: add(" PIM = (PIM * AREAL_#) >> wp".replace("#", str(i)))
207
+ for i in noncancellable_real_den: add(" PIM = (PIM << wp) // BREAL_#".replace("#", str(i)))
208
+
209
+ if multiplier:
210
+ if have_complex_arg:
211
+ add(" PRE, PIM = (mul*(PRE*ZRE-PIM*ZIM))//div, (mul*(PIM*ZRE+PRE*ZIM))//div")
212
+ add(" PRE >>= wp")
213
+ add(" PIM >>= wp")
214
+ else:
215
+ add(" PRE = ((mul * PRE * ZRE) >> wp) // div")
216
+ add(" PIM = ((mul * PIM * ZRE) >> wp) // div")
217
+ else:
218
+ if have_complex_arg:
219
+ add(" PRE, PIM = (PRE*ZRE-PIM*ZIM)//div, (PIM*ZRE+PRE*ZIM)//div")
220
+ add(" PRE >>= wp")
221
+ add(" PIM >>= wp")
222
+ else:
223
+ add(" PRE = ((PRE * ZRE) >> wp) // div")
224
+ add(" PIM = ((PIM * ZRE) >> wp) // div")
225
+
226
+ for i in acomplex:
227
+ add(" PRE, PIM = PRE*ACRE_#-PIM*ACIM_#, PIM*ACRE_#+PRE*ACIM_#".replace("#", str(i)))
228
+ add(" PRE >>= wp")
229
+ add(" PIM >>= wp")
230
+
231
+ for i in bcomplex:
232
+ add(" mag = BCRE_#*BCRE_#+BCIM_#*BCIM_#".replace("#", str(i)))
233
+ add(" re = PRE*BCRE_# + PIM*BCIM_#".replace("#", str(i)))
234
+ add(" im = PIM*BCRE_# - PRE*BCIM_#".replace("#", str(i)))
235
+ add(" PRE = (re << wp) // mag".replace("#", str(i)))
236
+ add(" PIM = (im << wp) // mag".replace("#", str(i)))
237
+
238
+ else:
239
+ for k in range(cancellable_real): add(" PRE = PRE * AREAL_%i // BREAL_%i" % (areal[k], breal[k]))
240
+ for i in noncancellable_real_num: add(" PRE = (PRE * AREAL_#) >> wp".replace("#", str(i)))
241
+ for i in noncancellable_real_den: add(" PRE = (PRE << wp) // BREAL_#".replace("#", str(i)))
242
+ if multiplier:
243
+ add(" PRE = ((PRE * mul * ZRE) >> wp) // div")
244
+ else:
245
+ add(" PRE = ((PRE * ZRE) >> wp) // div")
246
+
247
+ # Add product to sum
248
+ if have_complex:
249
+ add(" SRE += PRE")
250
+ add(" SIM += PIM")
251
+ add(" if (HIGH > PRE > LOW) and (HIGH > PIM > LOW):")
252
+ add(" break")
253
+ else:
254
+ add(" SRE += PRE")
255
+ add(" if HIGH > PRE > LOW:")
256
+ add(" break")
257
+
258
+ #add(" from mpmath import nprint, log, ldexp")
259
+ #add(" nprint([n, log(abs(PRE),2), ldexp(PRE,-wp)])")
260
+
261
+ add(" if n > MAX:")
262
+ add(" raise NoConvergence('Hypergeometric series converges too slowly. Try increasing maxterms.')")
263
+
264
+ # +1 all parameters for next loop
265
+ for i in aint: add(" AINT_# += 1".replace("#", str(i)))
266
+ for i in bint: add(" BINT_# += 1".replace("#", str(i)))
267
+ for i in arat: add(" AP_# += AQ_#".replace("#", str(i)))
268
+ for i in brat: add(" BP_# += BQ_#".replace("#", str(i)))
269
+ for i in areal: add(" AREAL_# += one".replace("#", str(i)))
270
+ for i in breal: add(" BREAL_# += one".replace("#", str(i)))
271
+ for i in acomplex: add(" ACRE_# += one".replace("#", str(i)))
272
+ for i in bcomplex: add(" BCRE_# += one".replace("#", str(i)))
273
+
274
+ if have_complex:
275
+ add("a = from_man_exp(SRE, -wp, prec, 'n')")
276
+ add("b = from_man_exp(SIM, -wp, prec, 'n')")
277
+
278
+ add("if SRE:")
279
+ add(" if SIM:")
280
+ add(" magn = max(a[2]+a[3], b[2]+b[3])")
281
+ add(" else:")
282
+ add(" magn = a[2]+a[3]")
283
+ add("elif SIM:")
284
+ add(" magn = b[2]+b[3]")
285
+ add("else:")
286
+ add(" magn = -wp+1")
287
+
288
+ add("return (a, b), True, magn")
289
+ else:
290
+ add("a = from_man_exp(SRE, -wp, prec, 'n')")
291
+
292
+ add("if SRE:")
293
+ add(" magn = a[2]+a[3]")
294
+ add("else:")
295
+ add(" magn = -wp+1")
296
+
297
+ add("return a, False, magn")
298
+
299
+ source = "\n".join((" " + line) for line in source)
300
+ source = ("def %s(coeffs, z, prec, wp, epsshift, magnitude_check, **kwargs):\n" % fname) + source
301
+
302
+ namespace = {}
303
+
304
+ exec_(source, globals(), namespace)
305
+
306
+ #print source
307
+ return source, namespace[fname]
308
+
309
+
310
+ if BACKEND == 'sage':
311
+
312
+ def make_hyp_summator(key):
313
+ """
314
+ Returns a function that sums a generalized hypergeometric series,
315
+ for given parameter types (integer, rational, real, complex).
316
+ """
317
+ from sage.libs.mpmath.ext_main import hypsum_internal
318
+ p, q, param_types, ztype = key
319
+ def _hypsum(coeffs, z, prec, wp, epsshift, magnitude_check, **kwargs):
320
+ return hypsum_internal(p, q, param_types, ztype, coeffs, z,
321
+ prec, wp, epsshift, magnitude_check, kwargs)
322
+
323
+ return "(none)", _hypsum
324
+
325
+
326
+ #-----------------------------------------------------------------------#
327
+ # #
328
+ # Error functions #
329
+ # #
330
+ #-----------------------------------------------------------------------#
331
+
332
+ # TODO: mpf_erf should call mpf_erfc when appropriate (currently
333
+ # only the converse delegation is implemented)
334
+
335
+ def mpf_erf(x, prec, rnd=round_fast):
336
+ sign, man, exp, bc = x
337
+ if not man:
338
+ if x == fzero: return fzero
339
+ if x == finf: return fone
340
+ if x== fninf: return fnone
341
+ return fnan
342
+ size = exp + bc
343
+ lg = math.log
344
+ # The approximation erf(x) = 1 is accurate to > x^2 * log(e,2) bits
345
+ if size > 3 and 2*(size-1) + 0.528766 > lg(prec,2):
346
+ if sign:
347
+ return mpf_perturb(fnone, 0, prec, rnd)
348
+ else:
349
+ return mpf_perturb(fone, 1, prec, rnd)
350
+ # erf(x) ~ 2*x/sqrt(pi) close to 0
351
+ if size < -prec:
352
+ # 2*x
353
+ x = mpf_shift(x,1)
354
+ c = mpf_sqrt(mpf_pi(prec+20), prec+20)
355
+ # TODO: interval rounding
356
+ return mpf_div(x, c, prec, rnd)
357
+ wp = prec + abs(size) + 25
358
+ # Taylor series for erf, fixed-point summation
359
+ t = abs(to_fixed(x, wp))
360
+ t2 = (t*t) >> wp
361
+ s, term, k = t, 12345, 1
362
+ while term:
363
+ t = ((t * t2) >> wp) // k
364
+ term = t // (2*k+1)
365
+ if k & 1:
366
+ s -= term
367
+ else:
368
+ s += term
369
+ k += 1
370
+ s = (s << (wp+1)) // sqrt_fixed(pi_fixed(wp), wp)
371
+ if sign:
372
+ s = -s
373
+ return from_man_exp(s, -wp, prec, rnd)
374
+
375
+ # If possible, we use the asymptotic series for erfc.
376
+ # This is an alternating divergent asymptotic series, so
377
+ # the error is at most equal to the first omitted term.
378
+ # Here we check if the smallest term is small enough
379
+ # for a given x and precision
380
+ def erfc_check_series(x, prec):
381
+ n = to_int(x)
382
+ if n**2 * 1.44 > prec:
383
+ return True
384
+ return False
385
+
386
+ def mpf_erfc(x, prec, rnd=round_fast):
387
+ sign, man, exp, bc = x
388
+ if not man:
389
+ if x == fzero: return fone
390
+ if x == finf: return fzero
391
+ if x == fninf: return ftwo
392
+ return fnan
393
+ wp = prec + 20
394
+ mag = bc+exp
395
+ # Preserve full accuracy when exponent grows huge
396
+ wp += max(0, 2*mag)
397
+ regular_erf = sign or mag < 2
398
+ if regular_erf or not erfc_check_series(x, wp):
399
+ if regular_erf:
400
+ return mpf_sub(fone, mpf_erf(x, prec+10, negative_rnd[rnd]), prec, rnd)
401
+ # 1-erf(x) ~ exp(-x^2), increase prec to deal with cancellation
402
+ n = to_int(x)+1
403
+ return mpf_sub(fone, mpf_erf(x, prec + int(n**2*1.44) + 10), prec, rnd)
404
+ s = term = MPZ_ONE << wp
405
+ term_prev = 0
406
+ t = (2 * to_fixed(x, wp) ** 2) >> wp
407
+ k = 1
408
+ while 1:
409
+ term = ((term * (2*k - 1)) << wp) // t
410
+ if k > 4 and term > term_prev or not term:
411
+ break
412
+ if k & 1:
413
+ s -= term
414
+ else:
415
+ s += term
416
+ term_prev = term
417
+ #print k, to_str(from_man_exp(term, -wp, 50), 10)
418
+ k += 1
419
+ s = (s << wp) // sqrt_fixed(pi_fixed(wp), wp)
420
+ s = from_man_exp(s, -wp, wp)
421
+ z = mpf_exp(mpf_neg(mpf_mul(x,x,wp),wp),wp)
422
+ y = mpf_div(mpf_mul(z, s, wp), x, prec, rnd)
423
+ return y
424
+
425
+
426
+ #-----------------------------------------------------------------------#
427
+ # #
428
+ # Exponential integrals #
429
+ # #
430
+ #-----------------------------------------------------------------------#
431
+
432
+ def ei_taylor(x, prec):
433
+ s = t = x
434
+ k = 2
435
+ while t:
436
+ t = ((t*x) >> prec) // k
437
+ s += t // k
438
+ k += 1
439
+ return s
440
+
441
+ def complex_ei_taylor(zre, zim, prec):
442
+ _abs = abs
443
+ sre = tre = zre
444
+ sim = tim = zim
445
+ k = 2
446
+ while _abs(tre) + _abs(tim) > 5:
447
+ tre, tim = ((tre*zre-tim*zim)//k)>>prec, ((tre*zim+tim*zre)//k)>>prec
448
+ sre += tre // k
449
+ sim += tim // k
450
+ k += 1
451
+ return sre, sim
452
+
453
+ def ei_asymptotic(x, prec):
454
+ one = MPZ_ONE << prec
455
+ x = t = ((one << prec) // x)
456
+ s = one + x
457
+ k = 2
458
+ while t:
459
+ t = (k*t*x) >> prec
460
+ s += t
461
+ k += 1
462
+ return s
463
+
464
+ def complex_ei_asymptotic(zre, zim, prec):
465
+ _abs = abs
466
+ one = MPZ_ONE << prec
467
+ M = (zim*zim + zre*zre) >> prec
468
+ # 1 / z
469
+ xre = tre = (zre << prec) // M
470
+ xim = tim = ((-zim) << prec) // M
471
+ sre = one + xre
472
+ sim = xim
473
+ k = 2
474
+ while _abs(tre) + _abs(tim) > 1000:
475
+ #print tre, tim
476
+ tre, tim = ((tre*xre-tim*xim)*k)>>prec, ((tre*xim+tim*xre)*k)>>prec
477
+ sre += tre
478
+ sim += tim
479
+ k += 1
480
+ if k > prec:
481
+ raise NoConvergence
482
+ return sre, sim
483
+
484
+ def mpf_ei(x, prec, rnd=round_fast, e1=False):
485
+ if e1:
486
+ x = mpf_neg(x)
487
+ sign, man, exp, bc = x
488
+ if e1 and not sign:
489
+ if x == fzero:
490
+ return finf
491
+ raise ComplexResult("E1(x) for x < 0")
492
+ if man:
493
+ xabs = 0, man, exp, bc
494
+ xmag = exp+bc
495
+ wp = prec + 20
496
+ can_use_asymp = xmag > wp
497
+ if not can_use_asymp:
498
+ if exp >= 0:
499
+ xabsint = man << exp
500
+ else:
501
+ xabsint = man >> (-exp)
502
+ can_use_asymp = xabsint > int(wp*0.693) + 10
503
+ if can_use_asymp:
504
+ if xmag > wp:
505
+ v = fone
506
+ else:
507
+ v = from_man_exp(ei_asymptotic(to_fixed(x, wp), wp), -wp)
508
+ v = mpf_mul(v, mpf_exp(x, wp), wp)
509
+ v = mpf_div(v, x, prec, rnd)
510
+ else:
511
+ wp += 2*int(to_int(xabs))
512
+ u = to_fixed(x, wp)
513
+ v = ei_taylor(u, wp) + euler_fixed(wp)
514
+ t1 = from_man_exp(v,-wp)
515
+ t2 = mpf_log(xabs,wp)
516
+ v = mpf_add(t1, t2, prec, rnd)
517
+ else:
518
+ if x == fzero: v = fninf
519
+ elif x == finf: v = finf
520
+ elif x == fninf: v = fzero
521
+ else: v = fnan
522
+ if e1:
523
+ v = mpf_neg(v)
524
+ return v
525
+
526
+ def mpc_ei(z, prec, rnd=round_fast, e1=False):
527
+ if e1:
528
+ z = mpc_neg(z)
529
+ a, b = z
530
+ asign, aman, aexp, abc = a
531
+ bsign, bman, bexp, bbc = b
532
+ if b == fzero:
533
+ if e1:
534
+ x = mpf_neg(mpf_ei(a, prec, rnd))
535
+ if not asign:
536
+ y = mpf_neg(mpf_pi(prec, rnd))
537
+ else:
538
+ y = fzero
539
+ return x, y
540
+ else:
541
+ return mpf_ei(a, prec, rnd), fzero
542
+ if a != fzero:
543
+ if not aman or not bman:
544
+ return (fnan, fnan)
545
+ wp = prec + 40
546
+ amag = aexp+abc
547
+ bmag = bexp+bbc
548
+ zmag = max(amag, bmag)
549
+ can_use_asymp = zmag > wp
550
+ if not can_use_asymp:
551
+ zabsint = abs(to_int(a)) + abs(to_int(b))
552
+ can_use_asymp = zabsint > int(wp*0.693) + 20
553
+ try:
554
+ if can_use_asymp:
555
+ if zmag > wp:
556
+ v = fone, fzero
557
+ else:
558
+ zre = to_fixed(a, wp)
559
+ zim = to_fixed(b, wp)
560
+ vre, vim = complex_ei_asymptotic(zre, zim, wp)
561
+ v = from_man_exp(vre, -wp), from_man_exp(vim, -wp)
562
+ v = mpc_mul(v, mpc_exp(z, wp), wp)
563
+ v = mpc_div(v, z, wp)
564
+ if e1:
565
+ v = mpc_neg(v, prec, rnd)
566
+ else:
567
+ x, y = v
568
+ if bsign:
569
+ v = mpf_pos(x, prec, rnd), mpf_sub(y, mpf_pi(wp), prec, rnd)
570
+ else:
571
+ v = mpf_pos(x, prec, rnd), mpf_add(y, mpf_pi(wp), prec, rnd)
572
+ return v
573
+ except NoConvergence:
574
+ pass
575
+ #wp += 2*max(0,zmag)
576
+ wp += 2*int(to_int(mpc_abs(z, 5)))
577
+ zre = to_fixed(a, wp)
578
+ zim = to_fixed(b, wp)
579
+ vre, vim = complex_ei_taylor(zre, zim, wp)
580
+ vre += euler_fixed(wp)
581
+ v = from_man_exp(vre,-wp), from_man_exp(vim,-wp)
582
+ if e1:
583
+ u = mpc_log(mpc_neg(z),wp)
584
+ else:
585
+ u = mpc_log(z,wp)
586
+ v = mpc_add(v, u, prec, rnd)
587
+ if e1:
588
+ v = mpc_neg(v)
589
+ return v
590
+
591
+ def mpf_e1(x, prec, rnd=round_fast):
592
+ return mpf_ei(x, prec, rnd, True)
593
+
594
+ def mpc_e1(x, prec, rnd=round_fast):
595
+ return mpc_ei(x, prec, rnd, True)
596
+
597
+ def mpf_expint(n, x, prec, rnd=round_fast, gamma=False):
598
+ """
599
+ E_n(x), n an integer, x real
600
+
601
+ With gamma=True, computes Gamma(n,x) (upper incomplete gamma function)
602
+
603
+ Returns (real, None) if real, otherwise (real, imag)
604
+ The imaginary part is an optional branch cut term
605
+
606
+ """
607
+ sign, man, exp, bc = x
608
+ if not man:
609
+ if gamma:
610
+ if x == fzero:
611
+ # Actually gamma function pole
612
+ if n <= 0:
613
+ return finf, None
614
+ return mpf_gamma_int(n, prec, rnd), None
615
+ if x == finf:
616
+ return fzero, None
617
+ # TODO: could return finite imaginary value at -inf
618
+ return fnan, fnan
619
+ else:
620
+ if x == fzero:
621
+ if n > 1:
622
+ return from_rational(1, n-1, prec, rnd), None
623
+ else:
624
+ return finf, None
625
+ if x == finf:
626
+ return fzero, None
627
+ return fnan, fnan
628
+ n_orig = n
629
+ if gamma:
630
+ n = 1-n
631
+ wp = prec + 20
632
+ xmag = exp + bc
633
+ # Beware of near-poles
634
+ if xmag < -10:
635
+ raise NotImplementedError
636
+ nmag = bitcount(abs(n))
637
+ have_imag = n > 0 and sign
638
+ negx = mpf_neg(x)
639
+ # Skip series if direct convergence
640
+ if n == 0 or 2*nmag - xmag < -wp:
641
+ if gamma:
642
+ v = mpf_exp(negx, wp)
643
+ re = mpf_mul(v, mpf_pow_int(x, n_orig-1, wp), prec, rnd)
644
+ else:
645
+ v = mpf_exp(negx, wp)
646
+ re = mpf_div(v, x, prec, rnd)
647
+ else:
648
+ # Finite number of terms, or...
649
+ can_use_asymptotic_series = -3*wp < n <= 0
650
+ # ...large enough?
651
+ if not can_use_asymptotic_series:
652
+ xi = abs(to_int(x))
653
+ m = min(max(1, xi-n), 2*wp)
654
+ siz = -n*nmag + (m+n)*bitcount(abs(m+n)) - m*xmag - (144*m//100)
655
+ tol = -wp-10
656
+ can_use_asymptotic_series = siz < tol
657
+ if can_use_asymptotic_series:
658
+ r = ((-MPZ_ONE) << (wp+wp)) // to_fixed(x, wp)
659
+ m = n
660
+ t = r*m
661
+ s = MPZ_ONE << wp
662
+ while m and t:
663
+ s += t
664
+ m += 1
665
+ t = (m*r*t) >> wp
666
+ v = mpf_exp(negx, wp)
667
+ if gamma:
668
+ # ~ exp(-x) * x^(n-1) * (1 + ...)
669
+ v = mpf_mul(v, mpf_pow_int(x, n_orig-1, wp), wp)
670
+ else:
671
+ # ~ exp(-x)/x * (1 + ...)
672
+ v = mpf_div(v, x, wp)
673
+ re = mpf_mul(v, from_man_exp(s, -wp), prec, rnd)
674
+ elif n == 1:
675
+ re = mpf_neg(mpf_ei(negx, prec, rnd))
676
+ elif n > 0 and n < 3*wp:
677
+ T1 = mpf_neg(mpf_ei(negx, wp))
678
+ if gamma:
679
+ if n_orig & 1:
680
+ T1 = mpf_neg(T1)
681
+ else:
682
+ T1 = mpf_mul(T1, mpf_pow_int(negx, n-1, wp), wp)
683
+ r = t = to_fixed(x, wp)
684
+ facs = [1] * (n-1)
685
+ for k in range(1,n-1):
686
+ facs[k] = facs[k-1] * k
687
+ facs = facs[::-1]
688
+ s = facs[0] << wp
689
+ for k in range(1, n-1):
690
+ if k & 1:
691
+ s -= facs[k] * t
692
+ else:
693
+ s += facs[k] * t
694
+ t = (t*r) >> wp
695
+ T2 = from_man_exp(s, -wp, wp)
696
+ T2 = mpf_mul(T2, mpf_exp(negx, wp))
697
+ if gamma:
698
+ T2 = mpf_mul(T2, mpf_pow_int(x, n_orig, wp), wp)
699
+ R = mpf_add(T1, T2)
700
+ re = mpf_div(R, from_int(ifac(n-1)), prec, rnd)
701
+ else:
702
+ raise NotImplementedError
703
+ if have_imag:
704
+ M = from_int(-ifac(n-1))
705
+ if gamma:
706
+ im = mpf_div(mpf_pi(wp), M, prec, rnd)
707
+ if n_orig & 1:
708
+ im = mpf_neg(im)
709
+ else:
710
+ im = mpf_div(mpf_mul(mpf_pi(wp), mpf_pow_int(negx, n_orig-1, wp), wp), M, prec, rnd)
711
+ return re, im
712
+ else:
713
+ return re, None
714
+
715
+ def mpf_ci_si_taylor(x, wp, which=0):
716
+ """
717
+ 0 - Ci(x) - (euler+log(x))
718
+ 1 - Si(x)
719
+ """
720
+ x = to_fixed(x, wp)
721
+ x2 = -(x*x) >> wp
722
+ if which == 0:
723
+ s, t, k = 0, (MPZ_ONE<<wp), 2
724
+ else:
725
+ s, t, k = x, x, 3
726
+ while t:
727
+ t = (t*x2//(k*(k-1)))>>wp
728
+ s += t//k
729
+ k += 2
730
+ return from_man_exp(s, -wp)
731
+
732
+ def mpc_ci_si_taylor(re, im, wp, which=0):
733
+ # The following code is only designed for small arguments,
734
+ # and not too small arguments (for relative accuracy)
735
+ if re[1]:
736
+ mag = re[2]+re[3]
737
+ elif im[1]:
738
+ mag = im[2]+im[3]
739
+ if im[1]:
740
+ mag = max(mag, im[2]+im[3])
741
+ if mag > 2 or mag < -wp:
742
+ raise NotImplementedError
743
+ wp += (2-mag)
744
+ zre = to_fixed(re, wp)
745
+ zim = to_fixed(im, wp)
746
+ z2re = (zim*zim-zre*zre)>>wp
747
+ z2im = (-2*zre*zim)>>wp
748
+ tre = zre
749
+ tim = zim
750
+ one = MPZ_ONE<<wp
751
+ if which == 0:
752
+ sre, sim, tre, tim, k = 0, 0, (MPZ_ONE<<wp), 0, 2
753
+ else:
754
+ sre, sim, tre, tim, k = zre, zim, zre, zim, 3
755
+ while max(abs(tre), abs(tim)) > 2:
756
+ f = k*(k-1)
757
+ tre, tim = ((tre*z2re-tim*z2im)//f)>>wp, ((tre*z2im+tim*z2re)//f)>>wp
758
+ sre += tre//k
759
+ sim += tim//k
760
+ k += 2
761
+ return from_man_exp(sre, -wp), from_man_exp(sim, -wp)
762
+
763
+ def mpf_ci_si(x, prec, rnd=round_fast, which=2):
764
+ """
765
+ Calculation of Ci(x), Si(x) for real x.
766
+
767
+ which = 0 -- returns (Ci(x), -)
768
+ which = 1 -- returns (Si(x), -)
769
+ which = 2 -- returns (Ci(x), Si(x))
770
+
771
+ Note: if x < 0, Ci(x) needs an additional imaginary term, pi*i.
772
+ """
773
+ wp = prec + 20
774
+ sign, man, exp, bc = x
775
+ ci, si = None, None
776
+ if not man:
777
+ if x == fzero:
778
+ return (fninf, fzero)
779
+ if x == fnan:
780
+ return (x, x)
781
+ ci = fzero
782
+ if which != 0:
783
+ if x == finf:
784
+ si = mpf_shift(mpf_pi(prec, rnd), -1)
785
+ if x == fninf:
786
+ si = mpf_neg(mpf_shift(mpf_pi(prec, negative_rnd[rnd]), -1))
787
+ return (ci, si)
788
+ # For small x: Ci(x) ~ euler + log(x), Si(x) ~ x
789
+ mag = exp+bc
790
+ if mag < -wp:
791
+ if which != 0:
792
+ si = mpf_perturb(x, 1-sign, prec, rnd)
793
+ if which != 1:
794
+ y = mpf_euler(wp)
795
+ xabs = mpf_abs(x)
796
+ ci = mpf_add(y, mpf_log(xabs, wp), prec, rnd)
797
+ return ci, si
798
+ # For huge x: Ci(x) ~ sin(x)/x, Si(x) ~ pi/2
799
+ elif mag > wp:
800
+ if which != 0:
801
+ if sign:
802
+ si = mpf_neg(mpf_pi(prec, negative_rnd[rnd]))
803
+ else:
804
+ si = mpf_pi(prec, rnd)
805
+ si = mpf_shift(si, -1)
806
+ if which != 1:
807
+ ci = mpf_div(mpf_sin(x, wp), x, prec, rnd)
808
+ return ci, si
809
+ else:
810
+ wp += abs(mag)
811
+ # Use an asymptotic series? The smallest value of n!/x^n
812
+ # occurs for n ~ x, where the magnitude is ~ exp(-x).
813
+ asymptotic = mag-1 > math.log(wp, 2)
814
+ # Case 1: convergent series near 0
815
+ if not asymptotic:
816
+ if which != 0:
817
+ si = mpf_pos(mpf_ci_si_taylor(x, wp, 1), prec, rnd)
818
+ if which != 1:
819
+ ci = mpf_ci_si_taylor(x, wp, 0)
820
+ ci = mpf_add(ci, mpf_euler(wp), wp)
821
+ ci = mpf_add(ci, mpf_log(mpf_abs(x), wp), prec, rnd)
822
+ return ci, si
823
+ x = mpf_abs(x)
824
+ # Case 2: asymptotic series for x >> 1
825
+ xf = to_fixed(x, wp)
826
+ xr = (MPZ_ONE<<(2*wp)) // xf # 1/x
827
+ s1 = (MPZ_ONE << wp)
828
+ s2 = xr
829
+ t = xr
830
+ k = 2
831
+ while t:
832
+ t = -t
833
+ t = (t*xr*k)>>wp
834
+ k += 1
835
+ s1 += t
836
+ t = (t*xr*k)>>wp
837
+ k += 1
838
+ s2 += t
839
+ s1 = from_man_exp(s1, -wp)
840
+ s2 = from_man_exp(s2, -wp)
841
+ s1 = mpf_div(s1, x, wp)
842
+ s2 = mpf_div(s2, x, wp)
843
+ cos, sin = mpf_cos_sin(x, wp)
844
+ # Ci(x) = sin(x)*s1-cos(x)*s2
845
+ # Si(x) = pi/2-cos(x)*s1-sin(x)*s2
846
+ if which != 0:
847
+ si = mpf_add(mpf_mul(cos, s1), mpf_mul(sin, s2), wp)
848
+ si = mpf_sub(mpf_shift(mpf_pi(wp), -1), si, wp)
849
+ if sign:
850
+ si = mpf_neg(si)
851
+ si = mpf_pos(si, prec, rnd)
852
+ if which != 1:
853
+ ci = mpf_sub(mpf_mul(sin, s1), mpf_mul(cos, s2), prec, rnd)
854
+ return ci, si
855
+
856
+ def mpf_ci(x, prec, rnd=round_fast):
857
+ if mpf_sign(x) < 0:
858
+ raise ComplexResult
859
+ return mpf_ci_si(x, prec, rnd, 0)[0]
860
+
861
+ def mpf_si(x, prec, rnd=round_fast):
862
+ return mpf_ci_si(x, prec, rnd, 1)[1]
863
+
864
+ def mpc_ci(z, prec, rnd=round_fast):
865
+ re, im = z
866
+ if im == fzero:
867
+ ci = mpf_ci_si(re, prec, rnd, 0)[0]
868
+ if mpf_sign(re) < 0:
869
+ return (ci, mpf_pi(prec, rnd))
870
+ return (ci, fzero)
871
+ wp = prec + 20
872
+ cre, cim = mpc_ci_si_taylor(re, im, wp, 0)
873
+ cre = mpf_add(cre, mpf_euler(wp), wp)
874
+ ci = mpc_add((cre, cim), mpc_log(z, wp), prec, rnd)
875
+ return ci
876
+
877
+ def mpc_si(z, prec, rnd=round_fast):
878
+ re, im = z
879
+ if im == fzero:
880
+ return (mpf_ci_si(re, prec, rnd, 1)[1], fzero)
881
+ wp = prec + 20
882
+ z = mpc_ci_si_taylor(re, im, wp, 1)
883
+ return mpc_pos(z, prec, rnd)
884
+
885
+
886
+ #-----------------------------------------------------------------------#
887
+ # #
888
+ # Bessel functions #
889
+ # #
890
+ #-----------------------------------------------------------------------#
891
+
892
+ # A Bessel function of the first kind of integer order, J_n(x), is
893
+ # given by the power series
894
+
895
+ # oo
896
+ # ___ k 2 k + n
897
+ # \ (-1) / x \
898
+ # J_n(x) = ) ----------- | - |
899
+ # /___ k! (k + n)! \ 2 /
900
+ # k = 0
901
+
902
+ # Simplifying the quotient between two successive terms gives the
903
+ # ratio x^2 / (-4*k*(k+n)). Hence, we only need one full-precision
904
+ # multiplication and one division by a small integer per term.
905
+ # The complex version is very similar, the only difference being
906
+ # that the multiplication is actually 4 multiplies.
907
+
908
+ # In the general case, we have
909
+ # J_v(x) = (x/2)**v / v! * 0F1(v+1, (-1/4)*z**2)
910
+
911
+ # TODO: for extremely large x, we could use an asymptotic
912
+ # trigonometric approximation.
913
+
914
+ # TODO: recompute at higher precision if the fixed-point mantissa
915
+ # is very small
916
+
917
+ def mpf_besseljn(n, x, prec, rounding=round_fast):
918
+ prec += 50
919
+ negate = n < 0 and n & 1
920
+ mag = x[2]+x[3]
921
+ n = abs(n)
922
+ wp = prec + 20 + n*bitcount(n)
923
+ if mag < 0:
924
+ wp -= n * mag
925
+ x = to_fixed(x, wp)
926
+ x2 = (x**2) >> wp
927
+ if not n:
928
+ s = t = MPZ_ONE << wp
929
+ else:
930
+ s = t = (x**n // ifac(n)) >> ((n-1)*wp + n)
931
+ k = 1
932
+ while t:
933
+ t = ((t * x2) // (-4*k*(k+n))) >> wp
934
+ s += t
935
+ k += 1
936
+ if negate:
937
+ s = -s
938
+ return from_man_exp(s, -wp, prec, rounding)
939
+
940
+ def mpc_besseljn(n, z, prec, rounding=round_fast):
941
+ negate = n < 0 and n & 1
942
+ n = abs(n)
943
+ origprec = prec
944
+ zre, zim = z
945
+ mag = max(zre[2]+zre[3], zim[2]+zim[3])
946
+ prec += 20 + n*bitcount(n) + abs(mag)
947
+ if mag < 0:
948
+ prec -= n * mag
949
+ zre = to_fixed(zre, prec)
950
+ zim = to_fixed(zim, prec)
951
+ z2re = (zre**2 - zim**2) >> prec
952
+ z2im = (zre*zim) >> (prec-1)
953
+ if not n:
954
+ sre = tre = MPZ_ONE << prec
955
+ sim = tim = MPZ_ZERO
956
+ else:
957
+ re, im = complex_int_pow(zre, zim, n)
958
+ sre = tre = (re // ifac(n)) >> ((n-1)*prec + n)
959
+ sim = tim = (im // ifac(n)) >> ((n-1)*prec + n)
960
+ k = 1
961
+ while abs(tre) + abs(tim) > 3:
962
+ p = -4*k*(k+n)
963
+ tre, tim = tre*z2re - tim*z2im, tim*z2re + tre*z2im
964
+ tre = (tre // p) >> prec
965
+ tim = (tim // p) >> prec
966
+ sre += tre
967
+ sim += tim
968
+ k += 1
969
+ if negate:
970
+ sre = -sre
971
+ sim = -sim
972
+ re = from_man_exp(sre, -prec, origprec, rounding)
973
+ im = from_man_exp(sim, -prec, origprec, rounding)
974
+ return (re, im)
975
+
976
+ def mpf_agm(a, b, prec, rnd=round_fast):
977
+ """
978
+ Computes the arithmetic-geometric mean agm(a,b) for
979
+ nonnegative mpf values a, b.
980
+ """
981
+ asign, aman, aexp, abc = a
982
+ bsign, bman, bexp, bbc = b
983
+ if asign or bsign:
984
+ raise ComplexResult("agm of a negative number")
985
+ # Handle inf, nan or zero in either operand
986
+ if not (aman and bman):
987
+ if a == fnan or b == fnan:
988
+ return fnan
989
+ if a == finf:
990
+ if b == fzero:
991
+ return fnan
992
+ return finf
993
+ if b == finf:
994
+ if a == fzero:
995
+ return fnan
996
+ return finf
997
+ # agm(0,x) = agm(x,0) = 0
998
+ return fzero
999
+ wp = prec + 20
1000
+ amag = aexp+abc
1001
+ bmag = bexp+bbc
1002
+ mag_delta = amag - bmag
1003
+ # Reduce to roughly the same magnitude using floating-point AGM
1004
+ abs_mag_delta = abs(mag_delta)
1005
+ if abs_mag_delta > 10:
1006
+ while abs_mag_delta > 10:
1007
+ a, b = mpf_shift(mpf_add(a,b,wp),-1), \
1008
+ mpf_sqrt(mpf_mul(a,b,wp),wp)
1009
+ abs_mag_delta //= 2
1010
+ asign, aman, aexp, abc = a
1011
+ bsign, bman, bexp, bbc = b
1012
+ amag = aexp+abc
1013
+ bmag = bexp+bbc
1014
+ mag_delta = amag - bmag
1015
+ #print to_float(a), to_float(b)
1016
+ # Use agm(a,b) = agm(x*a,x*b)/x to obtain a, b ~= 1
1017
+ min_mag = min(amag,bmag)
1018
+ max_mag = max(amag,bmag)
1019
+ n = 0
1020
+ # If too small, we lose precision when going to fixed-point
1021
+ if min_mag < -8:
1022
+ n = -min_mag
1023
+ # If too large, we waste time using fixed-point with large numbers
1024
+ elif max_mag > 20:
1025
+ n = -max_mag
1026
+ if n:
1027
+ a = mpf_shift(a, n)
1028
+ b = mpf_shift(b, n)
1029
+ #print to_float(a), to_float(b)
1030
+ af = to_fixed(a, wp)
1031
+ bf = to_fixed(b, wp)
1032
+ g = agm_fixed(af, bf, wp)
1033
+ return from_man_exp(g, -wp-n, prec, rnd)
1034
+
1035
+ def mpf_agm1(a, prec, rnd=round_fast):
1036
+ """
1037
+ Computes the arithmetic-geometric mean agm(1,a) for a nonnegative
1038
+ mpf value a.
1039
+ """
1040
+ return mpf_agm(fone, a, prec, rnd)
1041
+
1042
+ def mpc_agm(a, b, prec, rnd=round_fast):
1043
+ """
1044
+ Complex AGM.
1045
+
1046
+ TODO:
1047
+ * check that convergence works as intended
1048
+ * optimize
1049
+ * select a nonarbitrary branch
1050
+ """
1051
+ if mpc_is_infnan(a) or mpc_is_infnan(b):
1052
+ return fnan, fnan
1053
+ if mpc_zero in (a, b):
1054
+ return fzero, fzero
1055
+ if mpc_neg(a) == b:
1056
+ return fzero, fzero
1057
+ wp = prec+20
1058
+ eps = mpf_shift(fone, -wp+10)
1059
+ while 1:
1060
+ a1 = mpc_shift(mpc_add(a, b, wp), -1)
1061
+ b1 = mpc_sqrt(mpc_mul(a, b, wp), wp)
1062
+ a, b = a1, b1
1063
+ size = mpf_min_max([mpc_abs(a,10), mpc_abs(b,10)])[1]
1064
+ err = mpc_abs(mpc_sub(a, b, 10), 10)
1065
+ if size == fzero or mpf_lt(err, mpf_mul(eps, size)):
1066
+ return a
1067
+
1068
+ def mpc_agm1(a, prec, rnd=round_fast):
1069
+ return mpc_agm(mpc_one, a, prec, rnd)
1070
+
1071
+ def mpf_ellipk(x, prec, rnd=round_fast):
1072
+ if not x[1]:
1073
+ if x == fzero:
1074
+ return mpf_shift(mpf_pi(prec, rnd), -1)
1075
+ if x == fninf:
1076
+ return fzero
1077
+ if x == fnan:
1078
+ return x
1079
+ if x == fone:
1080
+ return finf
1081
+ # TODO: for |x| << 1/2, one could use fall back to
1082
+ # pi/2 * hyp2f1_rat((1,2),(1,2),(1,1), x)
1083
+ wp = prec + 15
1084
+ # Use K(x) = pi/2/agm(1,a) where a = sqrt(1-x)
1085
+ # The sqrt raises ComplexResult if x > 0
1086
+ a = mpf_sqrt(mpf_sub(fone, x, wp), wp)
1087
+ v = mpf_agm1(a, wp)
1088
+ r = mpf_div(mpf_pi(wp), v, prec, rnd)
1089
+ return mpf_shift(r, -1)
1090
+
1091
+ def mpc_ellipk(z, prec, rnd=round_fast):
1092
+ re, im = z
1093
+ if im == fzero:
1094
+ if re == finf:
1095
+ return mpc_zero
1096
+ if mpf_le(re, fone):
1097
+ return mpf_ellipk(re, prec, rnd), fzero
1098
+ wp = prec + 15
1099
+ a = mpc_sqrt(mpc_sub(mpc_one, z, wp), wp)
1100
+ v = mpc_agm1(a, wp)
1101
+ r = mpc_mpf_div(mpf_pi(wp), v, prec, rnd)
1102
+ return mpc_shift(r, -1)
1103
+
1104
+ def mpf_ellipe(x, prec, rnd=round_fast):
1105
+ # http://functions.wolfram.com/EllipticIntegrals/
1106
+ # EllipticK/20/01/0001/
1107
+ # E = (1-m)*(K'(m)*2*m + K(m))
1108
+ sign, man, exp, bc = x
1109
+ if not man:
1110
+ if x == fzero:
1111
+ return mpf_shift(mpf_pi(prec, rnd), -1)
1112
+ if x == fninf:
1113
+ return finf
1114
+ if x == fnan:
1115
+ return x
1116
+ if x == finf:
1117
+ raise ComplexResult
1118
+ if x == fone:
1119
+ return fone
1120
+ wp = prec+20
1121
+ mag = exp+bc
1122
+ if mag < -wp:
1123
+ return mpf_shift(mpf_pi(prec, rnd), -1)
1124
+ # Compute a finite difference for K'
1125
+ p = max(mag, 0) - wp
1126
+ h = mpf_shift(fone, p)
1127
+ K = mpf_ellipk(x, 2*wp)
1128
+ Kh = mpf_ellipk(mpf_sub(x, h), 2*wp)
1129
+ Kdiff = mpf_shift(mpf_sub(K, Kh), -p)
1130
+ t = mpf_sub(fone, x)
1131
+ b = mpf_mul(Kdiff, mpf_shift(x,1), wp)
1132
+ return mpf_mul(t, mpf_add(K, b), prec, rnd)
1133
+
1134
+ def mpc_ellipe(z, prec, rnd=round_fast):
1135
+ re, im = z
1136
+ if im == fzero:
1137
+ if re == finf:
1138
+ return (fzero, finf)
1139
+ if mpf_le(re, fone):
1140
+ return mpf_ellipe(re, prec, rnd), fzero
1141
+ wp = prec + 15
1142
+ mag = mpc_abs(z, 1)
1143
+ p = max(mag[2]+mag[3], 0) - wp
1144
+ h = mpf_shift(fone, p)
1145
+ K = mpc_ellipk(z, 2*wp)
1146
+ Kh = mpc_ellipk(mpc_add_mpf(z, h, 2*wp), 2*wp)
1147
+ Kdiff = mpc_shift(mpc_sub(Kh, K, wp), -p)
1148
+ t = mpc_sub(mpc_one, z, wp)
1149
+ b = mpc_mul(Kdiff, mpc_shift(z,1), wp)
1150
+ return mpc_mul(t, mpc_add(K, b, wp), prec, rnd)
lib/python3.11/site-packages/mpmath/libmp/libintmath.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for integer math.
3
+
4
+ TODO: rename, cleanup, perhaps move the gmpy wrapper code
5
+ here from settings.py
6
+
7
+ """
8
+
9
+ import math
10
+ from bisect import bisect
11
+
12
+ from .backend import xrange
13
+ from .backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO
14
+
15
+ small_trailing = [0] * 256
16
+ for j in range(1,8):
17
+ small_trailing[1<<j::1<<(j+1)] = [j] * (1<<(7-j))
18
+
19
+ def giant_steps(start, target, n=2):
20
+ """
21
+ Return a list of integers ~=
22
+
23
+ [start, n*start, ..., target/n^2, target/n, target]
24
+
25
+ but conservatively rounded so that the quotient between two
26
+ successive elements is actually slightly less than n.
27
+
28
+ With n = 2, this describes suitable precision steps for a
29
+ quadratically convergent algorithm such as Newton's method;
30
+ with n = 3 steps for cubic convergence (Halley's method), etc.
31
+
32
+ >>> giant_steps(50,1000)
33
+ [66, 128, 253, 502, 1000]
34
+ >>> giant_steps(50,1000,4)
35
+ [65, 252, 1000]
36
+
37
+ """
38
+ L = [target]
39
+ while L[-1] > start*n:
40
+ L = L + [L[-1]//n + 2]
41
+ return L[::-1]
42
+
43
+ def rshift(x, n):
44
+ """For an integer x, calculate x >> n with the fastest (floor)
45
+ rounding. Unlike the plain Python expression (x >> n), n is
46
+ allowed to be negative, in which case a left shift is performed."""
47
+ if n >= 0: return x >> n
48
+ else: return x << (-n)
49
+
50
+ def lshift(x, n):
51
+ """For an integer x, calculate x << n. Unlike the plain Python
52
+ expression (x << n), n is allowed to be negative, in which case a
53
+ right shift with default (floor) rounding is performed."""
54
+ if n >= 0: return x << n
55
+ else: return x >> (-n)
56
+
57
+ if BACKEND == 'sage':
58
+ import operator
59
+ rshift = operator.rshift
60
+ lshift = operator.lshift
61
+
62
+ def python_trailing(n):
63
+ """Count the number of trailing zero bits in abs(n)."""
64
+ if not n:
65
+ return 0
66
+ low_byte = n & 0xff
67
+ if low_byte:
68
+ return small_trailing[low_byte]
69
+ t = 8
70
+ n >>= 8
71
+ while not n & 0xff:
72
+ n >>= 8
73
+ t += 8
74
+ return t + small_trailing[n & 0xff]
75
+
76
+ if BACKEND == 'gmpy':
77
+ if gmpy.version() >= '2':
78
+ def gmpy_trailing(n):
79
+ """Count the number of trailing zero bits in abs(n) using gmpy."""
80
+ if n: return MPZ(n).bit_scan1()
81
+ else: return 0
82
+ else:
83
+ def gmpy_trailing(n):
84
+ """Count the number of trailing zero bits in abs(n) using gmpy."""
85
+ if n: return MPZ(n).scan1()
86
+ else: return 0
87
+
88
+ # Small powers of 2
89
+ powers = [1<<_ for _ in range(300)]
90
+
91
+ def python_bitcount(n):
92
+ """Calculate bit size of the nonnegative integer n."""
93
+ bc = bisect(powers, n)
94
+ if bc != 300:
95
+ return bc
96
+ bc = int(math.log(n, 2)) - 4
97
+ return bc + bctable[n>>bc]
98
+
99
+ def gmpy_bitcount(n):
100
+ """Calculate bit size of the nonnegative integer n."""
101
+ if n: return MPZ(n).numdigits(2)
102
+ else: return 0
103
+
104
+ #def sage_bitcount(n):
105
+ # if n: return MPZ(n).nbits()
106
+ # else: return 0
107
+
108
+ def sage_trailing(n):
109
+ return MPZ(n).trailing_zero_bits()
110
+
111
+ if BACKEND == 'gmpy':
112
+ bitcount = gmpy_bitcount
113
+ trailing = gmpy_trailing
114
+ elif BACKEND == 'sage':
115
+ sage_bitcount = sage_utils.bitcount
116
+ bitcount = sage_bitcount
117
+ trailing = sage_trailing
118
+ else:
119
+ bitcount = python_bitcount
120
+ trailing = python_trailing
121
+
122
+ if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy):
123
+ bitcount = gmpy.bit_length
124
+
125
+ # Used to avoid slow function calls as far as possible
126
+ trailtable = [trailing(n) for n in range(256)]
127
+ bctable = [bitcount(n) for n in range(1024)]
128
+
129
+ # TODO: speed up for bases 2, 4, 8, 16, ...
130
+
131
+ def bin_to_radix(x, xbits, base, bdigits):
132
+ """Changes radix of a fixed-point number; i.e., converts
133
+ x * 2**xbits to floor(x * 10**bdigits)."""
134
+ return x * (MPZ(base)**bdigits) >> xbits
135
+
136
+ stddigits = '0123456789abcdefghijklmnopqrstuvwxyz'
137
+
138
+ def small_numeral(n, base=10, digits=stddigits):
139
+ """Return the string numeral of a positive integer in an arbitrary
140
+ base. Most efficient for small input."""
141
+ if base == 10:
142
+ return str(n)
143
+ digs = []
144
+ while n:
145
+ n, digit = divmod(n, base)
146
+ digs.append(digits[digit])
147
+ return "".join(digs[::-1])
148
+
149
+ def numeral_python(n, base=10, size=0, digits=stddigits):
150
+ """Represent the integer n as a string of digits in the given base.
151
+ Recursive division is used to make this function about 3x faster
152
+ than Python's str() for converting integers to decimal strings.
153
+
154
+ The 'size' parameters specifies the number of digits in n; this
155
+ number is only used to determine splitting points and need not be
156
+ exact."""
157
+ if n <= 0:
158
+ if not n:
159
+ return "0"
160
+ return "-" + numeral(-n, base, size, digits)
161
+ # Fast enough to do directly
162
+ if size < 250:
163
+ return small_numeral(n, base, digits)
164
+ # Divide in half
165
+ half = (size // 2) + (size & 1)
166
+ A, B = divmod(n, base**half)
167
+ ad = numeral(A, base, half, digits)
168
+ bd = numeral(B, base, half, digits).rjust(half, "0")
169
+ return ad + bd
170
+
171
+ def numeral_gmpy(n, base=10, size=0, digits=stddigits):
172
+ """Represent the integer n as a string of digits in the given base.
173
+ Recursive division is used to make this function about 3x faster
174
+ than Python's str() for converting integers to decimal strings.
175
+
176
+ The 'size' parameters specifies the number of digits in n; this
177
+ number is only used to determine splitting points and need not be
178
+ exact."""
179
+ if n < 0:
180
+ return "-" + numeral(-n, base, size, digits)
181
+ # gmpy.digits() may cause a segmentation fault when trying to convert
182
+ # extremely large values to a string. The size limit may need to be
183
+ # adjusted on some platforms, but 1500000 works on Windows and Linux.
184
+ if size < 1500000:
185
+ return gmpy.digits(n, base)
186
+ # Divide in half
187
+ half = (size // 2) + (size & 1)
188
+ A, B = divmod(n, MPZ(base)**half)
189
+ ad = numeral(A, base, half, digits)
190
+ bd = numeral(B, base, half, digits).rjust(half, "0")
191
+ return ad + bd
192
+
193
+ if BACKEND == "gmpy":
194
+ numeral = numeral_gmpy
195
+ else:
196
+ numeral = numeral_python
197
+
198
+ _1_800 = 1<<800
199
+ _1_600 = 1<<600
200
+ _1_400 = 1<<400
201
+ _1_200 = 1<<200
202
+ _1_100 = 1<<100
203
+ _1_50 = 1<<50
204
+
205
+ def isqrt_small_python(x):
206
+ """
207
+ Correctly (floor) rounded integer square root, using
208
+ division. Fast up to ~200 digits.
209
+ """
210
+ if not x:
211
+ return x
212
+ if x < _1_800:
213
+ # Exact with IEEE double precision arithmetic
214
+ if x < _1_50:
215
+ return int(x**0.5)
216
+ # Initial estimate can be any integer >= the true root; round up
217
+ r = int(x**0.5 * 1.00000000000001) + 1
218
+ else:
219
+ bc = bitcount(x)
220
+ n = bc//2
221
+ r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up
222
+ # The following iteration now precisely computes floor(sqrt(x))
223
+ # See e.g. Crandall & Pomerance, "Prime Numbers: A Computational
224
+ # Perspective"
225
+ while 1:
226
+ y = (r+x//r)>>1
227
+ if y >= r:
228
+ return r
229
+ r = y
230
+
231
+ def isqrt_fast_python(x):
232
+ """
233
+ Fast approximate integer square root, computed using division-free
234
+ Newton iteration for large x. For random integers the result is almost
235
+ always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly
236
+ 0.1% probability. If x is very close to an exact square, the answer is
237
+ 1 ulp wrong with high probability.
238
+
239
+ With 0 guard bits, the largest error over a set of 10^5 random
240
+ inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits
241
+ almost certainly guarantees a max 1 ulp error.
242
+ """
243
+ # Use direct division-based iteration if sqrt(x) < 2^400
244
+ # Assume floating-point square root accurate to within 1 ulp, then:
245
+ # 0 Newton iterations good to 52 bits
246
+ # 1 Newton iterations good to 104 bits
247
+ # 2 Newton iterations good to 208 bits
248
+ # 3 Newton iterations good to 416 bits
249
+ if x < _1_800:
250
+ y = int(x**0.5)
251
+ if x >= _1_100:
252
+ y = (y + x//y) >> 1
253
+ if x >= _1_200:
254
+ y = (y + x//y) >> 1
255
+ if x >= _1_400:
256
+ y = (y + x//y) >> 1
257
+ return y
258
+ bc = bitcount(x)
259
+ guard_bits = 10
260
+ x <<= 2*guard_bits
261
+ bc += 2*guard_bits
262
+ bc += (bc&1)
263
+ hbc = bc//2
264
+ startprec = min(50, hbc)
265
+ # Newton iteration for 1/sqrt(x), with floating-point starting value
266
+ r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5)
267
+ pp = startprec
268
+ for p in giant_steps(startprec, hbc):
269
+ # r**2, scaled from real size 2**(-bc) to 2**p
270
+ r2 = (r*r) >> (2*pp - p)
271
+ # x*r**2, scaled from real size ~1.0 to 2**p
272
+ xr2 = ((x >> (bc-p)) * r2) >> p
273
+ # New value of r, scaled from real size 2**(-bc/2) to 2**p
274
+ r = (r * ((3<<p) - xr2)) >> (pp+1)
275
+ pp = p
276
+ # (1/sqrt(x))*x = sqrt(x)
277
+ return (r*(x>>hbc)) >> (p+guard_bits)
278
+
279
+ def sqrtrem_python(x):
280
+ """Correctly rounded integer (floor) square root with remainder."""
281
+ # to check cutoff:
282
+ # plot(lambda x: timing(isqrt, 2**int(x)), [0,2000])
283
+ if x < _1_600:
284
+ y = isqrt_small_python(x)
285
+ return y, x - y*y
286
+ y = isqrt_fast_python(x) + 1
287
+ rem = x - y*y
288
+ # Correct remainder
289
+ while rem < 0:
290
+ y -= 1
291
+ rem += (1+2*y)
292
+ else:
293
+ if rem:
294
+ while rem > 2*(1+y):
295
+ y += 1
296
+ rem -= (1+2*y)
297
+ return y, rem
298
+
299
+ def isqrt_python(x):
300
+ """Integer square root with correct (floor) rounding."""
301
+ return sqrtrem_python(x)[0]
302
+
303
+ def sqrt_fixed(x, prec):
304
+ return isqrt_fast(x<<prec)
305
+
306
+ sqrt_fixed2 = sqrt_fixed
307
+
308
+ if BACKEND == 'gmpy':
309
+ if gmpy.version() >= '2':
310
+ isqrt_small = isqrt_fast = isqrt = gmpy.isqrt
311
+ sqrtrem = gmpy.isqrt_rem
312
+ else:
313
+ isqrt_small = isqrt_fast = isqrt = gmpy.sqrt
314
+ sqrtrem = gmpy.sqrtrem
315
+ elif BACKEND == 'sage':
316
+ isqrt_small = isqrt_fast = isqrt = \
317
+ getattr(sage_utils, "isqrt", lambda n: MPZ(n).isqrt())
318
+ sqrtrem = lambda n: MPZ(n).sqrtrem()
319
+ else:
320
+ isqrt_small = isqrt_small_python
321
+ isqrt_fast = isqrt_fast_python
322
+ isqrt = isqrt_python
323
+ sqrtrem = sqrtrem_python
324
+
325
+
326
+ def ifib(n, _cache={}):
327
+ """Computes the nth Fibonacci number as an integer, for
328
+ integer n."""
329
+ if n < 0:
330
+ return (-1)**(-n+1) * ifib(-n)
331
+ if n in _cache:
332
+ return _cache[n]
333
+ m = n
334
+ # Use Dijkstra's logarithmic algorithm
335
+ # The following implementation is basically equivalent to
336
+ # http://en.literateprograms.org/Fibonacci_numbers_(Scheme)
337
+ a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE
338
+ while n:
339
+ if n & 1:
340
+ aq = a*q
341
+ a, b = b*q+aq+a*p, b*p+aq
342
+ n -= 1
343
+ else:
344
+ qq = q*q
345
+ p, q = p*p+qq, qq+2*p*q
346
+ n >>= 1
347
+ if m < 250:
348
+ _cache[m] = b
349
+ return b
350
+
351
+ MAX_FACTORIAL_CACHE = 1000
352
+
353
+ def ifac(n, memo={0:1, 1:1}):
354
+ """Return n factorial (for integers n >= 0 only)."""
355
+ f = memo.get(n)
356
+ if f:
357
+ return f
358
+ k = len(memo)
359
+ p = memo[k-1]
360
+ MAX = MAX_FACTORIAL_CACHE
361
+ while k <= n:
362
+ p *= k
363
+ if k <= MAX:
364
+ memo[k] = p
365
+ k += 1
366
+ return p
367
+
368
+ def ifac2(n, memo_pair=[{0:1}, {1:1}]):
369
+ """Return n!! (double factorial), integers n >= 0 only."""
370
+ memo = memo_pair[n&1]
371
+ f = memo.get(n)
372
+ if f:
373
+ return f
374
+ k = max(memo)
375
+ p = memo[k]
376
+ MAX = MAX_FACTORIAL_CACHE
377
+ while k < n:
378
+ k += 2
379
+ p *= k
380
+ if k <= MAX:
381
+ memo[k] = p
382
+ return p
383
+
384
+ if BACKEND == 'gmpy':
385
+ ifac = gmpy.fac
386
+ elif BACKEND == 'sage':
387
+ ifac = lambda n: int(sage.factorial(n))
388
+ ifib = sage.fibonacci
389
+
390
+ def list_primes(n):
391
+ n = n + 1
392
+ sieve = list(xrange(n))
393
+ sieve[:2] = [0, 0]
394
+ for i in xrange(2, int(n**0.5)+1):
395
+ if sieve[i]:
396
+ for j in xrange(i**2, n, i):
397
+ sieve[j] = 0
398
+ return [p for p in sieve if p]
399
+
400
+ if BACKEND == 'sage':
401
+ # Note: it is *VERY* important for performance that we convert
402
+ # the list to Python ints.
403
+ def list_primes(n):
404
+ return [int(_) for _ in sage.primes(n+1)]
405
+
406
+ small_odd_primes = (3,5,7,11,13,17,19,23,29,31,37,41,43,47)
407
+ small_odd_primes_set = set(small_odd_primes)
408
+
409
+ def isprime(n):
410
+ """
411
+ Determines whether n is a prime number. A probabilistic test is
412
+ performed if n is very large. No special trick is used for detecting
413
+ perfect powers.
414
+
415
+ >>> sum(list_primes(100000))
416
+ 454396537
417
+ >>> sum(n*isprime(n) for n in range(100000))
418
+ 454396537
419
+
420
+ """
421
+ n = int(n)
422
+ if not n & 1:
423
+ return n == 2
424
+ if n < 50:
425
+ return n in small_odd_primes_set
426
+ for p in small_odd_primes:
427
+ if not n % p:
428
+ return False
429
+ m = n-1
430
+ s = trailing(m)
431
+ d = m >> s
432
+ def test(a):
433
+ x = pow(a,d,n)
434
+ if x == 1 or x == m:
435
+ return True
436
+ for r in xrange(1,s):
437
+ x = x**2 % n
438
+ if x == m:
439
+ return True
440
+ return False
441
+ # See http://primes.utm.edu/prove/prove2_3.html
442
+ if n < 1373653:
443
+ witnesses = [2,3]
444
+ elif n < 341550071728321:
445
+ witnesses = [2,3,5,7,11,13,17]
446
+ else:
447
+ witnesses = small_odd_primes
448
+ for a in witnesses:
449
+ if not test(a):
450
+ return False
451
+ return True
452
+
453
+ def moebius(n):
454
+ """
455
+ Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n`
456
+ is a product of `k` distinct primes and `mu(n) = 0` otherwise.
457
+
458
+ TODO: speed up using factorization
459
+ """
460
+ n = abs(int(n))
461
+ if n < 2:
462
+ return n
463
+ factors = []
464
+ for p in xrange(2, n+1):
465
+ if not (n % p):
466
+ if not (n % p**2):
467
+ return 0
468
+ if not sum(p % f for f in factors):
469
+ factors.append(p)
470
+ return (-1)**len(factors)
471
+
472
+ def gcd(*args):
473
+ a = 0
474
+ for b in args:
475
+ if a:
476
+ while b:
477
+ a, b = b, a % b
478
+ else:
479
+ a = b
480
+ return a
481
+
482
+
483
+ # Comment by Juan Arias de Reyna:
484
+ #
485
+ # I learn this method to compute EulerE[2n] from van de Lune.
486
+ #
487
+ # We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1)
488
+ #
489
+ # where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies
490
+ #
491
+ # a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0
492
+ #
493
+ # a(n,j) = a(n-1,j) when n+j is even
494
+ # a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd
495
+ #
496
+ #
497
+ # But we can use only one array unidimensional a(j) since to compute
498
+ # a(n,j) we only need to know a(n-1,k) where k and j are of different parity
499
+ # and we have not to conserve the used values.
500
+ #
501
+ # We cached up the values of Euler numbers to sufficiently high order.
502
+ #
503
+ # Important Observation: If we pretend to use the numbers
504
+ # EulerE[1], EulerE[2], ... , EulerE[n]
505
+ # it is convenient to compute first EulerE[n], since the algorithm
506
+ # computes first all
507
+ # the previous ones, and keeps them in the CACHE
508
+
509
+ MAX_EULER_CACHE = 500
510
+
511
+ def eulernum(m, _cache={0:MPZ_ONE}):
512
+ r"""
513
+ Computes the Euler numbers `E(n)`, which can be defined as
514
+ coefficients of the Taylor expansion of `1/cosh x`:
515
+
516
+ .. math ::
517
+
518
+ \frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n
519
+
520
+ Example::
521
+
522
+ >>> [int(eulernum(n)) for n in range(11)]
523
+ [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
524
+ >>> [int(eulernum(n)) for n in range(11)] # test cache
525
+ [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
526
+
527
+ """
528
+ # for odd m > 1, the Euler numbers are zero
529
+ if m & 1:
530
+ return MPZ_ZERO
531
+ f = _cache.get(m)
532
+ if f:
533
+ return f
534
+ MAX = MAX_EULER_CACHE
535
+ n = m
536
+ a = [MPZ(_) for _ in [0,0,1,0,0,0]]
537
+ for n in range(1, m+1):
538
+ for j in range(n+1, -1, -2):
539
+ a[j+1] = (j-1)*a[j] + (j+1)*a[j+2]
540
+ a.append(0)
541
+ suma = 0
542
+ for k in range(n+1, -1, -2):
543
+ suma += a[k+1]
544
+ if n <= MAX:
545
+ _cache[n] = ((-1)**(n//2))*(suma // 2**n)
546
+ if n == m:
547
+ return ((-1)**(n//2))*suma // 2**n
548
+
549
+ def stirling1(n, k):
550
+ """
551
+ Stirling number of the first kind.
552
+ """
553
+ if n < 0 or k < 0:
554
+ raise ValueError
555
+ if k >= n:
556
+ return MPZ(n == k)
557
+ if k < 1:
558
+ return MPZ_ZERO
559
+ L = [MPZ_ZERO] * (k+1)
560
+ L[1] = MPZ_ONE
561
+ for m in xrange(2, n+1):
562
+ for j in xrange(min(k, m), 0, -1):
563
+ L[j] = (m-1) * L[j] + L[j-1]
564
+ return (-1)**(n+k) * L[k]
565
+
566
+ def stirling2(n, k):
567
+ """
568
+ Stirling number of the second kind.
569
+ """
570
+ if n < 0 or k < 0:
571
+ raise ValueError
572
+ if k >= n:
573
+ return MPZ(n == k)
574
+ if k <= 1:
575
+ return MPZ(k == 1)
576
+ s = MPZ_ZERO
577
+ t = MPZ_ONE
578
+ for j in xrange(k+1):
579
+ if (k + j) & 1:
580
+ s -= t * MPZ(j)**n
581
+ else:
582
+ s += t * MPZ(j)**n
583
+ t = t * (k - j) // (j + 1)
584
+ return s // ifac(k)
lib/python3.11/site-packages/mpmath/libmp/libmpc.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Low-level functions for complex arithmetic.
3
+ """
4
+
5
+ import sys
6
+
7
+ from .backend import MPZ, MPZ_ZERO, MPZ_ONE, MPZ_TWO, BACKEND
8
+
9
+ from .libmpf import (\
10
+ round_floor, round_ceiling, round_down, round_up,
11
+ round_nearest, round_fast, bitcount,
12
+ bctable, normalize, normalize1, reciprocal_rnd, rshift, lshift, giant_steps,
13
+ negative_rnd,
14
+ to_str, to_fixed, from_man_exp, from_float, to_float, from_int, to_int,
15
+ fzero, fone, ftwo, fhalf, finf, fninf, fnan, fnone,
16
+ mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul,
17
+ mpf_div, mpf_mul_int, mpf_shift, mpf_sqrt, mpf_hypot,
18
+ mpf_rdiv_int, mpf_floor, mpf_ceil, mpf_nint, mpf_frac,
19
+ mpf_sign, mpf_hash,
20
+ ComplexResult
21
+ )
22
+
23
+ from .libelefun import (\
24
+ mpf_pi, mpf_exp, mpf_log, mpf_cos_sin, mpf_cosh_sinh, mpf_tan, mpf_pow_int,
25
+ mpf_log_hypot,
26
+ mpf_cos_sin_pi, mpf_phi,
27
+ mpf_cos, mpf_sin, mpf_cos_pi, mpf_sin_pi,
28
+ mpf_atan, mpf_atan2, mpf_cosh, mpf_sinh, mpf_tanh,
29
+ mpf_asin, mpf_acos, mpf_acosh, mpf_nthroot, mpf_fibonacci
30
+ )
31
+
32
+ # An mpc value is a (real, imag) tuple
33
+ mpc_one = fone, fzero
34
+ mpc_zero = fzero, fzero
35
+ mpc_two = ftwo, fzero
36
+ mpc_half = (fhalf, fzero)
37
+
38
+ _infs = (finf, fninf)
39
+ _infs_nan = (finf, fninf, fnan)
40
+
41
+ def mpc_is_inf(z):
42
+ """Check if either real or imaginary part is infinite"""
43
+ re, im = z
44
+ if re in _infs: return True
45
+ if im in _infs: return True
46
+ return False
47
+
48
+ def mpc_is_infnan(z):
49
+ """Check if either real or imaginary part is infinite or nan"""
50
+ re, im = z
51
+ if re in _infs_nan: return True
52
+ if im in _infs_nan: return True
53
+ return False
54
+
55
+ def mpc_to_str(z, dps, **kwargs):
56
+ re, im = z
57
+ rs = to_str(re, dps)
58
+ if im[0]:
59
+ return rs + " - " + to_str(mpf_neg(im), dps, **kwargs) + "j"
60
+ else:
61
+ return rs + " + " + to_str(im, dps, **kwargs) + "j"
62
+
63
+ def mpc_to_complex(z, strict=False, rnd=round_fast):
64
+ re, im = z
65
+ return complex(to_float(re, strict, rnd), to_float(im, strict, rnd))
66
+
67
+ def mpc_hash(z):
68
+ if sys.version_info >= (3, 2):
69
+ re, im = z
70
+ h = mpf_hash(re) + sys.hash_info.imag * mpf_hash(im)
71
+ # Need to reduce either module 2^32 or 2^64
72
+ h = h % (2**sys.hash_info.width)
73
+ return int(h)
74
+ else:
75
+ try:
76
+ return hash(mpc_to_complex(z, strict=True))
77
+ except OverflowError:
78
+ return hash(z)
79
+
80
+ def mpc_conjugate(z, prec, rnd=round_fast):
81
+ re, im = z
82
+ return re, mpf_neg(im, prec, rnd)
83
+
84
+ def mpc_is_nonzero(z):
85
+ return z != mpc_zero
86
+
87
+ def mpc_add(z, w, prec, rnd=round_fast):
88
+ a, b = z
89
+ c, d = w
90
+ return mpf_add(a, c, prec, rnd), mpf_add(b, d, prec, rnd)
91
+
92
+ def mpc_add_mpf(z, x, prec, rnd=round_fast):
93
+ a, b = z
94
+ return mpf_add(a, x, prec, rnd), b
95
+
96
+ def mpc_sub(z, w, prec=0, rnd=round_fast):
97
+ a, b = z
98
+ c, d = w
99
+ return mpf_sub(a, c, prec, rnd), mpf_sub(b, d, prec, rnd)
100
+
101
+ def mpc_sub_mpf(z, p, prec=0, rnd=round_fast):
102
+ a, b = z
103
+ return mpf_sub(a, p, prec, rnd), b
104
+
105
+ def mpc_pos(z, prec, rnd=round_fast):
106
+ a, b = z
107
+ return mpf_pos(a, prec, rnd), mpf_pos(b, prec, rnd)
108
+
109
+ def mpc_neg(z, prec=None, rnd=round_fast):
110
+ a, b = z
111
+ return mpf_neg(a, prec, rnd), mpf_neg(b, prec, rnd)
112
+
113
+ def mpc_shift(z, n):
114
+ a, b = z
115
+ return mpf_shift(a, n), mpf_shift(b, n)
116
+
117
+ def mpc_abs(z, prec, rnd=round_fast):
118
+ """Absolute value of a complex number, |a+bi|.
119
+ Returns an mpf value."""
120
+ a, b = z
121
+ return mpf_hypot(a, b, prec, rnd)
122
+
123
+ def mpc_arg(z, prec, rnd=round_fast):
124
+ """Argument of a complex number. Returns an mpf value."""
125
+ a, b = z
126
+ return mpf_atan2(b, a, prec, rnd)
127
+
128
+ def mpc_floor(z, prec, rnd=round_fast):
129
+ a, b = z
130
+ return mpf_floor(a, prec, rnd), mpf_floor(b, prec, rnd)
131
+
132
+ def mpc_ceil(z, prec, rnd=round_fast):
133
+ a, b = z
134
+ return mpf_ceil(a, prec, rnd), mpf_ceil(b, prec, rnd)
135
+
136
+ def mpc_nint(z, prec, rnd=round_fast):
137
+ a, b = z
138
+ return mpf_nint(a, prec, rnd), mpf_nint(b, prec, rnd)
139
+
140
+ def mpc_frac(z, prec, rnd=round_fast):
141
+ a, b = z
142
+ return mpf_frac(a, prec, rnd), mpf_frac(b, prec, rnd)
143
+
144
+
145
+ def mpc_mul(z, w, prec, rnd=round_fast):
146
+ """
147
+ Complex multiplication.
148
+
149
+ Returns the real and imaginary part of (a+bi)*(c+di), rounded to
150
+ the specified precision. The rounding mode applies to the real and
151
+ imaginary parts separately.
152
+ """
153
+ a, b = z
154
+ c, d = w
155
+ p = mpf_mul(a, c)
156
+ q = mpf_mul(b, d)
157
+ r = mpf_mul(a, d)
158
+ s = mpf_mul(b, c)
159
+ re = mpf_sub(p, q, prec, rnd)
160
+ im = mpf_add(r, s, prec, rnd)
161
+ return re, im
162
+
163
+ def mpc_square(z, prec, rnd=round_fast):
164
+ # (a+b*I)**2 == a**2 - b**2 + 2*I*a*b
165
+ a, b = z
166
+ p = mpf_mul(a,a)
167
+ q = mpf_mul(b,b)
168
+ r = mpf_mul(a,b, prec, rnd)
169
+ re = mpf_sub(p, q, prec, rnd)
170
+ im = mpf_shift(r, 1)
171
+ return re, im
172
+
173
+ def mpc_mul_mpf(z, p, prec, rnd=round_fast):
174
+ a, b = z
175
+ re = mpf_mul(a, p, prec, rnd)
176
+ im = mpf_mul(b, p, prec, rnd)
177
+ return re, im
178
+
179
+ def mpc_mul_imag_mpf(z, x, prec, rnd=round_fast):
180
+ """
181
+ Multiply the mpc value z by I*x where x is an mpf value.
182
+ """
183
+ a, b = z
184
+ re = mpf_neg(mpf_mul(b, x, prec, rnd))
185
+ im = mpf_mul(a, x, prec, rnd)
186
+ return re, im
187
+
188
+ def mpc_mul_int(z, n, prec, rnd=round_fast):
189
+ a, b = z
190
+ re = mpf_mul_int(a, n, prec, rnd)
191
+ im = mpf_mul_int(b, n, prec, rnd)
192
+ return re, im
193
+
194
+ def mpc_div(z, w, prec, rnd=round_fast):
195
+ a, b = z
196
+ c, d = w
197
+ wp = prec + 10
198
+ # mag = c*c + d*d
199
+ mag = mpf_add(mpf_mul(c, c), mpf_mul(d, d), wp)
200
+ # (a*c+b*d)/mag, (b*c-a*d)/mag
201
+ t = mpf_add(mpf_mul(a,c), mpf_mul(b,d), wp)
202
+ u = mpf_sub(mpf_mul(b,c), mpf_mul(a,d), wp)
203
+ return mpf_div(t,mag,prec,rnd), mpf_div(u,mag,prec,rnd)
204
+
205
+ def mpc_div_mpf(z, p, prec, rnd=round_fast):
206
+ """Calculate z/p where p is real"""
207
+ a, b = z
208
+ re = mpf_div(a, p, prec, rnd)
209
+ im = mpf_div(b, p, prec, rnd)
210
+ return re, im
211
+
212
+ def mpc_reciprocal(z, prec, rnd=round_fast):
213
+ """Calculate 1/z efficiently"""
214
+ a, b = z
215
+ m = mpf_add(mpf_mul(a,a),mpf_mul(b,b),prec+10)
216
+ re = mpf_div(a, m, prec, rnd)
217
+ im = mpf_neg(mpf_div(b, m, prec, rnd))
218
+ return re, im
219
+
220
+ def mpc_mpf_div(p, z, prec, rnd=round_fast):
221
+ """Calculate p/z where p is real efficiently"""
222
+ a, b = z
223
+ m = mpf_add(mpf_mul(a,a),mpf_mul(b,b), prec+10)
224
+ re = mpf_div(mpf_mul(a,p), m, prec, rnd)
225
+ im = mpf_div(mpf_neg(mpf_mul(b,p)), m, prec, rnd)
226
+ return re, im
227
+
228
+ def complex_int_pow(a, b, n):
229
+ """Complex integer power: computes (a+b*I)**n exactly for
230
+ nonnegative n (a and b must be Python ints)."""
231
+ wre = 1
232
+ wim = 0
233
+ while n:
234
+ if n & 1:
235
+ wre, wim = wre*a - wim*b, wim*a + wre*b
236
+ n -= 1
237
+ a, b = a*a - b*b, 2*a*b
238
+ n //= 2
239
+ return wre, wim
240
+
241
+ def mpc_pow(z, w, prec, rnd=round_fast):
242
+ if w[1] == fzero:
243
+ return mpc_pow_mpf(z, w[0], prec, rnd)
244
+ return mpc_exp(mpc_mul(mpc_log(z, prec+10), w, prec+10), prec, rnd)
245
+
246
+ def mpc_pow_mpf(z, p, prec, rnd=round_fast):
247
+ psign, pman, pexp, pbc = p
248
+ if pexp >= 0:
249
+ return mpc_pow_int(z, (-1)**psign * (pman<<pexp), prec, rnd)
250
+ if pexp == -1:
251
+ sqrtz = mpc_sqrt(z, prec+10)
252
+ return mpc_pow_int(sqrtz, (-1)**psign * pman, prec, rnd)
253
+ return mpc_exp(mpc_mul_mpf(mpc_log(z, prec+10), p, prec+10), prec, rnd)
254
+
255
+ def mpc_pow_int(z, n, prec, rnd=round_fast):
256
+ a, b = z
257
+ if b == fzero:
258
+ return mpf_pow_int(a, n, prec, rnd), fzero
259
+ if a == fzero:
260
+ v = mpf_pow_int(b, n, prec, rnd)
261
+ n %= 4
262
+ if n == 0:
263
+ return v, fzero
264
+ elif n == 1:
265
+ return fzero, v
266
+ elif n == 2:
267
+ return mpf_neg(v), fzero
268
+ elif n == 3:
269
+ return fzero, mpf_neg(v)
270
+ if n == 0: return mpc_one
271
+ if n == 1: return mpc_pos(z, prec, rnd)
272
+ if n == 2: return mpc_square(z, prec, rnd)
273
+ if n == -1: return mpc_reciprocal(z, prec, rnd)
274
+ if n < 0: return mpc_reciprocal(mpc_pow_int(z, -n, prec+4), prec, rnd)
275
+ asign, aman, aexp, abc = a
276
+ bsign, bman, bexp, bbc = b
277
+ if asign: aman = -aman
278
+ if bsign: bman = -bman
279
+ de = aexp - bexp
280
+ abs_de = abs(de)
281
+ exact_size = n*(abs_de + max(abc, bbc))
282
+ if exact_size < 10000:
283
+ if de > 0:
284
+ aman <<= de
285
+ aexp = bexp
286
+ else:
287
+ bman <<= (-de)
288
+ bexp = aexp
289
+ re, im = complex_int_pow(aman, bman, n)
290
+ re = from_man_exp(re, int(n*aexp), prec, rnd)
291
+ im = from_man_exp(im, int(n*bexp), prec, rnd)
292
+ return re, im
293
+ return mpc_exp(mpc_mul_int(mpc_log(z, prec+10), n, prec+10), prec, rnd)
294
+
295
+ def mpc_sqrt(z, prec, rnd=round_fast):
296
+ """Complex square root (principal branch).
297
+
298
+ We have sqrt(a+bi) = sqrt((r+a)/2) + b/sqrt(2*(r+a))*i where
299
+ r = abs(a+bi), when a+bi is not a negative real number."""
300
+ a, b = z
301
+ if b == fzero:
302
+ if a == fzero:
303
+ return (a, b)
304
+ # When a+bi is a negative real number, we get a real sqrt times i
305
+ if a[0]:
306
+ im = mpf_sqrt(mpf_neg(a), prec, rnd)
307
+ return (fzero, im)
308
+ else:
309
+ re = mpf_sqrt(a, prec, rnd)
310
+ return (re, fzero)
311
+ wp = prec+20
312
+ if not a[0]: # case a positive
313
+ t = mpf_add(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) + a
314
+ u = mpf_shift(t, -1) # u = t/2
315
+ re = mpf_sqrt(u, prec, rnd) # re = sqrt(u)
316
+ v = mpf_shift(t, 1) # v = 2*t
317
+ w = mpf_sqrt(v, wp) # w = sqrt(v)
318
+ im = mpf_div(b, w, prec, rnd) # im = b / w
319
+ else: # case a negative
320
+ t = mpf_sub(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) - a
321
+ u = mpf_shift(t, -1) # u = t/2
322
+ im = mpf_sqrt(u, prec, rnd) # im = sqrt(u)
323
+ v = mpf_shift(t, 1) # v = 2*t
324
+ w = mpf_sqrt(v, wp) # w = sqrt(v)
325
+ re = mpf_div(b, w, prec, rnd) # re = b/w
326
+ if b[0]:
327
+ re = mpf_neg(re)
328
+ im = mpf_neg(im)
329
+ return re, im
330
+
331
+ def mpc_nthroot_fixed(a, b, n, prec):
332
+ # a, b signed integers at fixed precision prec
333
+ start = 50
334
+ a1 = int(rshift(a, prec - n*start))
335
+ b1 = int(rshift(b, prec - n*start))
336
+ try:
337
+ r = (a1 + 1j * b1)**(1.0/n)
338
+ re = r.real
339
+ im = r.imag
340
+ re = MPZ(int(re))
341
+ im = MPZ(int(im))
342
+ except OverflowError:
343
+ a1 = from_int(a1, start)
344
+ b1 = from_int(b1, start)
345
+ fn = from_int(n)
346
+ nth = mpf_rdiv_int(1, fn, start)
347
+ re, im = mpc_pow((a1, b1), (nth, fzero), start)
348
+ re = to_int(re)
349
+ im = to_int(im)
350
+ extra = 10
351
+ prevp = start
352
+ extra1 = n
353
+ for p in giant_steps(start, prec+extra):
354
+ # this is slow for large n, unlike int_pow_fixed
355
+ re2, im2 = complex_int_pow(re, im, n-1)
356
+ re2 = rshift(re2, (n-1)*prevp - p - extra1)
357
+ im2 = rshift(im2, (n-1)*prevp - p - extra1)
358
+ r4 = (re2*re2 + im2*im2) >> (p + extra1)
359
+ ap = rshift(a, prec - p)
360
+ bp = rshift(b, prec - p)
361
+ rec = (ap * re2 + bp * im2) >> p
362
+ imc = (-ap * im2 + bp * re2) >> p
363
+ reb = (rec << p) // r4
364
+ imb = (imc << p) // r4
365
+ re = (reb + (n-1)*lshift(re, p-prevp))//n
366
+ im = (imb + (n-1)*lshift(im, p-prevp))//n
367
+ prevp = p
368
+ return re, im
369
+
370
+ def mpc_nthroot(z, n, prec, rnd=round_fast):
371
+ """
372
+ Complex n-th root.
373
+
374
+ Use Newton method as in the real case when it is faster,
375
+ otherwise use z**(1/n)
376
+ """
377
+ a, b = z
378
+ if a[0] == 0 and b == fzero:
379
+ re = mpf_nthroot(a, n, prec, rnd)
380
+ return (re, fzero)
381
+ if n < 2:
382
+ if n == 0:
383
+ return mpc_one
384
+ if n == 1:
385
+ return mpc_pos((a, b), prec, rnd)
386
+ if n == -1:
387
+ return mpc_div(mpc_one, (a, b), prec, rnd)
388
+ inverse = mpc_nthroot((a, b), -n, prec+5, reciprocal_rnd[rnd])
389
+ return mpc_div(mpc_one, inverse, prec, rnd)
390
+ if n <= 20:
391
+ prec2 = int(1.2 * (prec + 10))
392
+ asign, aman, aexp, abc = a
393
+ bsign, bman, bexp, bbc = b
394
+ pf = mpc_abs((a,b), prec)
395
+ if pf[-2] + pf[-1] > -10 and pf[-2] + pf[-1] < prec:
396
+ af = to_fixed(a, prec2)
397
+ bf = to_fixed(b, prec2)
398
+ re, im = mpc_nthroot_fixed(af, bf, n, prec2)
399
+ extra = 10
400
+ re = from_man_exp(re, -prec2-extra, prec2, rnd)
401
+ im = from_man_exp(im, -prec2-extra, prec2, rnd)
402
+ return re, im
403
+ fn = from_int(n)
404
+ prec2 = prec+10 + 10
405
+ nth = mpf_rdiv_int(1, fn, prec2)
406
+ re, im = mpc_pow((a, b), (nth, fzero), prec2, rnd)
407
+ re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
408
+ im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
409
+ return re, im
410
+
411
+ def mpc_cbrt(z, prec, rnd=round_fast):
412
+ """
413
+ Complex cubic root.
414
+ """
415
+ return mpc_nthroot(z, 3, prec, rnd)
416
+
417
+ def mpc_exp(z, prec, rnd=round_fast):
418
+ """
419
+ Complex exponential function.
420
+
421
+ We use the direct formula exp(a+bi) = exp(a) * (cos(b) + sin(b)*i)
422
+ for the computation. This formula is very nice because it is
423
+ pefectly stable; since we just do real multiplications, the only
424
+ numerical errors that can creep in are single-ulp rounding errors.
425
+
426
+ The formula is efficient since mpmath's real exp is quite fast and
427
+ since we can compute cos and sin simultaneously.
428
+
429
+ It is no problem if a and b are large; if the implementations of
430
+ exp/cos/sin are accurate and efficient for all real numbers, then
431
+ so is this function for all complex numbers.
432
+ """
433
+ a, b = z
434
+ if a == fzero:
435
+ return mpf_cos_sin(b, prec, rnd)
436
+ if b == fzero:
437
+ return mpf_exp(a, prec, rnd), fzero
438
+ mag = mpf_exp(a, prec+4, rnd)
439
+ c, s = mpf_cos_sin(b, prec+4, rnd)
440
+ re = mpf_mul(mag, c, prec, rnd)
441
+ im = mpf_mul(mag, s, prec, rnd)
442
+ return re, im
443
+
444
+ def mpc_log(z, prec, rnd=round_fast):
445
+ re = mpf_log_hypot(z[0], z[1], prec, rnd)
446
+ im = mpc_arg(z, prec, rnd)
447
+ return re, im
448
+
449
+ def mpc_cos(z, prec, rnd=round_fast):
450
+ """Complex cosine. The formula used is cos(a+bi) = cos(a)*cosh(b) -
451
+ sin(a)*sinh(b)*i.
452
+
453
+ The same comments apply as for the complex exp: only real
454
+ multiplications are pewrormed, so no cancellation errors are
455
+ possible. The formula is also efficient since we can compute both
456
+ pairs (cos, sin) and (cosh, sinh) in single stwps."""
457
+ a, b = z
458
+ if b == fzero:
459
+ return mpf_cos(a, prec, rnd), fzero
460
+ if a == fzero:
461
+ return mpf_cosh(b, prec, rnd), fzero
462
+ wp = prec + 6
463
+ c, s = mpf_cos_sin(a, wp)
464
+ ch, sh = mpf_cosh_sinh(b, wp)
465
+ re = mpf_mul(c, ch, prec, rnd)
466
+ im = mpf_mul(s, sh, prec, rnd)
467
+ return re, mpf_neg(im)
468
+
469
+ def mpc_sin(z, prec, rnd=round_fast):
470
+ """Complex sine. We have sin(a+bi) = sin(a)*cosh(b) +
471
+ cos(a)*sinh(b)*i. See the docstring for mpc_cos for additional
472
+ comments."""
473
+ a, b = z
474
+ if b == fzero:
475
+ return mpf_sin(a, prec, rnd), fzero
476
+ if a == fzero:
477
+ return fzero, mpf_sinh(b, prec, rnd)
478
+ wp = prec + 6
479
+ c, s = mpf_cos_sin(a, wp)
480
+ ch, sh = mpf_cosh_sinh(b, wp)
481
+ re = mpf_mul(s, ch, prec, rnd)
482
+ im = mpf_mul(c, sh, prec, rnd)
483
+ return re, im
484
+
485
+ def mpc_tan(z, prec, rnd=round_fast):
486
+ """Complex tangent. Computed as tan(a+bi) = sin(2a)/M + sinh(2b)/M*i
487
+ where M = cos(2a) + cosh(2b)."""
488
+ a, b = z
489
+ asign, aman, aexp, abc = a
490
+ bsign, bman, bexp, bbc = b
491
+ if b == fzero: return mpf_tan(a, prec, rnd), fzero
492
+ if a == fzero: return fzero, mpf_tanh(b, prec, rnd)
493
+ wp = prec + 15
494
+ a = mpf_shift(a, 1)
495
+ b = mpf_shift(b, 1)
496
+ c, s = mpf_cos_sin(a, wp)
497
+ ch, sh = mpf_cosh_sinh(b, wp)
498
+ # TODO: handle cancellation when c ~= -1 and ch ~= 1
499
+ mag = mpf_add(c, ch, wp)
500
+ re = mpf_div(s, mag, prec, rnd)
501
+ im = mpf_div(sh, mag, prec, rnd)
502
+ return re, im
503
+
504
+ def mpc_cos_pi(z, prec, rnd=round_fast):
505
+ a, b = z
506
+ if b == fzero:
507
+ return mpf_cos_pi(a, prec, rnd), fzero
508
+ b = mpf_mul(b, mpf_pi(prec+5), prec+5)
509
+ if a == fzero:
510
+ return mpf_cosh(b, prec, rnd), fzero
511
+ wp = prec + 6
512
+ c, s = mpf_cos_sin_pi(a, wp)
513
+ ch, sh = mpf_cosh_sinh(b, wp)
514
+ re = mpf_mul(c, ch, prec, rnd)
515
+ im = mpf_mul(s, sh, prec, rnd)
516
+ return re, mpf_neg(im)
517
+
518
+ def mpc_sin_pi(z, prec, rnd=round_fast):
519
+ a, b = z
520
+ if b == fzero:
521
+ return mpf_sin_pi(a, prec, rnd), fzero
522
+ b = mpf_mul(b, mpf_pi(prec+5), prec+5)
523
+ if a == fzero:
524
+ return fzero, mpf_sinh(b, prec, rnd)
525
+ wp = prec + 6
526
+ c, s = mpf_cos_sin_pi(a, wp)
527
+ ch, sh = mpf_cosh_sinh(b, wp)
528
+ re = mpf_mul(s, ch, prec, rnd)
529
+ im = mpf_mul(c, sh, prec, rnd)
530
+ return re, im
531
+
532
+ def mpc_cos_sin(z, prec, rnd=round_fast):
533
+ a, b = z
534
+ if a == fzero:
535
+ ch, sh = mpf_cosh_sinh(b, prec, rnd)
536
+ return (ch, fzero), (fzero, sh)
537
+ if b == fzero:
538
+ c, s = mpf_cos_sin(a, prec, rnd)
539
+ return (c, fzero), (s, fzero)
540
+ wp = prec + 6
541
+ c, s = mpf_cos_sin(a, wp)
542
+ ch, sh = mpf_cosh_sinh(b, wp)
543
+ cre = mpf_mul(c, ch, prec, rnd)
544
+ cim = mpf_mul(s, sh, prec, rnd)
545
+ sre = mpf_mul(s, ch, prec, rnd)
546
+ sim = mpf_mul(c, sh, prec, rnd)
547
+ return (cre, mpf_neg(cim)), (sre, sim)
548
+
549
+ def mpc_cos_sin_pi(z, prec, rnd=round_fast):
550
+ a, b = z
551
+ if b == fzero:
552
+ c, s = mpf_cos_sin_pi(a, prec, rnd)
553
+ return (c, fzero), (s, fzero)
554
+ b = mpf_mul(b, mpf_pi(prec+5), prec+5)
555
+ if a == fzero:
556
+ ch, sh = mpf_cosh_sinh(b, prec, rnd)
557
+ return (ch, fzero), (fzero, sh)
558
+ wp = prec + 6
559
+ c, s = mpf_cos_sin_pi(a, wp)
560
+ ch, sh = mpf_cosh_sinh(b, wp)
561
+ cre = mpf_mul(c, ch, prec, rnd)
562
+ cim = mpf_mul(s, sh, prec, rnd)
563
+ sre = mpf_mul(s, ch, prec, rnd)
564
+ sim = mpf_mul(c, sh, prec, rnd)
565
+ return (cre, mpf_neg(cim)), (sre, sim)
566
+
567
+ def mpc_cosh(z, prec, rnd=round_fast):
568
+ """Complex hyperbolic cosine. Computed as cosh(z) = cos(z*i)."""
569
+ a, b = z
570
+ return mpc_cos((b, mpf_neg(a)), prec, rnd)
571
+
572
+ def mpc_sinh(z, prec, rnd=round_fast):
573
+ """Complex hyperbolic sine. Computed as sinh(z) = -i*sin(z*i)."""
574
+ a, b = z
575
+ b, a = mpc_sin((b, a), prec, rnd)
576
+ return a, b
577
+
578
+ def mpc_tanh(z, prec, rnd=round_fast):
579
+ """Complex hyperbolic tangent. Computed as tanh(z) = -i*tan(z*i)."""
580
+ a, b = z
581
+ b, a = mpc_tan((b, a), prec, rnd)
582
+ return a, b
583
+
584
+ # TODO: avoid loss of accuracy
585
+ def mpc_atan(z, prec, rnd=round_fast):
586
+ a, b = z
587
+ # atan(z) = (I/2)*(log(1-I*z) - log(1+I*z))
588
+ # x = 1-I*z = 1 + b - I*a
589
+ # y = 1+I*z = 1 - b + I*a
590
+ wp = prec + 15
591
+ x = mpf_add(fone, b, wp), mpf_neg(a)
592
+ y = mpf_sub(fone, b, wp), a
593
+ l1 = mpc_log(x, wp)
594
+ l2 = mpc_log(y, wp)
595
+ a, b = mpc_sub(l1, l2, prec, rnd)
596
+ # (I/2) * (a+b*I) = (-b/2 + a/2*I)
597
+ v = mpf_neg(mpf_shift(b,-1)), mpf_shift(a,-1)
598
+ # Subtraction at infinity gives correct real part but
599
+ # wrong imaginary part (should be zero)
600
+ if v[1] == fnan and mpc_is_inf(z):
601
+ v = (v[0], fzero)
602
+ return v
603
+
604
+ beta_crossover = from_float(0.6417)
605
+ alpha_crossover = from_float(1.5)
606
+
607
+ def acos_asin(z, prec, rnd, n):
608
+ """ complex acos for n = 0, asin for n = 1
609
+ The algorithm is described in
610
+ T.E. Hull, T.F. Fairgrieve and P.T.P. Tang
611
+ 'Implementing the Complex Arcsine and Arcosine Functions
612
+ using Exception Handling',
613
+ ACM Trans. on Math. Software Vol. 23 (1997), p299
614
+ The complex acos and asin can be defined as
615
+ acos(z) = acos(beta) - I*sign(a)* log(alpha + sqrt(alpha**2 -1))
616
+ asin(z) = asin(beta) + I*sign(a)* log(alpha + sqrt(alpha**2 -1))
617
+ where z = a + I*b
618
+ alpha = (1/2)*(r + s); beta = (1/2)*(r - s) = a/alpha
619
+ r = sqrt((a+1)**2 + y**2); s = sqrt((a-1)**2 + y**2)
620
+ These expressions are rewritten in different ways in different
621
+ regions, delimited by two crossovers alpha_crossover and beta_crossover,
622
+ and by abs(a) <= 1, in order to improve the numerical accuracy.
623
+ """
624
+ a, b = z
625
+ wp = prec + 10
626
+ # special cases with real argument
627
+ if b == fzero:
628
+ am = mpf_sub(fone, mpf_abs(a), wp)
629
+ # case abs(a) <= 1
630
+ if not am[0]:
631
+ if n == 0:
632
+ return mpf_acos(a, prec, rnd), fzero
633
+ else:
634
+ return mpf_asin(a, prec, rnd), fzero
635
+ # cases abs(a) > 1
636
+ else:
637
+ # case a < -1
638
+ if a[0]:
639
+ pi = mpf_pi(prec, rnd)
640
+ c = mpf_acosh(mpf_neg(a), prec, rnd)
641
+ if n == 0:
642
+ return pi, mpf_neg(c)
643
+ else:
644
+ return mpf_neg(mpf_shift(pi, -1)), c
645
+ # case a > 1
646
+ else:
647
+ c = mpf_acosh(a, prec, rnd)
648
+ if n == 0:
649
+ return fzero, c
650
+ else:
651
+ pi = mpf_pi(prec, rnd)
652
+ return mpf_shift(pi, -1), mpf_neg(c)
653
+ asign = bsign = 0
654
+ if a[0]:
655
+ a = mpf_neg(a)
656
+ asign = 1
657
+ if b[0]:
658
+ b = mpf_neg(b)
659
+ bsign = 1
660
+ am = mpf_sub(fone, a, wp)
661
+ ap = mpf_add(fone, a, wp)
662
+ r = mpf_hypot(ap, b, wp)
663
+ s = mpf_hypot(am, b, wp)
664
+ alpha = mpf_shift(mpf_add(r, s, wp), -1)
665
+ beta = mpf_div(a, alpha, wp)
666
+ b2 = mpf_mul(b,b, wp)
667
+ # case beta <= beta_crossover
668
+ if not mpf_sub(beta_crossover, beta, wp)[0]:
669
+ if n == 0:
670
+ re = mpf_acos(beta, wp)
671
+ else:
672
+ re = mpf_asin(beta, wp)
673
+ else:
674
+ # to compute the real part in this region use the identity
675
+ # asin(beta) = atan(beta/sqrt(1-beta**2))
676
+ # beta/sqrt(1-beta**2) = (alpha + a) * (alpha - a)
677
+ # alpha + a is numerically accurate; alpha - a can have
678
+ # cancellations leading to numerical inaccuracies, so rewrite
679
+ # it in differente ways according to the region
680
+ Ax = mpf_add(alpha, a, wp)
681
+ # case a <= 1
682
+ if not am[0]:
683
+ # c = b*b/(r + (a+1)); d = (s + (1-a))
684
+ # alpha - a = (1/2)*(c + d)
685
+ # case n=0: re = atan(sqrt((1/2) * Ax * (c + d))/a)
686
+ # case n=1: re = atan(a/sqrt((1/2) * Ax * (c + d)))
687
+ c = mpf_div(b2, mpf_add(r, ap, wp), wp)
688
+ d = mpf_add(s, am, wp)
689
+ re = mpf_shift(mpf_mul(Ax, mpf_add(c, d, wp), wp), -1)
690
+ if n == 0:
691
+ re = mpf_atan(mpf_div(mpf_sqrt(re, wp), a, wp), wp)
692
+ else:
693
+ re = mpf_atan(mpf_div(a, mpf_sqrt(re, wp), wp), wp)
694
+ else:
695
+ # c = Ax/(r + (a+1)); d = Ax/(s - (1-a))
696
+ # alpha - a = (1/2)*(c + d)
697
+ # case n = 0: re = atan(b*sqrt(c + d)/2/a)
698
+ # case n = 1: re = atan(a/(b*sqrt(c + d)/2)
699
+ c = mpf_div(Ax, mpf_add(r, ap, wp), wp)
700
+ d = mpf_div(Ax, mpf_sub(s, am, wp), wp)
701
+ re = mpf_shift(mpf_add(c, d, wp), -1)
702
+ re = mpf_mul(b, mpf_sqrt(re, wp), wp)
703
+ if n == 0:
704
+ re = mpf_atan(mpf_div(re, a, wp), wp)
705
+ else:
706
+ re = mpf_atan(mpf_div(a, re, wp), wp)
707
+ # to compute alpha + sqrt(alpha**2 - 1), if alpha <= alpha_crossover
708
+ # replace it with 1 + Am1 + sqrt(Am1*(alpha+1)))
709
+ # where Am1 = alpha -1
710
+ # if alpha <= alpha_crossover:
711
+ if not mpf_sub(alpha_crossover, alpha, wp)[0]:
712
+ c1 = mpf_div(b2, mpf_add(r, ap, wp), wp)
713
+ # case a < 1
714
+ if mpf_neg(am)[0]:
715
+ # Am1 = (1/2) * (b*b/(r + (a+1)) + b*b/(s + (1-a))
716
+ c2 = mpf_add(s, am, wp)
717
+ c2 = mpf_div(b2, c2, wp)
718
+ Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
719
+ else:
720
+ # Am1 = (1/2) * (b*b/(r + (a+1)) + (s - (1-a)))
721
+ c2 = mpf_sub(s, am, wp)
722
+ Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
723
+ # im = log(1 + Am1 + sqrt(Am1*(alpha+1)))
724
+ im = mpf_mul(Am1, mpf_add(alpha, fone, wp), wp)
725
+ im = mpf_log(mpf_add(fone, mpf_add(Am1, mpf_sqrt(im, wp), wp), wp), wp)
726
+ else:
727
+ # im = log(alpha + sqrt(alpha*alpha - 1))
728
+ im = mpf_sqrt(mpf_sub(mpf_mul(alpha, alpha, wp), fone, wp), wp)
729
+ im = mpf_log(mpf_add(alpha, im, wp), wp)
730
+ if asign:
731
+ if n == 0:
732
+ re = mpf_sub(mpf_pi(wp), re, wp)
733
+ else:
734
+ re = mpf_neg(re)
735
+ if not bsign and n == 0:
736
+ im = mpf_neg(im)
737
+ if bsign and n == 1:
738
+ im = mpf_neg(im)
739
+ re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
740
+ im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
741
+ return re, im
742
+
743
+ def mpc_acos(z, prec, rnd=round_fast):
744
+ return acos_asin(z, prec, rnd, 0)
745
+
746
+ def mpc_asin(z, prec, rnd=round_fast):
747
+ return acos_asin(z, prec, rnd, 1)
748
+
749
+ def mpc_asinh(z, prec, rnd=round_fast):
750
+ # asinh(z) = I * asin(-I z)
751
+ a, b = z
752
+ a, b = mpc_asin((b, mpf_neg(a)), prec, rnd)
753
+ return mpf_neg(b), a
754
+
755
+ def mpc_acosh(z, prec, rnd=round_fast):
756
+ # acosh(z) = -I * acos(z) for Im(acos(z)) <= 0
757
+ # +I * acos(z) otherwise
758
+ a, b = mpc_acos(z, prec, rnd)
759
+ if b[0] or b == fzero:
760
+ return mpf_neg(b), a
761
+ else:
762
+ return b, mpf_neg(a)
763
+
764
+ def mpc_atanh(z, prec, rnd=round_fast):
765
+ # atanh(z) = (log(1+z)-log(1-z))/2
766
+ wp = prec + 15
767
+ a = mpc_add(z, mpc_one, wp)
768
+ b = mpc_sub(mpc_one, z, wp)
769
+ a = mpc_log(a, wp)
770
+ b = mpc_log(b, wp)
771
+ v = mpc_shift(mpc_sub(a, b, wp), -1)
772
+ # Subtraction at infinity gives correct imaginary part but
773
+ # wrong real part (should be zero)
774
+ if v[0] == fnan and mpc_is_inf(z):
775
+ v = (fzero, v[1])
776
+ return v
777
+
778
+ def mpc_fibonacci(z, prec, rnd=round_fast):
779
+ re, im = z
780
+ if im == fzero:
781
+ return (mpf_fibonacci(re, prec, rnd), fzero)
782
+ size = max(abs(re[2]+re[3]), abs(re[2]+re[3]))
783
+ wp = prec + size + 20
784
+ a = mpf_phi(wp)
785
+ b = mpf_add(mpf_shift(a, 1), fnone, wp)
786
+ u = mpc_pow((a, fzero), z, wp)
787
+ v = mpc_cos_pi(z, wp)
788
+ v = mpc_div(v, u, wp)
789
+ u = mpc_sub(u, v, wp)
790
+ u = mpc_div_mpf(u, b, prec, rnd)
791
+ return u
792
+
793
+ def mpf_expj(x, prec, rnd='f'):
794
+ raise ComplexResult
795
+
796
+ def mpc_expj(z, prec, rnd='f'):
797
+ re, im = z
798
+ if im == fzero:
799
+ return mpf_cos_sin(re, prec, rnd)
800
+ if re == fzero:
801
+ return mpf_exp(mpf_neg(im), prec, rnd), fzero
802
+ ey = mpf_exp(mpf_neg(im), prec+10)
803
+ c, s = mpf_cos_sin(re, prec+10)
804
+ re = mpf_mul(ey, c, prec, rnd)
805
+ im = mpf_mul(ey, s, prec, rnd)
806
+ return re, im
807
+
808
+ def mpf_expjpi(x, prec, rnd='f'):
809
+ raise ComplexResult
810
+
811
+ def mpc_expjpi(z, prec, rnd='f'):
812
+ re, im = z
813
+ if im == fzero:
814
+ return mpf_cos_sin_pi(re, prec, rnd)
815
+ sign, man, exp, bc = im
816
+ wp = prec+10
817
+ if man:
818
+ wp += max(0, exp+bc)
819
+ im = mpf_neg(mpf_mul(mpf_pi(wp), im, wp))
820
+ if re == fzero:
821
+ return mpf_exp(im, prec, rnd), fzero
822
+ ey = mpf_exp(im, prec+10)
823
+ c, s = mpf_cos_sin_pi(re, prec+10)
824
+ re = mpf_mul(ey, c, prec, rnd)
825
+ im = mpf_mul(ey, s, prec, rnd)
826
+ return re, im
827
+
828
+
829
+ if BACKEND == 'sage':
830
+ try:
831
+ import sage.libs.mpmath.ext_libmp as _lbmp
832
+ mpc_exp = _lbmp.mpc_exp
833
+ mpc_sqrt = _lbmp.mpc_sqrt
834
+ except (ImportError, AttributeError):
835
+ print("Warning: Sage imports in libmpc failed")
lib/python3.11/site-packages/mpmath/libmp/libmpf.py ADDED
@@ -0,0 +1,1414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Low-level functions for arbitrary-precision floating-point arithmetic.
3
+ """
4
+
5
+ __docformat__ = 'plaintext'
6
+
7
+ import math
8
+
9
+ from bisect import bisect
10
+
11
+ import sys
12
+
13
+ # Importing random is slow
14
+ #from random import getrandbits
15
+ getrandbits = None
16
+
17
+ from .backend import (MPZ, MPZ_TYPE, MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_FIVE,
18
+ BACKEND, STRICT, HASH_MODULUS, HASH_BITS, gmpy, sage, sage_utils)
19
+
20
+ from .libintmath import (giant_steps,
21
+ trailtable, bctable, lshift, rshift, bitcount, trailing,
22
+ sqrt_fixed, numeral, isqrt, isqrt_fast, sqrtrem,
23
+ bin_to_radix)
24
+
25
+ # We don't pickle tuples directly for the following reasons:
26
+ # 1: pickle uses str() for ints, which is inefficient when they are large
27
+ # 2: pickle doesn't work for gmpy mpzs
28
+ # Both problems are solved by using hex()
29
+
30
+ if BACKEND == 'sage':
31
+ def to_pickable(x):
32
+ sign, man, exp, bc = x
33
+ return sign, hex(man), exp, bc
34
+ else:
35
+ def to_pickable(x):
36
+ sign, man, exp, bc = x
37
+ return sign, hex(man)[2:], exp, bc
38
+
39
+ def from_pickable(x):
40
+ sign, man, exp, bc = x
41
+ return (sign, MPZ(man, 16), exp, bc)
42
+
43
+ class ComplexResult(ValueError):
44
+ pass
45
+
46
+ try:
47
+ intern
48
+ except NameError:
49
+ intern = lambda x: x
50
+
51
+ # All supported rounding modes
52
+ round_nearest = intern('n')
53
+ round_floor = intern('f')
54
+ round_ceiling = intern('c')
55
+ round_up = intern('u')
56
+ round_down = intern('d')
57
+ round_fast = round_down
58
+
59
+ def prec_to_dps(n):
60
+ """Return number of accurate decimals that can be represented
61
+ with a precision of n bits."""
62
+ return max(1, int(round(int(n)/3.3219280948873626)-1))
63
+
64
+ def dps_to_prec(n):
65
+ """Return the number of bits required to represent n decimals
66
+ accurately."""
67
+ return max(1, int(round((int(n)+1)*3.3219280948873626)))
68
+
69
+ def repr_dps(n):
70
+ """Return the number of decimal digits required to represent
71
+ a number with n-bit precision so that it can be uniquely
72
+ reconstructed from the representation."""
73
+ dps = prec_to_dps(n)
74
+ if dps == 15:
75
+ return 17
76
+ return dps + 3
77
+
78
+ #----------------------------------------------------------------------------#
79
+ # Some commonly needed float values #
80
+ #----------------------------------------------------------------------------#
81
+
82
+ # Regular number format:
83
+ # (-1)**sign * mantissa * 2**exponent, plus bitcount of mantissa
84
+ fzero = (0, MPZ_ZERO, 0, 0)
85
+ fnzero = (1, MPZ_ZERO, 0, 0)
86
+ fone = (0, MPZ_ONE, 0, 1)
87
+ fnone = (1, MPZ_ONE, 0, 1)
88
+ ftwo = (0, MPZ_ONE, 1, 1)
89
+ ften = (0, MPZ_FIVE, 1, 3)
90
+ fhalf = (0, MPZ_ONE, -1, 1)
91
+
92
+ # Arbitrary encoding for special numbers: zero mantissa, nonzero exponent
93
+ fnan = (0, MPZ_ZERO, -123, -1)
94
+ finf = (0, MPZ_ZERO, -456, -2)
95
+ fninf = (1, MPZ_ZERO, -789, -3)
96
+
97
+ # Was 1e1000; this is broken in Python 2.4
98
+ math_float_inf = 1e300 * 1e300
99
+
100
+
101
+ #----------------------------------------------------------------------------#
102
+ # Rounding #
103
+ #----------------------------------------------------------------------------#
104
+
105
+ # This function can be used to round a mantissa generally. However,
106
+ # we will try to do most rounding inline for efficiency.
107
+ def round_int(x, n, rnd):
108
+ if rnd == round_nearest:
109
+ if x >= 0:
110
+ t = x >> (n-1)
111
+ if t & 1 and ((t & 2) or (x & h_mask[n<300][n])):
112
+ return (t>>1)+1
113
+ else:
114
+ return t>>1
115
+ else:
116
+ return -round_int(-x, n, rnd)
117
+ if rnd == round_floor:
118
+ return x >> n
119
+ if rnd == round_ceiling:
120
+ return -((-x) >> n)
121
+ if rnd == round_down:
122
+ if x >= 0:
123
+ return x >> n
124
+ return -((-x) >> n)
125
+ if rnd == round_up:
126
+ if x >= 0:
127
+ return -((-x) >> n)
128
+ return x >> n
129
+
130
+ # These masks are used to pick out segments of numbers to determine
131
+ # which direction to round when rounding to nearest.
132
+ class h_mask_big:
133
+ def __getitem__(self, n):
134
+ return (MPZ_ONE<<(n-1))-1
135
+
136
+ h_mask_small = [0]+[((MPZ_ONE<<(_-1))-1) for _ in range(1, 300)]
137
+ h_mask = [h_mask_big(), h_mask_small]
138
+
139
+ # The >> operator rounds to floor. shifts_down[rnd][sign]
140
+ # tells whether this is the right direction to use, or if the
141
+ # number should be negated before shifting
142
+ shifts_down = {round_floor:(1,0), round_ceiling:(0,1),
143
+ round_down:(1,1), round_up:(0,0)}
144
+
145
+
146
+ #----------------------------------------------------------------------------#
147
+ # Normalization of raw mpfs #
148
+ #----------------------------------------------------------------------------#
149
+
150
+ # This function is called almost every time an mpf is created.
151
+ # It has been optimized accordingly.
152
+
153
+ def _normalize(sign, man, exp, bc, prec, rnd):
154
+ """
155
+ Create a raw mpf tuple with value (-1)**sign * man * 2**exp and
156
+ normalized mantissa. The mantissa is rounded in the specified
157
+ direction if its size exceeds the precision. Trailing zero bits
158
+ are also stripped from the mantissa to ensure that the
159
+ representation is canonical.
160
+
161
+ Conditions on the input:
162
+ * The input must represent a regular (finite) number
163
+ * The sign bit must be 0 or 1
164
+ * The mantissa must be positive
165
+ * The exponent must be an integer
166
+ * The bitcount must be exact
167
+
168
+ If these conditions are not met, use from_man_exp, mpf_pos, or any
169
+ of the conversion functions to create normalized raw mpf tuples.
170
+ """
171
+ if not man:
172
+ return fzero
173
+ # Cut mantissa down to size if larger than target precision
174
+ n = bc - prec
175
+ if n > 0:
176
+ if rnd == round_nearest:
177
+ t = man >> (n-1)
178
+ if t & 1 and ((t & 2) or (man & h_mask[n<300][n])):
179
+ man = (t>>1)+1
180
+ else:
181
+ man = t>>1
182
+ elif shifts_down[rnd][sign]:
183
+ man >>= n
184
+ else:
185
+ man = -((-man)>>n)
186
+ exp += n
187
+ bc = prec
188
+ # Strip trailing bits
189
+ if not man & 1:
190
+ t = trailtable[int(man & 255)]
191
+ if not t:
192
+ while not man & 255:
193
+ man >>= 8
194
+ exp += 8
195
+ bc -= 8
196
+ t = trailtable[int(man & 255)]
197
+ man >>= t
198
+ exp += t
199
+ bc -= t
200
+ # Bit count can be wrong if the input mantissa was 1 less than
201
+ # a power of 2 and got rounded up, thereby adding an extra bit.
202
+ # With trailing bits removed, all powers of two have mantissa 1,
203
+ # so this is easy to check for.
204
+ if man == 1:
205
+ bc = 1
206
+ return sign, man, exp, bc
207
+
208
+ def _normalize1(sign, man, exp, bc, prec, rnd):
209
+ """same as normalize, but with the added condition that
210
+ man is odd or zero
211
+ """
212
+ if not man:
213
+ return fzero
214
+ if bc <= prec:
215
+ return sign, man, exp, bc
216
+ n = bc - prec
217
+ if rnd == round_nearest:
218
+ t = man >> (n-1)
219
+ if t & 1 and ((t & 2) or (man & h_mask[n<300][n])):
220
+ man = (t>>1)+1
221
+ else:
222
+ man = t>>1
223
+ elif shifts_down[rnd][sign]:
224
+ man >>= n
225
+ else:
226
+ man = -((-man)>>n)
227
+ exp += n
228
+ bc = prec
229
+ # Strip trailing bits
230
+ if not man & 1:
231
+ t = trailtable[int(man & 255)]
232
+ if not t:
233
+ while not man & 255:
234
+ man >>= 8
235
+ exp += 8
236
+ bc -= 8
237
+ t = trailtable[int(man & 255)]
238
+ man >>= t
239
+ exp += t
240
+ bc -= t
241
+ # Bit count can be wrong if the input mantissa was 1 less than
242
+ # a power of 2 and got rounded up, thereby adding an extra bit.
243
+ # With trailing bits removed, all powers of two have mantissa 1,
244
+ # so this is easy to check for.
245
+ if man == 1:
246
+ bc = 1
247
+ return sign, man, exp, bc
248
+
249
+ try:
250
+ _exp_types = (int, long)
251
+ except NameError:
252
+ _exp_types = (int,)
253
+
254
+ def strict_normalize(sign, man, exp, bc, prec, rnd):
255
+ """Additional checks on the components of an mpf. Enable tests by setting
256
+ the environment variable MPMATH_STRICT to Y."""
257
+ assert type(man) == MPZ_TYPE
258
+ assert type(bc) in _exp_types
259
+ assert type(exp) in _exp_types
260
+ assert bc == bitcount(man)
261
+ return _normalize(sign, man, exp, bc, prec, rnd)
262
+
263
+ def strict_normalize1(sign, man, exp, bc, prec, rnd):
264
+ """Additional checks on the components of an mpf. Enable tests by setting
265
+ the environment variable MPMATH_STRICT to Y."""
266
+ assert type(man) == MPZ_TYPE
267
+ assert type(bc) in _exp_types
268
+ assert type(exp) in _exp_types
269
+ assert bc == bitcount(man)
270
+ assert (not man) or (man & 1)
271
+ return _normalize1(sign, man, exp, bc, prec, rnd)
272
+
273
+ if BACKEND == 'gmpy' and '_mpmath_normalize' in dir(gmpy):
274
+ _normalize = gmpy._mpmath_normalize
275
+ _normalize1 = gmpy._mpmath_normalize
276
+
277
+ if BACKEND == 'sage':
278
+ _normalize = _normalize1 = sage_utils.normalize
279
+
280
+ if STRICT:
281
+ normalize = strict_normalize
282
+ normalize1 = strict_normalize1
283
+ else:
284
+ normalize = _normalize
285
+ normalize1 = _normalize1
286
+
287
+ #----------------------------------------------------------------------------#
288
+ # Conversion functions #
289
+ #----------------------------------------------------------------------------#
290
+
291
+ def from_man_exp(man, exp, prec=None, rnd=round_fast):
292
+ """Create raw mpf from (man, exp) pair. The mantissa may be signed.
293
+ If no precision is specified, the mantissa is stored exactly."""
294
+ man = MPZ(man)
295
+ sign = 0
296
+ if man < 0:
297
+ sign = 1
298
+ man = -man
299
+ if man < 1024:
300
+ bc = bctable[int(man)]
301
+ else:
302
+ bc = bitcount(man)
303
+ if not prec:
304
+ if not man:
305
+ return fzero
306
+ if not man & 1:
307
+ if man & 2:
308
+ return (sign, man >> 1, exp + 1, bc - 1)
309
+ t = trailtable[int(man & 255)]
310
+ if not t:
311
+ while not man & 255:
312
+ man >>= 8
313
+ exp += 8
314
+ bc -= 8
315
+ t = trailtable[int(man & 255)]
316
+ man >>= t
317
+ exp += t
318
+ bc -= t
319
+ return (sign, man, exp, bc)
320
+ return normalize(sign, man, exp, bc, prec, rnd)
321
+
322
+ int_cache = dict((n, from_man_exp(n, 0)) for n in range(-10, 257))
323
+
324
+ if BACKEND == 'gmpy' and '_mpmath_create' in dir(gmpy):
325
+ from_man_exp = gmpy._mpmath_create
326
+
327
+ if BACKEND == 'sage':
328
+ from_man_exp = sage_utils.from_man_exp
329
+
330
+ def from_int(n, prec=0, rnd=round_fast):
331
+ """Create a raw mpf from an integer. If no precision is specified,
332
+ the mantissa is stored exactly."""
333
+ if not prec:
334
+ if n in int_cache:
335
+ return int_cache[n]
336
+ return from_man_exp(n, 0, prec, rnd)
337
+
338
+ def to_man_exp(s):
339
+ """Return (man, exp) of a raw mpf. Raise an error if inf/nan."""
340
+ sign, man, exp, bc = s
341
+ if (not man) and exp:
342
+ raise ValueError("mantissa and exponent are undefined for %s" % man)
343
+ return man, exp
344
+
345
+ def to_int(s, rnd=None):
346
+ """Convert a raw mpf to the nearest int. Rounding is done down by
347
+ default (same as int(float) in Python), but can be changed. If the
348
+ input is inf/nan, an exception is raised."""
349
+ sign, man, exp, bc = s
350
+ if (not man) and exp:
351
+ raise ValueError("cannot convert inf or nan to int")
352
+ if exp >= 0:
353
+ if sign:
354
+ return (-man) << exp
355
+ return man << exp
356
+ # Make default rounding fast
357
+ if not rnd:
358
+ if sign:
359
+ return -(man >> (-exp))
360
+ else:
361
+ return man >> (-exp)
362
+ if sign:
363
+ return round_int(-man, -exp, rnd)
364
+ else:
365
+ return round_int(man, -exp, rnd)
366
+
367
+ def mpf_round_int(s, rnd):
368
+ sign, man, exp, bc = s
369
+ if (not man) and exp:
370
+ return s
371
+ if exp >= 0:
372
+ return s
373
+ mag = exp+bc
374
+ if mag < 1:
375
+ if rnd == round_ceiling:
376
+ if sign: return fzero
377
+ else: return fone
378
+ elif rnd == round_floor:
379
+ if sign: return fnone
380
+ else: return fzero
381
+ elif rnd == round_nearest:
382
+ if mag < 0 or man == MPZ_ONE: return fzero
383
+ elif sign: return fnone
384
+ else: return fone
385
+ else:
386
+ raise NotImplementedError
387
+ return mpf_pos(s, min(bc, mag), rnd)
388
+
389
+ def mpf_floor(s, prec=0, rnd=round_fast):
390
+ v = mpf_round_int(s, round_floor)
391
+ if prec:
392
+ v = mpf_pos(v, prec, rnd)
393
+ return v
394
+
395
+ def mpf_ceil(s, prec=0, rnd=round_fast):
396
+ v = mpf_round_int(s, round_ceiling)
397
+ if prec:
398
+ v = mpf_pos(v, prec, rnd)
399
+ return v
400
+
401
+ def mpf_nint(s, prec=0, rnd=round_fast):
402
+ v = mpf_round_int(s, round_nearest)
403
+ if prec:
404
+ v = mpf_pos(v, prec, rnd)
405
+ return v
406
+
407
+ def mpf_frac(s, prec=0, rnd=round_fast):
408
+ return mpf_sub(s, mpf_floor(s), prec, rnd)
409
+
410
+ def from_float(x, prec=53, rnd=round_fast):
411
+ """Create a raw mpf from a Python float, rounding if necessary.
412
+ If prec >= 53, the result is guaranteed to represent exactly the
413
+ same number as the input. If prec is not specified, use prec=53."""
414
+ # frexp only raises an exception for nan on some platforms
415
+ if x != x:
416
+ return fnan
417
+ # in Python2.5 math.frexp gives an exception for float infinity
418
+ # in Python2.6 it returns (float infinity, 0)
419
+ try:
420
+ m, e = math.frexp(x)
421
+ except:
422
+ if x == math_float_inf: return finf
423
+ if x == -math_float_inf: return fninf
424
+ return fnan
425
+ if x == math_float_inf: return finf
426
+ if x == -math_float_inf: return fninf
427
+ return from_man_exp(int(m*(1<<53)), e-53, prec, rnd)
428
+
429
+ def from_npfloat(x, prec=113, rnd=round_fast):
430
+ """Create a raw mpf from a numpy float, rounding if necessary.
431
+ If prec >= 113, the result is guaranteed to represent exactly the
432
+ same number as the input. If prec is not specified, use prec=113."""
433
+ y = float(x)
434
+ if x == y: # ldexp overflows for float16
435
+ return from_float(y, prec, rnd)
436
+ import numpy as np
437
+ if np.isfinite(x):
438
+ m, e = np.frexp(x)
439
+ return from_man_exp(int(np.ldexp(m, 113)), int(e-113), prec, rnd)
440
+ if np.isposinf(x): return finf
441
+ if np.isneginf(x): return fninf
442
+ return fnan
443
+
444
+ def from_Decimal(x, prec=None, rnd=round_fast):
445
+ """Create a raw mpf from a Decimal, rounding if necessary.
446
+ If prec is not specified, use the equivalent bit precision
447
+ of the number of significant digits in x."""
448
+ if x.is_nan(): return fnan
449
+ if x.is_infinite(): return fninf if x.is_signed() else finf
450
+ if prec is None:
451
+ prec = int(len(x.as_tuple()[1])*3.3219280948873626)
452
+ return from_str(str(x), prec, rnd)
453
+
454
+ def to_float(s, strict=False, rnd=round_fast):
455
+ """
456
+ Convert a raw mpf to a Python float. The result is exact if the
457
+ bitcount of s is <= 53 and no underflow/overflow occurs.
458
+
459
+ If the number is too large or too small to represent as a regular
460
+ float, it will be converted to inf or 0.0. Setting strict=True
461
+ forces an OverflowError to be raised instead.
462
+
463
+ Warning: with a directed rounding mode, the correct nearest representable
464
+ floating-point number in the specified direction might not be computed
465
+ in case of overflow or (gradual) underflow.
466
+ """
467
+ sign, man, exp, bc = s
468
+ if not man:
469
+ if s == fzero: return 0.0
470
+ if s == finf: return math_float_inf
471
+ if s == fninf: return -math_float_inf
472
+ return math_float_inf/math_float_inf
473
+ if bc > 53:
474
+ sign, man, exp, bc = normalize1(sign, man, exp, bc, 53, rnd)
475
+ if sign:
476
+ man = -man
477
+ try:
478
+ return math.ldexp(man, exp)
479
+ except OverflowError:
480
+ if strict:
481
+ raise
482
+ # Overflow to infinity
483
+ if exp + bc > 0:
484
+ if sign:
485
+ return -math_float_inf
486
+ else:
487
+ return math_float_inf
488
+ # Underflow to zero
489
+ return 0.0
490
+
491
+ def from_rational(p, q, prec, rnd=round_fast):
492
+ """Create a raw mpf from a rational number p/q, round if
493
+ necessary."""
494
+ return mpf_div(from_int(p), from_int(q), prec, rnd)
495
+
496
+ def to_rational(s):
497
+ """Convert a raw mpf to a rational number. Return integers (p, q)
498
+ such that s = p/q exactly."""
499
+ sign, man, exp, bc = s
500
+ if sign:
501
+ man = -man
502
+ if bc == -1:
503
+ raise ValueError("cannot convert %s to a rational number" % man)
504
+ if exp >= 0:
505
+ return man * (1<<exp), 1
506
+ else:
507
+ return man, 1<<(-exp)
508
+
509
+ def to_fixed(s, prec):
510
+ """Convert a raw mpf to a fixed-point big integer"""
511
+ sign, man, exp, bc = s
512
+ offset = exp + prec
513
+ if sign:
514
+ if offset >= 0: return (-man) << offset
515
+ else: return (-man) >> (-offset)
516
+ else:
517
+ if offset >= 0: return man << offset
518
+ else: return man >> (-offset)
519
+
520
+
521
+ ##############################################################################
522
+ ##############################################################################
523
+
524
+ #----------------------------------------------------------------------------#
525
+ # Arithmetic operations, etc. #
526
+ #----------------------------------------------------------------------------#
527
+
528
+ def mpf_rand(prec):
529
+ """Return a raw mpf chosen randomly from [0, 1), with prec bits
530
+ in the mantissa."""
531
+ global getrandbits
532
+ if not getrandbits:
533
+ import random
534
+ getrandbits = random.getrandbits
535
+ return from_man_exp(getrandbits(prec), -prec, prec, round_floor)
536
+
537
+ def mpf_eq(s, t):
538
+ """Test equality of two raw mpfs. This is simply tuple comparison
539
+ unless either number is nan, in which case the result is False."""
540
+ if not s[1] or not t[1]:
541
+ if s == fnan or t == fnan:
542
+ return False
543
+ return s == t
544
+
545
+ def mpf_hash(s):
546
+ # Duplicate the new hash algorithm introduces in Python 3.2.
547
+ if sys.version_info >= (3, 2):
548
+ ssign, sman, sexp, sbc = s
549
+
550
+ # Handle special numbers
551
+ if not sman:
552
+ if s == fnan: return sys.hash_info.nan
553
+ if s == finf: return sys.hash_info.inf
554
+ if s == fninf: return -sys.hash_info.inf
555
+ h = sman % HASH_MODULUS
556
+ if sexp >= 0:
557
+ sexp = sexp % HASH_BITS
558
+ else:
559
+ sexp = HASH_BITS - 1 - ((-1 - sexp) % HASH_BITS)
560
+ h = (h << sexp) % HASH_MODULUS
561
+ if ssign: h = -h
562
+ if h == -1: h = -2
563
+ return int(h)
564
+ else:
565
+ try:
566
+ # Try to be compatible with hash values for floats and ints
567
+ return hash(to_float(s, strict=1))
568
+ except OverflowError:
569
+ # We must unfortunately sacrifice compatibility with ints here.
570
+ # We could do hash(man << exp) when the exponent is positive, but
571
+ # this would cause unreasonable inefficiency for large numbers.
572
+ return hash(s)
573
+
574
+ def mpf_cmp(s, t):
575
+ """Compare the raw mpfs s and t. Return -1 if s < t, 0 if s == t,
576
+ and 1 if s > t. (Same convention as Python's cmp() function.)"""
577
+
578
+ # In principle, a comparison amounts to determining the sign of s-t.
579
+ # A full subtraction is relatively slow, however, so we first try to
580
+ # look at the components.
581
+ ssign, sman, sexp, sbc = s
582
+ tsign, tman, texp, tbc = t
583
+
584
+ # Handle zeros and special numbers
585
+ if not sman or not tman:
586
+ if s == fzero: return -mpf_sign(t)
587
+ if t == fzero: return mpf_sign(s)
588
+ if s == t: return 0
589
+ # Follow same convention as Python's cmp for float nan
590
+ if t == fnan: return 1
591
+ if s == finf: return 1
592
+ if t == fninf: return 1
593
+ return -1
594
+ # Different sides of zero
595
+ if ssign != tsign:
596
+ if not ssign: return 1
597
+ return -1
598
+ # This reduces to direct integer comparison
599
+ if sexp == texp:
600
+ if sman == tman:
601
+ return 0
602
+ if sman > tman:
603
+ if ssign: return -1
604
+ else: return 1
605
+ else:
606
+ if ssign: return 1
607
+ else: return -1
608
+ # Check position of the highest set bit in each number. If
609
+ # different, there is certainly an inequality.
610
+ a = sbc + sexp
611
+ b = tbc + texp
612
+ if ssign:
613
+ if a < b: return 1
614
+ if a > b: return -1
615
+ else:
616
+ if a < b: return -1
617
+ if a > b: return 1
618
+
619
+ # Both numbers have the same highest bit. Subtract to find
620
+ # how the lower bits compare.
621
+ delta = mpf_sub(s, t, 5, round_floor)
622
+ if delta[0]:
623
+ return -1
624
+ return 1
625
+
626
+ def mpf_lt(s, t):
627
+ if s == fnan or t == fnan:
628
+ return False
629
+ return mpf_cmp(s, t) < 0
630
+
631
+ def mpf_le(s, t):
632
+ if s == fnan or t == fnan:
633
+ return False
634
+ return mpf_cmp(s, t) <= 0
635
+
636
+ def mpf_gt(s, t):
637
+ if s == fnan or t == fnan:
638
+ return False
639
+ return mpf_cmp(s, t) > 0
640
+
641
+ def mpf_ge(s, t):
642
+ if s == fnan or t == fnan:
643
+ return False
644
+ return mpf_cmp(s, t) >= 0
645
+
646
+ def mpf_min_max(seq):
647
+ min = max = seq[0]
648
+ for x in seq[1:]:
649
+ if mpf_lt(x, min): min = x
650
+ if mpf_gt(x, max): max = x
651
+ return min, max
652
+
653
+ def mpf_pos(s, prec=0, rnd=round_fast):
654
+ """Calculate 0+s for a raw mpf (i.e., just round s to the specified
655
+ precision)."""
656
+ if prec:
657
+ sign, man, exp, bc = s
658
+ if (not man) and exp:
659
+ return s
660
+ return normalize1(sign, man, exp, bc, prec, rnd)
661
+ return s
662
+
663
+ def mpf_neg(s, prec=None, rnd=round_fast):
664
+ """Negate a raw mpf (return -s), rounding the result to the
665
+ specified precision. The prec argument can be omitted to do the
666
+ operation exactly."""
667
+ sign, man, exp, bc = s
668
+ if not man:
669
+ if exp:
670
+ if s == finf: return fninf
671
+ if s == fninf: return finf
672
+ return s
673
+ if not prec:
674
+ return (1-sign, man, exp, bc)
675
+ return normalize1(1-sign, man, exp, bc, prec, rnd)
676
+
677
+ def mpf_abs(s, prec=None, rnd=round_fast):
678
+ """Return abs(s) of the raw mpf s, rounded to the specified
679
+ precision. The prec argument can be omitted to generate an
680
+ exact result."""
681
+ sign, man, exp, bc = s
682
+ if (not man) and exp:
683
+ if s == fninf:
684
+ return finf
685
+ return s
686
+ if not prec:
687
+ if sign:
688
+ return (0, man, exp, bc)
689
+ return s
690
+ return normalize1(0, man, exp, bc, prec, rnd)
691
+
692
+ def mpf_sign(s):
693
+ """Return -1, 0, or 1 (as a Python int, not a raw mpf) depending on
694
+ whether s is negative, zero, or positive. (Nan is taken to give 0.)"""
695
+ sign, man, exp, bc = s
696
+ if not man:
697
+ if s == finf: return 1
698
+ if s == fninf: return -1
699
+ return 0
700
+ return (-1) ** sign
701
+
702
+ def mpf_add(s, t, prec=0, rnd=round_fast, _sub=0):
703
+ """
704
+ Add the two raw mpf values s and t.
705
+
706
+ With prec=0, no rounding is performed. Note that this can
707
+ produce a very large mantissa (potentially too large to fit
708
+ in memory) if exponents are far apart.
709
+ """
710
+ ssign, sman, sexp, sbc = s
711
+ tsign, tman, texp, tbc = t
712
+ tsign ^= _sub
713
+ # Standard case: two nonzero, regular numbers
714
+ if sman and tman:
715
+ offset = sexp - texp
716
+ if offset:
717
+ if offset > 0:
718
+ # Outside precision range; only need to perturb
719
+ if offset > 100 and prec:
720
+ delta = sbc + sexp - tbc - texp
721
+ if delta > prec + 4:
722
+ offset = prec + 4
723
+ sman <<= offset
724
+ if tsign == ssign: sman += 1
725
+ else: sman -= 1
726
+ return normalize1(ssign, sman, sexp-offset,
727
+ bitcount(sman), prec, rnd)
728
+ # Add
729
+ if ssign == tsign:
730
+ man = tman + (sman << offset)
731
+ # Subtract
732
+ else:
733
+ if ssign: man = tman - (sman << offset)
734
+ else: man = (sman << offset) - tman
735
+ if man >= 0:
736
+ ssign = 0
737
+ else:
738
+ man = -man
739
+ ssign = 1
740
+ bc = bitcount(man)
741
+ return normalize1(ssign, man, texp, bc, prec or bc, rnd)
742
+ elif offset < 0:
743
+ # Outside precision range; only need to perturb
744
+ if offset < -100 and prec:
745
+ delta = tbc + texp - sbc - sexp
746
+ if delta > prec + 4:
747
+ offset = prec + 4
748
+ tman <<= offset
749
+ if ssign == tsign: tman += 1
750
+ else: tman -= 1
751
+ return normalize1(tsign, tman, texp-offset,
752
+ bitcount(tman), prec, rnd)
753
+ # Add
754
+ if ssign == tsign:
755
+ man = sman + (tman << -offset)
756
+ # Subtract
757
+ else:
758
+ if tsign: man = sman - (tman << -offset)
759
+ else: man = (tman << -offset) - sman
760
+ if man >= 0:
761
+ ssign = 0
762
+ else:
763
+ man = -man
764
+ ssign = 1
765
+ bc = bitcount(man)
766
+ return normalize1(ssign, man, sexp, bc, prec or bc, rnd)
767
+ # Equal exponents; no shifting necessary
768
+ if ssign == tsign:
769
+ man = tman + sman
770
+ else:
771
+ if ssign: man = tman - sman
772
+ else: man = sman - tman
773
+ if man >= 0:
774
+ ssign = 0
775
+ else:
776
+ man = -man
777
+ ssign = 1
778
+ bc = bitcount(man)
779
+ return normalize(ssign, man, texp, bc, prec or bc, rnd)
780
+ # Handle zeros and special numbers
781
+ if _sub:
782
+ t = mpf_neg(t)
783
+ if not sman:
784
+ if sexp:
785
+ if s == t or tman or not texp:
786
+ return s
787
+ return fnan
788
+ if tman:
789
+ return normalize1(tsign, tman, texp, tbc, prec or tbc, rnd)
790
+ return t
791
+ if texp:
792
+ return t
793
+ if sman:
794
+ return normalize1(ssign, sman, sexp, sbc, prec or sbc, rnd)
795
+ return s
796
+
797
+ def mpf_sub(s, t, prec=0, rnd=round_fast):
798
+ """Return the difference of two raw mpfs, s-t. This function is
799
+ simply a wrapper of mpf_add that changes the sign of t."""
800
+ return mpf_add(s, t, prec, rnd, 1)
801
+
802
+ def mpf_sum(xs, prec=0, rnd=round_fast, absolute=False):
803
+ """
804
+ Sum a list of mpf values efficiently and accurately
805
+ (typically no temporary roundoff occurs). If prec=0,
806
+ the final result will not be rounded either.
807
+
808
+ There may be roundoff error or cancellation if extremely
809
+ large exponent differences occur.
810
+
811
+ With absolute=True, sums the absolute values.
812
+ """
813
+ man = 0
814
+ exp = 0
815
+ max_extra_prec = prec*2 or 1000000 # XXX
816
+ special = None
817
+ for x in xs:
818
+ xsign, xman, xexp, xbc = x
819
+ if xman:
820
+ if xsign and not absolute:
821
+ xman = -xman
822
+ delta = xexp - exp
823
+ if xexp >= exp:
824
+ # x much larger than existing sum?
825
+ # first: quick test
826
+ if (delta > max_extra_prec) and \
827
+ ((not man) or delta-bitcount(abs(man)) > max_extra_prec):
828
+ man = xman
829
+ exp = xexp
830
+ else:
831
+ man += (xman << delta)
832
+ else:
833
+ delta = -delta
834
+ # x much smaller than existing sum?
835
+ if delta-xbc > max_extra_prec:
836
+ if not man:
837
+ man, exp = xman, xexp
838
+ else:
839
+ man = (man << delta) + xman
840
+ exp = xexp
841
+ elif xexp:
842
+ if absolute:
843
+ x = mpf_abs(x)
844
+ special = mpf_add(special or fzero, x, 1)
845
+ # Will be inf or nan
846
+ if special:
847
+ return special
848
+ return from_man_exp(man, exp, prec, rnd)
849
+
850
+ def gmpy_mpf_mul(s, t, prec=0, rnd=round_fast):
851
+ """Multiply two raw mpfs"""
852
+ ssign, sman, sexp, sbc = s
853
+ tsign, tman, texp, tbc = t
854
+ sign = ssign ^ tsign
855
+ man = sman*tman
856
+ if man:
857
+ bc = bitcount(man)
858
+ if prec:
859
+ return normalize1(sign, man, sexp+texp, bc, prec, rnd)
860
+ else:
861
+ return (sign, man, sexp+texp, bc)
862
+ s_special = (not sman) and sexp
863
+ t_special = (not tman) and texp
864
+ if not s_special and not t_special:
865
+ return fzero
866
+ if fnan in (s, t): return fnan
867
+ if (not tman) and texp: s, t = t, s
868
+ if t == fzero: return fnan
869
+ return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)]
870
+
871
+ def gmpy_mpf_mul_int(s, n, prec, rnd=round_fast):
872
+ """Multiply by a Python integer."""
873
+ sign, man, exp, bc = s
874
+ if not man:
875
+ return mpf_mul(s, from_int(n), prec, rnd)
876
+ if not n:
877
+ return fzero
878
+ if n < 0:
879
+ sign ^= 1
880
+ n = -n
881
+ man *= n
882
+ return normalize(sign, man, exp, bitcount(man), prec, rnd)
883
+
884
+ def python_mpf_mul(s, t, prec=0, rnd=round_fast):
885
+ """Multiply two raw mpfs"""
886
+ ssign, sman, sexp, sbc = s
887
+ tsign, tman, texp, tbc = t
888
+ sign = ssign ^ tsign
889
+ man = sman*tman
890
+ if man:
891
+ bc = sbc + tbc - 1
892
+ bc += int(man>>bc)
893
+ if prec:
894
+ return normalize1(sign, man, sexp+texp, bc, prec, rnd)
895
+ else:
896
+ return (sign, man, sexp+texp, bc)
897
+ s_special = (not sman) and sexp
898
+ t_special = (not tman) and texp
899
+ if not s_special and not t_special:
900
+ return fzero
901
+ if fnan in (s, t): return fnan
902
+ if (not tman) and texp: s, t = t, s
903
+ if t == fzero: return fnan
904
+ return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)]
905
+
906
+ def python_mpf_mul_int(s, n, prec, rnd=round_fast):
907
+ """Multiply by a Python integer."""
908
+ sign, man, exp, bc = s
909
+ if not man:
910
+ return mpf_mul(s, from_int(n), prec, rnd)
911
+ if not n:
912
+ return fzero
913
+ if n < 0:
914
+ sign ^= 1
915
+ n = -n
916
+ man *= n
917
+ # Generally n will be small
918
+ if n < 1024:
919
+ bc += bctable[int(n)] - 1
920
+ else:
921
+ bc += bitcount(n) - 1
922
+ bc += int(man>>bc)
923
+ return normalize(sign, man, exp, bc, prec, rnd)
924
+
925
+
926
+ if BACKEND == 'gmpy':
927
+ mpf_mul = gmpy_mpf_mul
928
+ mpf_mul_int = gmpy_mpf_mul_int
929
+ else:
930
+ mpf_mul = python_mpf_mul
931
+ mpf_mul_int = python_mpf_mul_int
932
+
933
+ def mpf_shift(s, n):
934
+ """Quickly multiply the raw mpf s by 2**n without rounding."""
935
+ sign, man, exp, bc = s
936
+ if not man:
937
+ return s
938
+ return sign, man, exp+n, bc
939
+
940
+ def mpf_frexp(x):
941
+ """Convert x = y*2**n to (y, n) with abs(y) in [0.5, 1) if nonzero"""
942
+ sign, man, exp, bc = x
943
+ if not man:
944
+ if x == fzero:
945
+ return (fzero, 0)
946
+ else:
947
+ raise ValueError
948
+ return mpf_shift(x, -bc-exp), bc+exp
949
+
950
+ def mpf_div(s, t, prec, rnd=round_fast):
951
+ """Floating-point division"""
952
+ ssign, sman, sexp, sbc = s
953
+ tsign, tman, texp, tbc = t
954
+ if not sman or not tman:
955
+ if s == fzero:
956
+ if t == fzero: raise ZeroDivisionError
957
+ if t == fnan: return fnan
958
+ return fzero
959
+ if t == fzero:
960
+ raise ZeroDivisionError
961
+ s_special = (not sman) and sexp
962
+ t_special = (not tman) and texp
963
+ if s_special and t_special:
964
+ return fnan
965
+ if s == fnan or t == fnan:
966
+ return fnan
967
+ if not t_special:
968
+ if t == fzero:
969
+ return fnan
970
+ return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)]
971
+ return fzero
972
+ sign = ssign ^ tsign
973
+ if tman == 1:
974
+ return normalize1(sign, sman, sexp-texp, sbc, prec, rnd)
975
+ # Same strategy as for addition: if there is a remainder, perturb
976
+ # the result a few bits outside the precision range before rounding
977
+ extra = prec - sbc + tbc + 5
978
+ if extra < 5:
979
+ extra = 5
980
+ quot, rem = divmod(sman<<extra, tman)
981
+ if rem:
982
+ quot = (quot<<1) + 1
983
+ extra += 1
984
+ return normalize1(sign, quot, sexp-texp-extra, bitcount(quot), prec, rnd)
985
+ return normalize(sign, quot, sexp-texp-extra, bitcount(quot), prec, rnd)
986
+
987
+ def mpf_rdiv_int(n, t, prec, rnd=round_fast):
988
+ """Floating-point division n/t with a Python integer as numerator"""
989
+ sign, man, exp, bc = t
990
+ if not n or not man:
991
+ return mpf_div(from_int(n), t, prec, rnd)
992
+ if n < 0:
993
+ sign ^= 1
994
+ n = -n
995
+ extra = prec + bc + 5
996
+ quot, rem = divmod(n<<extra, man)
997
+ if rem:
998
+ quot = (quot<<1) + 1
999
+ extra += 1
1000
+ return normalize1(sign, quot, -exp-extra, bitcount(quot), prec, rnd)
1001
+ return normalize(sign, quot, -exp-extra, bitcount(quot), prec, rnd)
1002
+
1003
+ def mpf_mod(s, t, prec, rnd=round_fast):
1004
+ ssign, sman, sexp, sbc = s
1005
+ tsign, tman, texp, tbc = t
1006
+ if ((not sman) and sexp) or ((not tman) and texp):
1007
+ return fnan
1008
+ # Important special case: do nothing if t is larger
1009
+ if ssign == tsign and texp > sexp+sbc:
1010
+ return s
1011
+ # Another important special case: this allows us to do e.g. x % 1.0
1012
+ # to find the fractional part of x, and it will work when x is huge.
1013
+ if tman == 1 and sexp > texp+tbc:
1014
+ return fzero
1015
+ base = min(sexp, texp)
1016
+ sman = (-1)**ssign * sman
1017
+ tman = (-1)**tsign * tman
1018
+ man = (sman << (sexp-base)) % (tman << (texp-base))
1019
+ if man >= 0:
1020
+ sign = 0
1021
+ else:
1022
+ man = -man
1023
+ sign = 1
1024
+ return normalize(sign, man, base, bitcount(man), prec, rnd)
1025
+
1026
+ reciprocal_rnd = {
1027
+ round_down : round_up,
1028
+ round_up : round_down,
1029
+ round_floor : round_ceiling,
1030
+ round_ceiling : round_floor,
1031
+ round_nearest : round_nearest
1032
+ }
1033
+
1034
+ negative_rnd = {
1035
+ round_down : round_down,
1036
+ round_up : round_up,
1037
+ round_floor : round_ceiling,
1038
+ round_ceiling : round_floor,
1039
+ round_nearest : round_nearest
1040
+ }
1041
+
1042
+ def mpf_pow_int(s, n, prec, rnd=round_fast):
1043
+ """Compute s**n, where s is a raw mpf and n is a Python integer."""
1044
+ sign, man, exp, bc = s
1045
+
1046
+ if (not man) and exp:
1047
+ if s == finf:
1048
+ if n > 0: return s
1049
+ if n == 0: return fnan
1050
+ return fzero
1051
+ if s == fninf:
1052
+ if n > 0: return [finf, fninf][n & 1]
1053
+ if n == 0: return fnan
1054
+ return fzero
1055
+ return fnan
1056
+
1057
+ n = int(n)
1058
+ if n == 0: return fone
1059
+ if n == 1: return mpf_pos(s, prec, rnd)
1060
+ if n == 2:
1061
+ _, man, exp, bc = s
1062
+ if not man:
1063
+ return fzero
1064
+ man = man*man
1065
+ if man == 1:
1066
+ return (0, MPZ_ONE, exp+exp, 1)
1067
+ bc = bc + bc - 2
1068
+ bc += bctable[int(man>>bc)]
1069
+ return normalize1(0, man, exp+exp, bc, prec, rnd)
1070
+ if n == -1: return mpf_div(fone, s, prec, rnd)
1071
+ if n < 0:
1072
+ inverse = mpf_pow_int(s, -n, prec+5, reciprocal_rnd[rnd])
1073
+ return mpf_div(fone, inverse, prec, rnd)
1074
+
1075
+ result_sign = sign & n
1076
+
1077
+ # Use exact integer power when the exact mantissa is small
1078
+ if man == 1:
1079
+ return (result_sign, MPZ_ONE, exp*n, 1)
1080
+ if bc*n < 1000:
1081
+ man **= n
1082
+ return normalize1(result_sign, man, exp*n, bitcount(man), prec, rnd)
1083
+
1084
+ # Use directed rounding all the way through to maintain rigorous
1085
+ # bounds for interval arithmetic
1086
+ rounds_down = (rnd == round_nearest) or \
1087
+ shifts_down[rnd][result_sign]
1088
+
1089
+ # Now we perform binary exponentiation. Need to estimate precision
1090
+ # to avoid rounding errors from temporary operations. Roughly log_2(n)
1091
+ # operations are performed.
1092
+ workprec = prec + 4*bitcount(n) + 4
1093
+ _, pm, pe, pbc = fone
1094
+ while 1:
1095
+ if n & 1:
1096
+ pm = pm*man
1097
+ pe = pe+exp
1098
+ pbc += bc - 2
1099
+ pbc = pbc + bctable[int(pm >> pbc)]
1100
+ if pbc > workprec:
1101
+ if rounds_down:
1102
+ pm = pm >> (pbc-workprec)
1103
+ else:
1104
+ pm = -((-pm) >> (pbc-workprec))
1105
+ pe += pbc - workprec
1106
+ pbc = workprec
1107
+ n -= 1
1108
+ if not n:
1109
+ break
1110
+ man = man*man
1111
+ exp = exp+exp
1112
+ bc = bc + bc - 2
1113
+ bc = bc + bctable[int(man >> bc)]
1114
+ if bc > workprec:
1115
+ if rounds_down:
1116
+ man = man >> (bc-workprec)
1117
+ else:
1118
+ man = -((-man) >> (bc-workprec))
1119
+ exp += bc - workprec
1120
+ bc = workprec
1121
+ n = n // 2
1122
+
1123
+ return normalize(result_sign, pm, pe, pbc, prec, rnd)
1124
+
1125
+
1126
+ def mpf_perturb(x, eps_sign, prec, rnd):
1127
+ """
1128
+ For nonzero x, calculate x + eps with directed rounding, where
1129
+ eps < prec relatively and eps has the given sign (0 for
1130
+ positive, 1 for negative).
1131
+
1132
+ With rounding to nearest, this is taken to simply normalize
1133
+ x to the given precision.
1134
+ """
1135
+ if rnd == round_nearest:
1136
+ return mpf_pos(x, prec, rnd)
1137
+ sign, man, exp, bc = x
1138
+ eps = (eps_sign, MPZ_ONE, exp+bc-prec-1, 1)
1139
+ if sign:
1140
+ away = (rnd in (round_down, round_ceiling)) ^ eps_sign
1141
+ else:
1142
+ away = (rnd in (round_up, round_ceiling)) ^ eps_sign
1143
+ if away:
1144
+ return mpf_add(x, eps, prec, rnd)
1145
+ else:
1146
+ return mpf_pos(x, prec, rnd)
1147
+
1148
+
1149
+ #----------------------------------------------------------------------------#
1150
+ # Radix conversion #
1151
+ #----------------------------------------------------------------------------#
1152
+
1153
+ def to_digits_exp(s, dps):
1154
+ """Helper function for representing the floating-point number s as
1155
+ a decimal with dps digits. Returns (sign, string, exponent) where
1156
+ sign is '' or '-', string is the digit string, and exponent is
1157
+ the decimal exponent as an int.
1158
+
1159
+ If inexact, the decimal representation is rounded toward zero."""
1160
+
1161
+ # Extract sign first so it doesn't mess up the string digit count
1162
+ if s[0]:
1163
+ sign = '-'
1164
+ s = mpf_neg(s)
1165
+ else:
1166
+ sign = ''
1167
+ _sign, man, exp, bc = s
1168
+
1169
+ if not man:
1170
+ return '', '0', 0
1171
+
1172
+ bitprec = int(dps * math.log(10,2)) + 10
1173
+
1174
+ # Cut down to size
1175
+ # TODO: account for precision when doing this
1176
+ exp_from_1 = exp + bc
1177
+ if abs(exp_from_1) > 3500:
1178
+ from .libelefun import mpf_ln2, mpf_ln10
1179
+ # Set b = int(exp * log(2)/log(10))
1180
+ # If exp is huge, we must use high-precision arithmetic to
1181
+ # find the nearest power of ten
1182
+ expprec = bitcount(abs(exp)) + 5
1183
+ tmp = from_int(exp)
1184
+ tmp = mpf_mul(tmp, mpf_ln2(expprec))
1185
+ tmp = mpf_div(tmp, mpf_ln10(expprec), expprec)
1186
+ b = to_int(tmp)
1187
+ s = mpf_div(s, mpf_pow_int(ften, b, bitprec), bitprec)
1188
+ _sign, man, exp, bc = s
1189
+ exponent = b
1190
+ else:
1191
+ exponent = 0
1192
+
1193
+ # First, calculate mantissa digits by converting to a binary
1194
+ # fixed-point number and then converting that number to
1195
+ # a decimal fixed-point number.
1196
+ fixprec = max(bitprec - exp - bc, 0)
1197
+ fixdps = int(fixprec / math.log(10,2) + 0.5)
1198
+ sf = to_fixed(s, fixprec)
1199
+ sd = bin_to_radix(sf, fixprec, 10, fixdps)
1200
+ digits = numeral(sd, base=10, size=dps)
1201
+
1202
+ exponent += len(digits) - fixdps - 1
1203
+ return sign, digits, exponent
1204
+
1205
+ def to_str(s, dps, strip_zeros=True, min_fixed=None, max_fixed=None,
1206
+ show_zero_exponent=False):
1207
+ """
1208
+ Convert a raw mpf to a decimal floating-point literal with at
1209
+ most `dps` decimal digits in the mantissa (not counting extra zeros
1210
+ that may be inserted for visual purposes).
1211
+
1212
+ The number will be printed in fixed-point format if the position
1213
+ of the leading digit is strictly between min_fixed
1214
+ (default = min(-dps/3,-5)) and max_fixed (default = dps).
1215
+
1216
+ To force fixed-point format always, set min_fixed = -inf,
1217
+ max_fixed = +inf. To force floating-point format, set
1218
+ min_fixed >= max_fixed.
1219
+
1220
+ The literal is formatted so that it can be parsed back to a number
1221
+ by to_str, float() or Decimal().
1222
+ """
1223
+
1224
+ # Special numbers
1225
+ if not s[1]:
1226
+ if s == fzero:
1227
+ if dps: t = '0.0'
1228
+ else: t = '.0'
1229
+ if show_zero_exponent:
1230
+ t += 'e+0'
1231
+ return t
1232
+ if s == finf: return '+inf'
1233
+ if s == fninf: return '-inf'
1234
+ if s == fnan: return 'nan'
1235
+ raise ValueError
1236
+
1237
+ if min_fixed is None: min_fixed = min(-(dps//3), -5)
1238
+ if max_fixed is None: max_fixed = dps
1239
+
1240
+ # to_digits_exp rounds to floor.
1241
+ # This sometimes kills some instances of "...00001"
1242
+ sign, digits, exponent = to_digits_exp(s, dps+3)
1243
+
1244
+ # No digits: show only .0; round exponent to nearest
1245
+ if not dps:
1246
+ if digits[0] in '56789':
1247
+ exponent += 1
1248
+ digits = ".0"
1249
+
1250
+ else:
1251
+ # Rounding up kills some instances of "...99999"
1252
+ if len(digits) > dps and digits[dps] in '56789':
1253
+ digits = digits[:dps]
1254
+ i = dps - 1
1255
+ while i >= 0 and digits[i] == '9':
1256
+ i -= 1
1257
+ if i >= 0:
1258
+ digits = digits[:i] + str(int(digits[i]) + 1) + '0' * (dps - i - 1)
1259
+ else:
1260
+ digits = '1' + '0' * (dps - 1)
1261
+ exponent += 1
1262
+ else:
1263
+ digits = digits[:dps]
1264
+
1265
+ # Prettify numbers close to unit magnitude
1266
+ if min_fixed < exponent < max_fixed:
1267
+ if exponent < 0:
1268
+ digits = ("0"*int(-exponent)) + digits
1269
+ split = 1
1270
+ else:
1271
+ split = exponent + 1
1272
+ if split > dps:
1273
+ digits += "0"*(split-dps)
1274
+ exponent = 0
1275
+ else:
1276
+ split = 1
1277
+
1278
+ digits = (digits[:split] + "." + digits[split:])
1279
+
1280
+ if strip_zeros:
1281
+ # Clean up trailing zeros
1282
+ digits = digits.rstrip('0')
1283
+ if digits[-1] == ".":
1284
+ digits += "0"
1285
+
1286
+ if exponent == 0 and dps and not show_zero_exponent: return sign + digits
1287
+ if exponent >= 0: return sign + digits + "e+" + str(exponent)
1288
+ if exponent < 0: return sign + digits + "e" + str(exponent)
1289
+
1290
+ def str_to_man_exp(x, base=10):
1291
+ """Helper function for from_str."""
1292
+ x = x.lower().rstrip('l')
1293
+ # Verify that the input is a valid float literal
1294
+ float(x)
1295
+ # Split into mantissa, exponent
1296
+ parts = x.split('e')
1297
+ if len(parts) == 1:
1298
+ exp = 0
1299
+ else: # == 2
1300
+ x = parts[0]
1301
+ exp = int(parts[1])
1302
+ # Look for radix point in mantissa
1303
+ parts = x.split('.')
1304
+ if len(parts) == 2:
1305
+ a, b = parts[0], parts[1].rstrip('0')
1306
+ exp -= len(b)
1307
+ x = a + b
1308
+ x = MPZ(int(x, base))
1309
+ return x, exp
1310
+
1311
+ special_str = {'inf':finf, '+inf':finf, '-inf':fninf, 'nan':fnan}
1312
+
1313
+ def from_str(x, prec, rnd=round_fast):
1314
+ """Create a raw mpf from a decimal literal, rounding in the
1315
+ specified direction if the input number cannot be represented
1316
+ exactly as a binary floating-point number with the given number of
1317
+ bits. The literal syntax accepted is the same as for Python
1318
+ floats.
1319
+
1320
+ TODO: the rounding does not work properly for large exponents.
1321
+ """
1322
+ x = x.lower().strip()
1323
+ if x in special_str:
1324
+ return special_str[x]
1325
+
1326
+ if '/' in x:
1327
+ p, q = x.split('/')
1328
+ p, q = p.rstrip('l'), q.rstrip('l')
1329
+ return from_rational(int(p), int(q), prec, rnd)
1330
+
1331
+ man, exp = str_to_man_exp(x, base=10)
1332
+
1333
+ # XXX: appropriate cutoffs & track direction
1334
+ # note no factors of 5
1335
+ if abs(exp) > 400:
1336
+ s = from_int(man, prec+10)
1337
+ s = mpf_mul(s, mpf_pow_int(ften, exp, prec+10), prec, rnd)
1338
+ else:
1339
+ if exp >= 0:
1340
+ s = from_int(man * 10**exp, prec, rnd)
1341
+ else:
1342
+ s = from_rational(man, 10**-exp, prec, rnd)
1343
+ return s
1344
+
1345
+ # Binary string conversion. These are currently mainly used for debugging
1346
+ # and could use some improvement in the future
1347
+
1348
+ def from_bstr(x):
1349
+ man, exp = str_to_man_exp(x, base=2)
1350
+ man = MPZ(man)
1351
+ sign = 0
1352
+ if man < 0:
1353
+ man = -man
1354
+ sign = 1
1355
+ bc = bitcount(man)
1356
+ return normalize(sign, man, exp, bc, bc, round_floor)
1357
+
1358
+ def to_bstr(x):
1359
+ sign, man, exp, bc = x
1360
+ return ['','-'][sign] + numeral(man, size=bitcount(man), base=2) + ("e%i" % exp)
1361
+
1362
+
1363
+ #----------------------------------------------------------------------------#
1364
+ # Square roots #
1365
+ #----------------------------------------------------------------------------#
1366
+
1367
+
1368
+ def mpf_sqrt(s, prec, rnd=round_fast):
1369
+ """
1370
+ Compute the square root of a nonnegative mpf value. The
1371
+ result is correctly rounded.
1372
+ """
1373
+ sign, man, exp, bc = s
1374
+ if sign:
1375
+ raise ComplexResult("square root of a negative number")
1376
+ if not man:
1377
+ return s
1378
+ if exp & 1:
1379
+ exp -= 1
1380
+ man <<= 1
1381
+ bc += 1
1382
+ elif man == 1:
1383
+ return normalize1(sign, man, exp//2, bc, prec, rnd)
1384
+ shift = max(4, 2*prec-bc+4)
1385
+ shift += shift & 1
1386
+ if rnd in 'fd':
1387
+ man = isqrt(man<<shift)
1388
+ else:
1389
+ man, rem = sqrtrem(man<<shift)
1390
+ # Perturb up
1391
+ if rem:
1392
+ man = (man<<1)+1
1393
+ shift += 2
1394
+ return from_man_exp(man, (exp-shift)//2, prec, rnd)
1395
+
1396
+ def mpf_hypot(x, y, prec, rnd=round_fast):
1397
+ """Compute the Euclidean norm sqrt(x**2 + y**2) of two raw mpfs
1398
+ x and y."""
1399
+ if y == fzero: return mpf_abs(x, prec, rnd)
1400
+ if x == fzero: return mpf_abs(y, prec, rnd)
1401
+ hypot2 = mpf_add(mpf_mul(x,x), mpf_mul(y,y), prec+4)
1402
+ return mpf_sqrt(hypot2, prec, rnd)
1403
+
1404
+
1405
+ if BACKEND == 'sage':
1406
+ try:
1407
+ import sage.libs.mpmath.ext_libmp as ext_lib
1408
+ mpf_add = ext_lib.mpf_add
1409
+ mpf_sub = ext_lib.mpf_sub
1410
+ mpf_mul = ext_lib.mpf_mul
1411
+ mpf_div = ext_lib.mpf_div
1412
+ mpf_sqrt = ext_lib.mpf_sqrt
1413
+ except ImportError:
1414
+ pass
lib/python3.11/site-packages/mpmath/libmp/libmpi.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Computational functions for interval arithmetic.
3
+
4
+ """
5
+
6
+ from .backend import xrange
7
+
8
+ from .libmpf import (
9
+ ComplexResult,
10
+ round_down, round_up, round_floor, round_ceiling, round_nearest,
11
+ prec_to_dps, repr_dps, dps_to_prec,
12
+ bitcount,
13
+ from_float,
14
+ fnan, finf, fninf, fzero, fhalf, fone, fnone,
15
+ mpf_sign, mpf_lt, mpf_le, mpf_gt, mpf_ge, mpf_eq, mpf_cmp,
16
+ mpf_min_max,
17
+ mpf_floor, from_int, to_int, to_str, from_str,
18
+ mpf_abs, mpf_neg, mpf_pos, mpf_add, mpf_sub, mpf_mul, mpf_mul_int,
19
+ mpf_div, mpf_shift, mpf_pow_int,
20
+ from_man_exp, MPZ_ONE)
21
+
22
+ from .libelefun import (
23
+ mpf_log, mpf_exp, mpf_sqrt, mpf_atan, mpf_atan2,
24
+ mpf_pi, mod_pi2, mpf_cos_sin
25
+ )
26
+
27
+ from .gammazeta import mpf_gamma, mpf_rgamma, mpf_loggamma, mpc_loggamma
28
+
29
+ def mpi_str(s, prec):
30
+ sa, sb = s
31
+ dps = prec_to_dps(prec) + 5
32
+ return "[%s, %s]" % (to_str(sa, dps), to_str(sb, dps))
33
+ #dps = prec_to_dps(prec)
34
+ #m = mpi_mid(s, prec)
35
+ #d = mpf_shift(mpi_delta(s, 20), -1)
36
+ #return "%s +/- %s" % (to_str(m, dps), to_str(d, 3))
37
+
38
+ mpi_zero = (fzero, fzero)
39
+ mpi_one = (fone, fone)
40
+
41
+ def mpi_eq(s, t):
42
+ return s == t
43
+
44
+ def mpi_ne(s, t):
45
+ return s != t
46
+
47
+ def mpi_lt(s, t):
48
+ sa, sb = s
49
+ ta, tb = t
50
+ if mpf_lt(sb, ta): return True
51
+ if mpf_ge(sa, tb): return False
52
+ return None
53
+
54
+ def mpi_le(s, t):
55
+ sa, sb = s
56
+ ta, tb = t
57
+ if mpf_le(sb, ta): return True
58
+ if mpf_gt(sa, tb): return False
59
+ return None
60
+
61
+ def mpi_gt(s, t): return mpi_lt(t, s)
62
+ def mpi_ge(s, t): return mpi_le(t, s)
63
+
64
+ def mpi_add(s, t, prec=0):
65
+ sa, sb = s
66
+ ta, tb = t
67
+ a = mpf_add(sa, ta, prec, round_floor)
68
+ b = mpf_add(sb, tb, prec, round_ceiling)
69
+ if a == fnan: a = fninf
70
+ if b == fnan: b = finf
71
+ return a, b
72
+
73
+ def mpi_sub(s, t, prec=0):
74
+ sa, sb = s
75
+ ta, tb = t
76
+ a = mpf_sub(sa, tb, prec, round_floor)
77
+ b = mpf_sub(sb, ta, prec, round_ceiling)
78
+ if a == fnan: a = fninf
79
+ if b == fnan: b = finf
80
+ return a, b
81
+
82
+ def mpi_delta(s, prec):
83
+ sa, sb = s
84
+ return mpf_sub(sb, sa, prec, round_up)
85
+
86
+ def mpi_mid(s, prec):
87
+ sa, sb = s
88
+ return mpf_shift(mpf_add(sa, sb, prec, round_nearest), -1)
89
+
90
+ def mpi_pos(s, prec):
91
+ sa, sb = s
92
+ a = mpf_pos(sa, prec, round_floor)
93
+ b = mpf_pos(sb, prec, round_ceiling)
94
+ return a, b
95
+
96
+ def mpi_neg(s, prec=0):
97
+ sa, sb = s
98
+ a = mpf_neg(sb, prec, round_floor)
99
+ b = mpf_neg(sa, prec, round_ceiling)
100
+ return a, b
101
+
102
+ def mpi_abs(s, prec=0):
103
+ sa, sb = s
104
+ sas = mpf_sign(sa)
105
+ sbs = mpf_sign(sb)
106
+ # Both points nonnegative?
107
+ if sas >= 0:
108
+ a = mpf_pos(sa, prec, round_floor)
109
+ b = mpf_pos(sb, prec, round_ceiling)
110
+ # Upper point nonnegative?
111
+ elif sbs >= 0:
112
+ a = fzero
113
+ negsa = mpf_neg(sa)
114
+ if mpf_lt(negsa, sb):
115
+ b = mpf_pos(sb, prec, round_ceiling)
116
+ else:
117
+ b = mpf_pos(negsa, prec, round_ceiling)
118
+ # Both negative?
119
+ else:
120
+ a = mpf_neg(sb, prec, round_floor)
121
+ b = mpf_neg(sa, prec, round_ceiling)
122
+ return a, b
123
+
124
+ # TODO: optimize
125
+ def mpi_mul_mpf(s, t, prec):
126
+ return mpi_mul(s, (t, t), prec)
127
+
128
+ def mpi_div_mpf(s, t, prec):
129
+ return mpi_div(s, (t, t), prec)
130
+
131
+ def mpi_mul(s, t, prec=0):
132
+ sa, sb = s
133
+ ta, tb = t
134
+ sas = mpf_sign(sa)
135
+ sbs = mpf_sign(sb)
136
+ tas = mpf_sign(ta)
137
+ tbs = mpf_sign(tb)
138
+ if sas == sbs == 0:
139
+ # Should maybe be undefined
140
+ if ta == fninf or tb == finf:
141
+ return fninf, finf
142
+ return fzero, fzero
143
+ if tas == tbs == 0:
144
+ # Should maybe be undefined
145
+ if sa == fninf or sb == finf:
146
+ return fninf, finf
147
+ return fzero, fzero
148
+ if sas >= 0:
149
+ # positive * positive
150
+ if tas >= 0:
151
+ a = mpf_mul(sa, ta, prec, round_floor)
152
+ b = mpf_mul(sb, tb, prec, round_ceiling)
153
+ if a == fnan: a = fzero
154
+ if b == fnan: b = finf
155
+ # positive * negative
156
+ elif tbs <= 0:
157
+ a = mpf_mul(sb, ta, prec, round_floor)
158
+ b = mpf_mul(sa, tb, prec, round_ceiling)
159
+ if a == fnan: a = fninf
160
+ if b == fnan: b = fzero
161
+ # positive * both signs
162
+ else:
163
+ a = mpf_mul(sb, ta, prec, round_floor)
164
+ b = mpf_mul(sb, tb, prec, round_ceiling)
165
+ if a == fnan: a = fninf
166
+ if b == fnan: b = finf
167
+ elif sbs <= 0:
168
+ # negative * positive
169
+ if tas >= 0:
170
+ a = mpf_mul(sa, tb, prec, round_floor)
171
+ b = mpf_mul(sb, ta, prec, round_ceiling)
172
+ if a == fnan: a = fninf
173
+ if b == fnan: b = fzero
174
+ # negative * negative
175
+ elif tbs <= 0:
176
+ a = mpf_mul(sb, tb, prec, round_floor)
177
+ b = mpf_mul(sa, ta, prec, round_ceiling)
178
+ if a == fnan: a = fzero
179
+ if b == fnan: b = finf
180
+ # negative * both signs
181
+ else:
182
+ a = mpf_mul(sa, tb, prec, round_floor)
183
+ b = mpf_mul(sa, ta, prec, round_ceiling)
184
+ if a == fnan: a = fninf
185
+ if b == fnan: b = finf
186
+ else:
187
+ # General case: perform all cross-multiplications and compare
188
+ # Since the multiplications can be done exactly, we need only
189
+ # do 4 (instead of 8: two for each rounding mode)
190
+ cases = [mpf_mul(sa, ta), mpf_mul(sa, tb), mpf_mul(sb, ta), mpf_mul(sb, tb)]
191
+ if fnan in cases:
192
+ a, b = (fninf, finf)
193
+ else:
194
+ a, b = mpf_min_max(cases)
195
+ a = mpf_pos(a, prec, round_floor)
196
+ b = mpf_pos(b, prec, round_ceiling)
197
+ return a, b
198
+
199
+ def mpi_square(s, prec=0):
200
+ sa, sb = s
201
+ if mpf_ge(sa, fzero):
202
+ a = mpf_mul(sa, sa, prec, round_floor)
203
+ b = mpf_mul(sb, sb, prec, round_ceiling)
204
+ elif mpf_le(sb, fzero):
205
+ a = mpf_mul(sb, sb, prec, round_floor)
206
+ b = mpf_mul(sa, sa, prec, round_ceiling)
207
+ else:
208
+ sa = mpf_neg(sa)
209
+ sa, sb = mpf_min_max([sa, sb])
210
+ a = fzero
211
+ b = mpf_mul(sb, sb, prec, round_ceiling)
212
+ return a, b
213
+
214
+ def mpi_div(s, t, prec):
215
+ sa, sb = s
216
+ ta, tb = t
217
+ sas = mpf_sign(sa)
218
+ sbs = mpf_sign(sb)
219
+ tas = mpf_sign(ta)
220
+ tbs = mpf_sign(tb)
221
+ # 0 / X
222
+ if sas == sbs == 0:
223
+ # 0 / <interval containing 0>
224
+ if (tas < 0 and tbs > 0) or (tas == 0 or tbs == 0):
225
+ return fninf, finf
226
+ return fzero, fzero
227
+ # Denominator contains both negative and positive numbers;
228
+ # this should properly be a multi-interval, but the closest
229
+ # match is the entire (extended) real line
230
+ if tas < 0 and tbs > 0:
231
+ return fninf, finf
232
+ # Assume denominator to be nonnegative
233
+ if tas < 0:
234
+ return mpi_div(mpi_neg(s), mpi_neg(t), prec)
235
+ # Division by zero
236
+ # XXX: make sure all results make sense
237
+ if tas == 0:
238
+ # Numerator contains both signs?
239
+ if sas < 0 and sbs > 0:
240
+ return fninf, finf
241
+ if tas == tbs:
242
+ return fninf, finf
243
+ # Numerator positive?
244
+ if sas >= 0:
245
+ a = mpf_div(sa, tb, prec, round_floor)
246
+ b = finf
247
+ if sbs <= 0:
248
+ a = fninf
249
+ b = mpf_div(sb, tb, prec, round_ceiling)
250
+ # Division with positive denominator
251
+ # We still have to handle nans resulting from inf/0 or inf/inf
252
+ else:
253
+ # Nonnegative numerator
254
+ if sas >= 0:
255
+ a = mpf_div(sa, tb, prec, round_floor)
256
+ b = mpf_div(sb, ta, prec, round_ceiling)
257
+ if a == fnan: a = fzero
258
+ if b == fnan: b = finf
259
+ # Nonpositive numerator
260
+ elif sbs <= 0:
261
+ a = mpf_div(sa, ta, prec, round_floor)
262
+ b = mpf_div(sb, tb, prec, round_ceiling)
263
+ if a == fnan: a = fninf
264
+ if b == fnan: b = fzero
265
+ # Numerator contains both signs?
266
+ else:
267
+ a = mpf_div(sa, ta, prec, round_floor)
268
+ b = mpf_div(sb, ta, prec, round_ceiling)
269
+ if a == fnan: a = fninf
270
+ if b == fnan: b = finf
271
+ return a, b
272
+
273
+ def mpi_pi(prec):
274
+ a = mpf_pi(prec, round_floor)
275
+ b = mpf_pi(prec, round_ceiling)
276
+ return a, b
277
+
278
+ def mpi_exp(s, prec):
279
+ sa, sb = s
280
+ # exp is monotonic
281
+ a = mpf_exp(sa, prec, round_floor)
282
+ b = mpf_exp(sb, prec, round_ceiling)
283
+ return a, b
284
+
285
+ def mpi_log(s, prec):
286
+ sa, sb = s
287
+ # log is monotonic
288
+ a = mpf_log(sa, prec, round_floor)
289
+ b = mpf_log(sb, prec, round_ceiling)
290
+ return a, b
291
+
292
+ def mpi_sqrt(s, prec):
293
+ sa, sb = s
294
+ # sqrt is monotonic
295
+ a = mpf_sqrt(sa, prec, round_floor)
296
+ b = mpf_sqrt(sb, prec, round_ceiling)
297
+ return a, b
298
+
299
+ def mpi_atan(s, prec):
300
+ sa, sb = s
301
+ a = mpf_atan(sa, prec, round_floor)
302
+ b = mpf_atan(sb, prec, round_ceiling)
303
+ return a, b
304
+
305
+ def mpi_pow_int(s, n, prec):
306
+ sa, sb = s
307
+ if n < 0:
308
+ return mpi_div((fone, fone), mpi_pow_int(s, -n, prec+20), prec)
309
+ if n == 0:
310
+ return (fone, fone)
311
+ if n == 1:
312
+ return s
313
+ if n == 2:
314
+ return mpi_square(s, prec)
315
+ # Odd -- signs are preserved
316
+ if n & 1:
317
+ a = mpf_pow_int(sa, n, prec, round_floor)
318
+ b = mpf_pow_int(sb, n, prec, round_ceiling)
319
+ # Even -- important to ensure positivity
320
+ else:
321
+ sas = mpf_sign(sa)
322
+ sbs = mpf_sign(sb)
323
+ # Nonnegative?
324
+ if sas >= 0:
325
+ a = mpf_pow_int(sa, n, prec, round_floor)
326
+ b = mpf_pow_int(sb, n, prec, round_ceiling)
327
+ # Nonpositive?
328
+ elif sbs <= 0:
329
+ a = mpf_pow_int(sb, n, prec, round_floor)
330
+ b = mpf_pow_int(sa, n, prec, round_ceiling)
331
+ # Mixed signs?
332
+ else:
333
+ a = fzero
334
+ # max(-a,b)**n
335
+ sa = mpf_neg(sa)
336
+ if mpf_ge(sa, sb):
337
+ b = mpf_pow_int(sa, n, prec, round_ceiling)
338
+ else:
339
+ b = mpf_pow_int(sb, n, prec, round_ceiling)
340
+ return a, b
341
+
342
+ def mpi_pow(s, t, prec):
343
+ ta, tb = t
344
+ if ta == tb and ta not in (finf, fninf):
345
+ if ta == from_int(to_int(ta)):
346
+ return mpi_pow_int(s, to_int(ta), prec)
347
+ if ta == fhalf:
348
+ return mpi_sqrt(s, prec)
349
+ u = mpi_log(s, prec + 20)
350
+ v = mpi_mul(u, t, prec + 20)
351
+ return mpi_exp(v, prec)
352
+
353
+ def MIN(x, y):
354
+ if mpf_le(x, y):
355
+ return x
356
+ return y
357
+
358
+ def MAX(x, y):
359
+ if mpf_ge(x, y):
360
+ return x
361
+ return y
362
+
363
+ def cos_sin_quadrant(x, wp):
364
+ sign, man, exp, bc = x
365
+ if x == fzero:
366
+ return fone, fzero, 0
367
+ # TODO: combine evaluation code to avoid duplicate modulo
368
+ c, s = mpf_cos_sin(x, wp)
369
+ t, n, wp_ = mod_pi2(man, exp, exp+bc, 15)
370
+ if sign:
371
+ n = -1-n
372
+ return c, s, n
373
+
374
+ def mpi_cos_sin(x, prec):
375
+ a, b = x
376
+ if a == b == fzero:
377
+ return (fone, fone), (fzero, fzero)
378
+ # Guaranteed to contain both -1 and 1
379
+ if (finf in x) or (fninf in x):
380
+ return (fnone, fone), (fnone, fone)
381
+ wp = prec + 20
382
+ ca, sa, na = cos_sin_quadrant(a, wp)
383
+ cb, sb, nb = cos_sin_quadrant(b, wp)
384
+ ca, cb = mpf_min_max([ca, cb])
385
+ sa, sb = mpf_min_max([sa, sb])
386
+ # Both functions are monotonic within one quadrant
387
+ if na == nb:
388
+ pass
389
+ # Guaranteed to contain both -1 and 1
390
+ elif nb - na >= 4:
391
+ return (fnone, fone), (fnone, fone)
392
+ else:
393
+ # cos has maximum between a and b
394
+ if na//4 != nb//4:
395
+ cb = fone
396
+ # cos has minimum
397
+ if (na-2)//4 != (nb-2)//4:
398
+ ca = fnone
399
+ # sin has maximum
400
+ if (na-1)//4 != (nb-1)//4:
401
+ sb = fone
402
+ # sin has minimum
403
+ if (na-3)//4 != (nb-3)//4:
404
+ sa = fnone
405
+ # Perturb to force interval rounding
406
+ more = from_man_exp((MPZ_ONE<<wp) + (MPZ_ONE<<10), -wp)
407
+ less = from_man_exp((MPZ_ONE<<wp) - (MPZ_ONE<<10), -wp)
408
+ def finalize(v, rounding):
409
+ if bool(v[0]) == (rounding == round_floor):
410
+ p = more
411
+ else:
412
+ p = less
413
+ v = mpf_mul(v, p, prec, rounding)
414
+ sign, man, exp, bc = v
415
+ if exp+bc >= 1:
416
+ if sign:
417
+ return fnone
418
+ return fone
419
+ return v
420
+ ca = finalize(ca, round_floor)
421
+ cb = finalize(cb, round_ceiling)
422
+ sa = finalize(sa, round_floor)
423
+ sb = finalize(sb, round_ceiling)
424
+ return (ca,cb), (sa,sb)
425
+
426
+ def mpi_cos(x, prec):
427
+ return mpi_cos_sin(x, prec)[0]
428
+
429
+ def mpi_sin(x, prec):
430
+ return mpi_cos_sin(x, prec)[1]
431
+
432
+ def mpi_tan(x, prec):
433
+ cos, sin = mpi_cos_sin(x, prec+20)
434
+ return mpi_div(sin, cos, prec)
435
+
436
+ def mpi_cot(x, prec):
437
+ cos, sin = mpi_cos_sin(x, prec+20)
438
+ return mpi_div(cos, sin, prec)
439
+
440
+ def mpi_from_str_a_b(x, y, percent, prec):
441
+ wp = prec + 20
442
+ xa = from_str(x, wp, round_floor)
443
+ xb = from_str(x, wp, round_ceiling)
444
+ #ya = from_str(y, wp, round_floor)
445
+ y = from_str(y, wp, round_ceiling)
446
+ assert mpf_ge(y, fzero)
447
+ if percent:
448
+ y = mpf_mul(MAX(mpf_abs(xa), mpf_abs(xb)), y, wp, round_ceiling)
449
+ y = mpf_div(y, from_int(100), wp, round_ceiling)
450
+ a = mpf_sub(xa, y, prec, round_floor)
451
+ b = mpf_add(xb, y, prec, round_ceiling)
452
+ return a, b
453
+
454
+ def mpi_from_str(s, prec):
455
+ """
456
+ Parse an interval number given as a string.
457
+
458
+ Allowed forms are
459
+
460
+ "-1.23e-27"
461
+ Any single decimal floating-point literal.
462
+ "a +- b" or "a (b)"
463
+ a is the midpoint of the interval and b is the half-width
464
+ "a +- b%" or "a (b%)"
465
+ a is the midpoint of the interval and the half-width
466
+ is b percent of a (`a \times b / 100`).
467
+ "[a, b]"
468
+ The interval indicated directly.
469
+ "x[y,z]e"
470
+ x are shared digits, y and z are unequal digits, e is the exponent.
471
+
472
+ """
473
+ e = ValueError("Improperly formed interval number '%s'" % s)
474
+ s = s.replace(" ", "")
475
+ wp = prec + 20
476
+ if "+-" in s:
477
+ x, y = s.split("+-")
478
+ return mpi_from_str_a_b(x, y, False, prec)
479
+ # case 2
480
+ elif "(" in s:
481
+ # Don't confuse with a complex number (x,y)
482
+ if s[0] == "(" or ")" not in s:
483
+ raise e
484
+ s = s.replace(")", "")
485
+ percent = False
486
+ if "%" in s:
487
+ if s[-1] != "%":
488
+ raise e
489
+ percent = True
490
+ s = s.replace("%", "")
491
+ x, y = s.split("(")
492
+ return mpi_from_str_a_b(x, y, percent, prec)
493
+ elif "," in s:
494
+ if ('[' not in s) or (']' not in s):
495
+ raise e
496
+ if s[0] == '[':
497
+ # case 3
498
+ s = s.replace("[", "")
499
+ s = s.replace("]", "")
500
+ a, b = s.split(",")
501
+ a = from_str(a, prec, round_floor)
502
+ b = from_str(b, prec, round_ceiling)
503
+ return a, b
504
+ else:
505
+ # case 4
506
+ x, y = s.split('[')
507
+ y, z = y.split(',')
508
+ if 'e' in s:
509
+ z, e = z.split(']')
510
+ else:
511
+ z, e = z.rstrip(']'), ''
512
+ a = from_str(x+y+e, prec, round_floor)
513
+ b = from_str(x+z+e, prec, round_ceiling)
514
+ return a, b
515
+ else:
516
+ a = from_str(s, prec, round_floor)
517
+ b = from_str(s, prec, round_ceiling)
518
+ return a, b
519
+
520
+ def mpi_to_str(x, dps, use_spaces=True, brackets='[]', mode='brackets', error_dps=4, **kwargs):
521
+ """
522
+ Convert a mpi interval to a string.
523
+
524
+ **Arguments**
525
+
526
+ *dps*
527
+ decimal places to use for printing
528
+ *use_spaces*
529
+ use spaces for more readable output, defaults to true
530
+ *brackets*
531
+ pair of strings (or two-character string) giving left and right brackets
532
+ *mode*
533
+ mode of display: 'plusminus', 'percent', 'brackets' (default) or 'diff'
534
+ *error_dps*
535
+ limit the error to *error_dps* digits (mode 'plusminus and 'percent')
536
+
537
+ Additional keyword arguments are forwarded to the mpf-to-string conversion
538
+ for the components of the output.
539
+
540
+ **Examples**
541
+
542
+ >>> from mpmath import mpi, mp
543
+ >>> mp.dps = 30
544
+ >>> x = mpi(1, 2)._mpi_
545
+ >>> mpi_to_str(x, 2, mode='plusminus')
546
+ '1.5 +- 0.5'
547
+ >>> mpi_to_str(x, 2, mode='percent')
548
+ '1.5 (33.33%)'
549
+ >>> mpi_to_str(x, 2, mode='brackets')
550
+ '[1.0, 2.0]'
551
+ >>> mpi_to_str(x, 2, mode='brackets' , brackets=('<', '>'))
552
+ '<1.0, 2.0>'
553
+ >>> x = mpi('5.2582327113062393041', '5.2582327113062749951')._mpi_
554
+ >>> mpi_to_str(x, 15, mode='diff')
555
+ '5.2582327113062[4, 7]'
556
+ >>> mpi_to_str(mpi(0)._mpi_, 2, mode='percent')
557
+ '0.0 (0.0%)'
558
+
559
+ """
560
+ prec = dps_to_prec(dps)
561
+ wp = prec + 20
562
+ a, b = x
563
+ mid = mpi_mid(x, prec)
564
+ delta = mpi_delta(x, prec)
565
+ a_str = to_str(a, dps, **kwargs)
566
+ b_str = to_str(b, dps, **kwargs)
567
+ mid_str = to_str(mid, dps, **kwargs)
568
+ sp = ""
569
+ if use_spaces:
570
+ sp = " "
571
+ br1, br2 = brackets
572
+ if mode == 'plusminus':
573
+ delta_str = to_str(mpf_shift(delta,-1), dps, **kwargs)
574
+ s = mid_str + sp + "+-" + sp + delta_str
575
+ elif mode == 'percent':
576
+ if mid == fzero:
577
+ p = fzero
578
+ else:
579
+ # p = 100 * delta(x) / (2*mid(x))
580
+ p = mpf_mul(delta, from_int(100))
581
+ p = mpf_div(p, mpf_mul(mid, from_int(2)), wp)
582
+ s = mid_str + sp + "(" + to_str(p, error_dps) + "%)"
583
+ elif mode == 'brackets':
584
+ s = br1 + a_str + "," + sp + b_str + br2
585
+ elif mode == 'diff':
586
+ # use more digits if str(x.a) and str(x.b) are equal
587
+ if a_str == b_str:
588
+ a_str = to_str(a, dps+3, **kwargs)
589
+ b_str = to_str(b, dps+3, **kwargs)
590
+ # separate mantissa and exponent
591
+ a = a_str.split('e')
592
+ if len(a) == 1:
593
+ a.append('')
594
+ b = b_str.split('e')
595
+ if len(b) == 1:
596
+ b.append('')
597
+ if a[1] == b[1]:
598
+ if a[0] != b[0]:
599
+ for i in xrange(len(a[0]) + 1):
600
+ if a[0][i] != b[0][i]:
601
+ break
602
+ s = (a[0][:i] + br1 + a[0][i:] + ',' + sp + b[0][i:] + br2
603
+ + 'e'*min(len(a[1]), 1) + a[1])
604
+ else: # no difference
605
+ s = a[0] + br1 + br2 + 'e'*min(len(a[1]), 1) + a[1]
606
+ else:
607
+ s = br1 + 'e'.join(a) + ',' + sp + 'e'.join(b) + br2
608
+ else:
609
+ raise ValueError("'%s' is unknown mode for printing mpi" % mode)
610
+ return s
611
+
612
+ def mpci_add(x, y, prec):
613
+ a, b = x
614
+ c, d = y
615
+ return mpi_add(a, c, prec), mpi_add(b, d, prec)
616
+
617
+ def mpci_sub(x, y, prec):
618
+ a, b = x
619
+ c, d = y
620
+ return mpi_sub(a, c, prec), mpi_sub(b, d, prec)
621
+
622
+ def mpci_neg(x, prec=0):
623
+ a, b = x
624
+ return mpi_neg(a, prec), mpi_neg(b, prec)
625
+
626
+ def mpci_pos(x, prec):
627
+ a, b = x
628
+ return mpi_pos(a, prec), mpi_pos(b, prec)
629
+
630
+ def mpci_mul(x, y, prec):
631
+ # TODO: optimize for real/imag cases
632
+ a, b = x
633
+ c, d = y
634
+ r1 = mpi_mul(a,c)
635
+ r2 = mpi_mul(b,d)
636
+ re = mpi_sub(r1,r2,prec)
637
+ i1 = mpi_mul(a,d)
638
+ i2 = mpi_mul(b,c)
639
+ im = mpi_add(i1,i2,prec)
640
+ return re, im
641
+
642
+ def mpci_div(x, y, prec):
643
+ # TODO: optimize for real/imag cases
644
+ a, b = x
645
+ c, d = y
646
+ wp = prec+20
647
+ m1 = mpi_square(c)
648
+ m2 = mpi_square(d)
649
+ m = mpi_add(m1,m2,wp)
650
+ re = mpi_add(mpi_mul(a,c), mpi_mul(b,d), wp)
651
+ im = mpi_sub(mpi_mul(b,c), mpi_mul(a,d), wp)
652
+ re = mpi_div(re, m, prec)
653
+ im = mpi_div(im, m, prec)
654
+ return re, im
655
+
656
+ def mpci_exp(x, prec):
657
+ a, b = x
658
+ wp = prec+20
659
+ r = mpi_exp(a, wp)
660
+ c, s = mpi_cos_sin(b, wp)
661
+ a = mpi_mul(r, c, prec)
662
+ b = mpi_mul(r, s, prec)
663
+ return a, b
664
+
665
+ def mpi_shift(x, n):
666
+ a, b = x
667
+ return mpf_shift(a,n), mpf_shift(b,n)
668
+
669
+ def mpi_cosh_sinh(x, prec):
670
+ # TODO: accuracy for small x
671
+ wp = prec+20
672
+ e1 = mpi_exp(x, wp)
673
+ e2 = mpi_div(mpi_one, e1, wp)
674
+ c = mpi_add(e1, e2, prec)
675
+ s = mpi_sub(e1, e2, prec)
676
+ c = mpi_shift(c, -1)
677
+ s = mpi_shift(s, -1)
678
+ return c, s
679
+
680
+ def mpci_cos(x, prec):
681
+ a, b = x
682
+ wp = prec+10
683
+ c, s = mpi_cos_sin(a, wp)
684
+ ch, sh = mpi_cosh_sinh(b, wp)
685
+ re = mpi_mul(c, ch, prec)
686
+ im = mpi_mul(s, sh, prec)
687
+ return re, mpi_neg(im)
688
+
689
+ def mpci_sin(x, prec):
690
+ a, b = x
691
+ wp = prec+10
692
+ c, s = mpi_cos_sin(a, wp)
693
+ ch, sh = mpi_cosh_sinh(b, wp)
694
+ re = mpi_mul(s, ch, prec)
695
+ im = mpi_mul(c, sh, prec)
696
+ return re, im
697
+
698
+ def mpci_abs(x, prec):
699
+ a, b = x
700
+ if a == mpi_zero:
701
+ return mpi_abs(b)
702
+ if b == mpi_zero:
703
+ return mpi_abs(a)
704
+ # Important: nonnegative
705
+ a = mpi_square(a)
706
+ b = mpi_square(b)
707
+ t = mpi_add(a, b, prec+20)
708
+ return mpi_sqrt(t, prec)
709
+
710
+ def mpi_atan2(y, x, prec):
711
+ ya, yb = y
712
+ xa, xb = x
713
+ # Constrained to the real line
714
+ if ya == yb == fzero:
715
+ if mpf_ge(xa, fzero):
716
+ return mpi_zero
717
+ return mpi_pi(prec)
718
+ # Right half-plane
719
+ if mpf_ge(xa, fzero):
720
+ if mpf_ge(ya, fzero):
721
+ a = mpf_atan2(ya, xb, prec, round_floor)
722
+ else:
723
+ a = mpf_atan2(ya, xa, prec, round_floor)
724
+ if mpf_ge(yb, fzero):
725
+ b = mpf_atan2(yb, xa, prec, round_ceiling)
726
+ else:
727
+ b = mpf_atan2(yb, xb, prec, round_ceiling)
728
+ # Upper half-plane
729
+ elif mpf_ge(ya, fzero):
730
+ b = mpf_atan2(ya, xa, prec, round_ceiling)
731
+ if mpf_le(xb, fzero):
732
+ a = mpf_atan2(yb, xb, prec, round_floor)
733
+ else:
734
+ a = mpf_atan2(ya, xb, prec, round_floor)
735
+ # Lower half-plane
736
+ elif mpf_le(yb, fzero):
737
+ a = mpf_atan2(yb, xa, prec, round_floor)
738
+ if mpf_le(xb, fzero):
739
+ b = mpf_atan2(ya, xb, prec, round_ceiling)
740
+ else:
741
+ b = mpf_atan2(yb, xb, prec, round_ceiling)
742
+ # Covering the origin
743
+ else:
744
+ b = mpf_pi(prec, round_ceiling)
745
+ a = mpf_neg(b)
746
+ return a, b
747
+
748
+ def mpci_arg(z, prec):
749
+ x, y = z
750
+ return mpi_atan2(y, x, prec)
751
+
752
+ def mpci_log(z, prec):
753
+ x, y = z
754
+ re = mpi_log(mpci_abs(z, prec+20), prec)
755
+ im = mpci_arg(z, prec)
756
+ return re, im
757
+
758
+ def mpci_pow(x, y, prec):
759
+ # TODO: recognize/speed up real cases, integer y
760
+ yre, yim = y
761
+ if yim == mpi_zero:
762
+ ya, yb = yre
763
+ if ya == yb:
764
+ sign, man, exp, bc = yb
765
+ if man and exp >= 0:
766
+ return mpci_pow_int(x, (-1)**sign * int(man<<exp), prec)
767
+ # x^0
768
+ if yb == fzero:
769
+ return mpci_pow_int(x, 0, prec)
770
+ wp = prec+20
771
+ return mpci_exp(mpci_mul(y, mpci_log(x, wp), wp), prec)
772
+
773
+ def mpci_square(x, prec):
774
+ a, b = x
775
+ # (a+bi)^2 = (a^2-b^2) + 2abi
776
+ re = mpi_sub(mpi_square(a), mpi_square(b), prec)
777
+ im = mpi_mul(a, b, prec)
778
+ im = mpi_shift(im, 1)
779
+ return re, im
780
+
781
+ def mpci_pow_int(x, n, prec):
782
+ if n < 0:
783
+ return mpci_div((mpi_one,mpi_zero), mpci_pow_int(x, -n, prec+20), prec)
784
+ if n == 0:
785
+ return mpi_one, mpi_zero
786
+ if n == 1:
787
+ return mpci_pos(x, prec)
788
+ if n == 2:
789
+ return mpci_square(x, prec)
790
+ wp = prec + 20
791
+ result = (mpi_one, mpi_zero)
792
+ while n:
793
+ if n & 1:
794
+ result = mpci_mul(result, x, wp)
795
+ n -= 1
796
+ x = mpci_square(x, wp)
797
+ n >>= 1
798
+ return mpci_pos(result, prec)
799
+
800
+ gamma_min_a = from_float(1.46163214496)
801
+ gamma_min_b = from_float(1.46163214497)
802
+ gamma_min = (gamma_min_a, gamma_min_b)
803
+ gamma_mono_imag_a = from_float(-1.1)
804
+ gamma_mono_imag_b = from_float(1.1)
805
+
806
+ def mpi_overlap(x, y):
807
+ a, b = x
808
+ c, d = y
809
+ if mpf_lt(d, a): return False
810
+ if mpf_gt(c, b): return False
811
+ return True
812
+
813
+ # type = 0 -- gamma
814
+ # type = 1 -- factorial
815
+ # type = 2 -- 1/gamma
816
+ # type = 3 -- log-gamma
817
+
818
+ def mpi_gamma(z, prec, type=0):
819
+ a, b = z
820
+ wp = prec+20
821
+
822
+ if type == 1:
823
+ return mpi_gamma(mpi_add(z, mpi_one, wp), prec, 0)
824
+
825
+ # increasing
826
+ if mpf_gt(a, gamma_min_b):
827
+ if type == 0:
828
+ c = mpf_gamma(a, prec, round_floor)
829
+ d = mpf_gamma(b, prec, round_ceiling)
830
+ elif type == 2:
831
+ c = mpf_rgamma(b, prec, round_floor)
832
+ d = mpf_rgamma(a, prec, round_ceiling)
833
+ elif type == 3:
834
+ c = mpf_loggamma(a, prec, round_floor)
835
+ d = mpf_loggamma(b, prec, round_ceiling)
836
+ # decreasing
837
+ elif mpf_gt(a, fzero) and mpf_lt(b, gamma_min_a):
838
+ if type == 0:
839
+ c = mpf_gamma(b, prec, round_floor)
840
+ d = mpf_gamma(a, prec, round_ceiling)
841
+ elif type == 2:
842
+ c = mpf_rgamma(a, prec, round_floor)
843
+ d = mpf_rgamma(b, prec, round_ceiling)
844
+ elif type == 3:
845
+ c = mpf_loggamma(b, prec, round_floor)
846
+ d = mpf_loggamma(a, prec, round_ceiling)
847
+ else:
848
+ # TODO: reflection formula
849
+ znew = mpi_add(z, mpi_one, wp)
850
+ if type == 0: return mpi_div(mpi_gamma(znew, prec+2, 0), z, prec)
851
+ if type == 2: return mpi_mul(mpi_gamma(znew, prec+2, 2), z, prec)
852
+ if type == 3: return mpi_sub(mpi_gamma(znew, prec+2, 3), mpi_log(z, prec+2), prec)
853
+ return c, d
854
+
855
+ def mpci_gamma(z, prec, type=0):
856
+ (a1,a2), (b1,b2) = z
857
+
858
+ # Real case
859
+ if b1 == b2 == fzero and (type != 3 or mpf_gt(a1,fzero)):
860
+ return mpi_gamma(z, prec, type), mpi_zero
861
+
862
+ # Estimate precision
863
+ wp = prec+20
864
+ if type != 3:
865
+ amag = a2[2]+a2[3]
866
+ bmag = b2[2]+b2[3]
867
+ if a2 != fzero:
868
+ mag = max(amag, bmag)
869
+ else:
870
+ mag = bmag
871
+ an = abs(to_int(a2))
872
+ bn = abs(to_int(b2))
873
+ absn = max(an, bn)
874
+ gamma_size = max(0,absn*mag)
875
+ wp += bitcount(gamma_size)
876
+
877
+ # Assume type != 1
878
+ if type == 1:
879
+ (a1,a2) = mpi_add((a1,a2), mpi_one, wp); z = (a1,a2), (b1,b2)
880
+ type = 0
881
+
882
+ # Avoid non-monotonic region near the negative real axis
883
+ if mpf_lt(a1, gamma_min_b):
884
+ if mpi_overlap((b1,b2), (gamma_mono_imag_a, gamma_mono_imag_b)):
885
+ # TODO: reflection formula
886
+ #if mpf_lt(a2, mpf_shift(fone,-1)):
887
+ # znew = mpci_sub((mpi_one,mpi_zero),z,wp)
888
+ # ...
889
+ # Recurrence:
890
+ # gamma(z) = gamma(z+1)/z
891
+ znew = mpi_add((a1,a2), mpi_one, wp), (b1,b2)
892
+ if type == 0: return mpci_div(mpci_gamma(znew, prec+2, 0), z, prec)
893
+ if type == 2: return mpci_mul(mpci_gamma(znew, prec+2, 2), z, prec)
894
+ if type == 3: return mpci_sub(mpci_gamma(znew, prec+2, 3), mpci_log(z,prec+2), prec)
895
+
896
+ # Use monotonicity (except for a small region close to the
897
+ # origin and near poles)
898
+ # upper half-plane
899
+ if mpf_ge(b1, fzero):
900
+ minre = mpc_loggamma((a1,b2), wp, round_floor)
901
+ maxre = mpc_loggamma((a2,b1), wp, round_ceiling)
902
+ minim = mpc_loggamma((a1,b1), wp, round_floor)
903
+ maxim = mpc_loggamma((a2,b2), wp, round_ceiling)
904
+ # lower half-plane
905
+ elif mpf_le(b2, fzero):
906
+ minre = mpc_loggamma((a1,b1), wp, round_floor)
907
+ maxre = mpc_loggamma((a2,b2), wp, round_ceiling)
908
+ minim = mpc_loggamma((a2,b1), wp, round_floor)
909
+ maxim = mpc_loggamma((a1,b2), wp, round_ceiling)
910
+ # crosses real axis
911
+ else:
912
+ maxre = mpc_loggamma((a2,fzero), wp, round_ceiling)
913
+ # stretches more into the lower half-plane
914
+ if mpf_gt(mpf_neg(b1), b2):
915
+ minre = mpc_loggamma((a1,b1), wp, round_ceiling)
916
+ else:
917
+ minre = mpc_loggamma((a1,b2), wp, round_ceiling)
918
+ minim = mpc_loggamma((a2,b1), wp, round_floor)
919
+ maxim = mpc_loggamma((a2,b2), wp, round_floor)
920
+
921
+ w = (minre[0], maxre[0]), (minim[1], maxim[1])
922
+ if type == 3:
923
+ return mpi_pos(w[0], prec), mpi_pos(w[1], prec)
924
+ if type == 2:
925
+ w = mpci_neg(w)
926
+ return mpci_exp(w, prec)
927
+
928
+ def mpi_loggamma(z, prec): return mpi_gamma(z, prec, type=3)
929
+ def mpci_loggamma(z, prec): return mpci_gamma(z, prec, type=3)
930
+
931
+ def mpi_rgamma(z, prec): return mpi_gamma(z, prec, type=2)
932
+ def mpci_rgamma(z, prec): return mpci_gamma(z, prec, type=2)
933
+
934
+ def mpi_factorial(z, prec): return mpi_gamma(z, prec, type=1)
935
+ def mpci_factorial(z, prec): return mpci_gamma(z, prec, type=1)
lib/python3.11/site-packages/mpmath/math2.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module complements the math and cmath builtin modules by providing
3
+ fast machine precision versions of some additional functions (gamma, ...)
4
+ and wrapping math/cmath functions so that they can be called with either
5
+ real or complex arguments.
6
+ """
7
+
8
+ import operator
9
+ import math
10
+ import cmath
11
+
12
+ # Irrational (?) constants
13
+ pi = 3.1415926535897932385
14
+ e = 2.7182818284590452354
15
+ sqrt2 = 1.4142135623730950488
16
+ sqrt5 = 2.2360679774997896964
17
+ phi = 1.6180339887498948482
18
+ ln2 = 0.69314718055994530942
19
+ ln10 = 2.302585092994045684
20
+ euler = 0.57721566490153286061
21
+ catalan = 0.91596559417721901505
22
+ khinchin = 2.6854520010653064453
23
+ apery = 1.2020569031595942854
24
+
25
+ logpi = 1.1447298858494001741
26
+
27
+ def _mathfun_real(f_real, f_complex):
28
+ def f(x, **kwargs):
29
+ if type(x) is float:
30
+ return f_real(x)
31
+ if type(x) is complex:
32
+ return f_complex(x)
33
+ try:
34
+ x = float(x)
35
+ return f_real(x)
36
+ except (TypeError, ValueError):
37
+ x = complex(x)
38
+ return f_complex(x)
39
+ f.__name__ = f_real.__name__
40
+ return f
41
+
42
+ def _mathfun(f_real, f_complex):
43
+ def f(x, **kwargs):
44
+ if type(x) is complex:
45
+ return f_complex(x)
46
+ try:
47
+ return f_real(float(x))
48
+ except (TypeError, ValueError):
49
+ return f_complex(complex(x))
50
+ f.__name__ = f_real.__name__
51
+ return f
52
+
53
+ def _mathfun_n(f_real, f_complex):
54
+ def f(*args, **kwargs):
55
+ try:
56
+ return f_real(*(float(x) for x in args))
57
+ except (TypeError, ValueError):
58
+ return f_complex(*(complex(x) for x in args))
59
+ f.__name__ = f_real.__name__
60
+ return f
61
+
62
+ # Workaround for non-raising log and sqrt in Python 2.5 and 2.4
63
+ # on Unix system
64
+ try:
65
+ math.log(-2.0)
66
+ def math_log(x):
67
+ if x <= 0.0:
68
+ raise ValueError("math domain error")
69
+ return math.log(x)
70
+ def math_sqrt(x):
71
+ if x < 0.0:
72
+ raise ValueError("math domain error")
73
+ return math.sqrt(x)
74
+ except (ValueError, TypeError):
75
+ math_log = math.log
76
+ math_sqrt = math.sqrt
77
+
78
+ pow = _mathfun_n(operator.pow, lambda x, y: complex(x)**y)
79
+ log = _mathfun_n(math_log, cmath.log)
80
+ sqrt = _mathfun(math_sqrt, cmath.sqrt)
81
+ exp = _mathfun_real(math.exp, cmath.exp)
82
+
83
+ cos = _mathfun_real(math.cos, cmath.cos)
84
+ sin = _mathfun_real(math.sin, cmath.sin)
85
+ tan = _mathfun_real(math.tan, cmath.tan)
86
+
87
+ acos = _mathfun(math.acos, cmath.acos)
88
+ asin = _mathfun(math.asin, cmath.asin)
89
+ atan = _mathfun_real(math.atan, cmath.atan)
90
+
91
+ cosh = _mathfun_real(math.cosh, cmath.cosh)
92
+ sinh = _mathfun_real(math.sinh, cmath.sinh)
93
+ tanh = _mathfun_real(math.tanh, cmath.tanh)
94
+
95
+ floor = _mathfun_real(math.floor,
96
+ lambda z: complex(math.floor(z.real), math.floor(z.imag)))
97
+ ceil = _mathfun_real(math.ceil,
98
+ lambda z: complex(math.ceil(z.real), math.ceil(z.imag)))
99
+
100
+
101
+ cos_sin = _mathfun_real(lambda x: (math.cos(x), math.sin(x)),
102
+ lambda z: (cmath.cos(z), cmath.sin(z)))
103
+
104
+ cbrt = _mathfun(lambda x: x**(1./3), lambda z: z**(1./3))
105
+
106
+ def nthroot(x, n):
107
+ r = 1./n
108
+ try:
109
+ return float(x) ** r
110
+ except (ValueError, TypeError):
111
+ return complex(x) ** r
112
+
113
+ def _sinpi_real(x):
114
+ if x < 0:
115
+ return -_sinpi_real(-x)
116
+ n, r = divmod(x, 0.5)
117
+ r *= pi
118
+ n %= 4
119
+ if n == 0: return math.sin(r)
120
+ if n == 1: return math.cos(r)
121
+ if n == 2: return -math.sin(r)
122
+ if n == 3: return -math.cos(r)
123
+
124
+ def _cospi_real(x):
125
+ if x < 0:
126
+ x = -x
127
+ n, r = divmod(x, 0.5)
128
+ r *= pi
129
+ n %= 4
130
+ if n == 0: return math.cos(r)
131
+ if n == 1: return -math.sin(r)
132
+ if n == 2: return -math.cos(r)
133
+ if n == 3: return math.sin(r)
134
+
135
+ def _sinpi_complex(z):
136
+ if z.real < 0:
137
+ return -_sinpi_complex(-z)
138
+ n, r = divmod(z.real, 0.5)
139
+ z = pi*complex(r, z.imag)
140
+ n %= 4
141
+ if n == 0: return cmath.sin(z)
142
+ if n == 1: return cmath.cos(z)
143
+ if n == 2: return -cmath.sin(z)
144
+ if n == 3: return -cmath.cos(z)
145
+
146
+ def _cospi_complex(z):
147
+ if z.real < 0:
148
+ z = -z
149
+ n, r = divmod(z.real, 0.5)
150
+ z = pi*complex(r, z.imag)
151
+ n %= 4
152
+ if n == 0: return cmath.cos(z)
153
+ if n == 1: return -cmath.sin(z)
154
+ if n == 2: return -cmath.cos(z)
155
+ if n == 3: return cmath.sin(z)
156
+
157
+ cospi = _mathfun_real(_cospi_real, _cospi_complex)
158
+ sinpi = _mathfun_real(_sinpi_real, _sinpi_complex)
159
+
160
+ def tanpi(x):
161
+ try:
162
+ return sinpi(x) / cospi(x)
163
+ except OverflowError:
164
+ if complex(x).imag > 10:
165
+ return 1j
166
+ if complex(x).imag < 10:
167
+ return -1j
168
+ raise
169
+
170
+ def cotpi(x):
171
+ try:
172
+ return cospi(x) / sinpi(x)
173
+ except OverflowError:
174
+ if complex(x).imag > 10:
175
+ return -1j
176
+ if complex(x).imag < 10:
177
+ return 1j
178
+ raise
179
+
180
+ INF = 1e300*1e300
181
+ NINF = -INF
182
+ NAN = INF-INF
183
+ EPS = 2.2204460492503131e-16
184
+
185
+ _exact_gamma = (INF, 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0,
186
+ 362880.0, 3628800.0, 39916800.0, 479001600.0, 6227020800.0, 87178291200.0,
187
+ 1307674368000.0, 20922789888000.0, 355687428096000.0, 6402373705728000.0,
188
+ 121645100408832000.0, 2432902008176640000.0)
189
+
190
+ _max_exact_gamma = len(_exact_gamma)-1
191
+
192
+ # Lanczos coefficients used by the GNU Scientific Library
193
+ _lanczos_g = 7
194
+ _lanczos_p = (0.99999999999980993, 676.5203681218851, -1259.1392167224028,
195
+ 771.32342877765313, -176.61502916214059, 12.507343278686905,
196
+ -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7)
197
+
198
+ def _gamma_real(x):
199
+ _intx = int(x)
200
+ if _intx == x:
201
+ if _intx <= 0:
202
+ #return (-1)**_intx * INF
203
+ raise ZeroDivisionError("gamma function pole")
204
+ if _intx <= _max_exact_gamma:
205
+ return _exact_gamma[_intx]
206
+ if x < 0.5:
207
+ # TODO: sinpi
208
+ return pi / (_sinpi_real(x)*_gamma_real(1-x))
209
+ else:
210
+ x -= 1.0
211
+ r = _lanczos_p[0]
212
+ for i in range(1, _lanczos_g+2):
213
+ r += _lanczos_p[i]/(x+i)
214
+ t = x + _lanczos_g + 0.5
215
+ return 2.506628274631000502417 * t**(x+0.5) * math.exp(-t) * r
216
+
217
+ def _gamma_complex(x):
218
+ if not x.imag:
219
+ return complex(_gamma_real(x.real))
220
+ if x.real < 0.5:
221
+ # TODO: sinpi
222
+ return pi / (_sinpi_complex(x)*_gamma_complex(1-x))
223
+ else:
224
+ x -= 1.0
225
+ r = _lanczos_p[0]
226
+ for i in range(1, _lanczos_g+2):
227
+ r += _lanczos_p[i]/(x+i)
228
+ t = x + _lanczos_g + 0.5
229
+ return 2.506628274631000502417 * t**(x+0.5) * cmath.exp(-t) * r
230
+
231
+ gamma = _mathfun_real(_gamma_real, _gamma_complex)
232
+
233
+ def rgamma(x):
234
+ try:
235
+ return 1./gamma(x)
236
+ except ZeroDivisionError:
237
+ return x*0.0
238
+
239
+ def factorial(x):
240
+ return gamma(x+1.0)
241
+
242
+ def arg(x):
243
+ if type(x) is float:
244
+ return math.atan2(0.0,x)
245
+ return math.atan2(x.imag,x.real)
246
+
247
+ # XXX: broken for negatives
248
+ def loggamma(x):
249
+ if type(x) not in (float, complex):
250
+ try:
251
+ x = float(x)
252
+ except (ValueError, TypeError):
253
+ x = complex(x)
254
+ try:
255
+ xreal = x.real
256
+ ximag = x.imag
257
+ except AttributeError: # py2.5
258
+ xreal = x
259
+ ximag = 0.0
260
+ # Reflection formula
261
+ # http://functions.wolfram.com/GammaBetaErf/LogGamma/16/01/01/0003/
262
+ if xreal < 0.0:
263
+ if abs(x) < 0.5:
264
+ v = log(gamma(x))
265
+ if ximag == 0:
266
+ v = v.conjugate()
267
+ return v
268
+ z = 1-x
269
+ try:
270
+ re = z.real
271
+ im = z.imag
272
+ except AttributeError: # py2.5
273
+ re = z
274
+ im = 0.0
275
+ refloor = floor(re)
276
+ if im == 0.0:
277
+ imsign = 0
278
+ elif im < 0.0:
279
+ imsign = -1
280
+ else:
281
+ imsign = 1
282
+ return (-pi*1j)*abs(refloor)*(1-abs(imsign)) + logpi - \
283
+ log(sinpi(z-refloor)) - loggamma(z) + 1j*pi*refloor*imsign
284
+ if x == 1.0 or x == 2.0:
285
+ return x*0
286
+ p = 0.
287
+ while abs(x) < 11:
288
+ p -= log(x)
289
+ x += 1.0
290
+ s = 0.918938533204672742 + (x-0.5)*log(x) - x
291
+ r = 1./x
292
+ r2 = r*r
293
+ s += 0.083333333333333333333*r; r *= r2
294
+ s += -0.0027777777777777777778*r; r *= r2
295
+ s += 0.00079365079365079365079*r; r *= r2
296
+ s += -0.0005952380952380952381*r; r *= r2
297
+ s += 0.00084175084175084175084*r; r *= r2
298
+ s += -0.0019175269175269175269*r; r *= r2
299
+ s += 0.0064102564102564102564*r; r *= r2
300
+ s += -0.02955065359477124183*r
301
+ return s + p
302
+
303
+ _psi_coeff = [
304
+ 0.083333333333333333333,
305
+ -0.0083333333333333333333,
306
+ 0.003968253968253968254,
307
+ -0.0041666666666666666667,
308
+ 0.0075757575757575757576,
309
+ -0.021092796092796092796,
310
+ 0.083333333333333333333,
311
+ -0.44325980392156862745,
312
+ 3.0539543302701197438,
313
+ -26.456212121212121212]
314
+
315
+ def _digamma_real(x):
316
+ _intx = int(x)
317
+ if _intx == x:
318
+ if _intx <= 0:
319
+ raise ZeroDivisionError("polygamma pole")
320
+ if x < 0.5:
321
+ x = 1.0-x
322
+ s = pi*cotpi(x)
323
+ else:
324
+ s = 0.0
325
+ while x < 10.0:
326
+ s -= 1.0/x
327
+ x += 1.0
328
+ x2 = x**-2
329
+ t = x2
330
+ for c in _psi_coeff:
331
+ s -= c*t
332
+ if t < 1e-20:
333
+ break
334
+ t *= x2
335
+ return s + math_log(x) - 0.5/x
336
+
337
+ def _digamma_complex(x):
338
+ if not x.imag:
339
+ return complex(_digamma_real(x.real))
340
+ if x.real < 0.5:
341
+ x = 1.0-x
342
+ s = pi*cotpi(x)
343
+ else:
344
+ s = 0.0
345
+ while abs(x) < 10.0:
346
+ s -= 1.0/x
347
+ x += 1.0
348
+ x2 = x**-2
349
+ t = x2
350
+ for c in _psi_coeff:
351
+ s -= c*t
352
+ if abs(t) < 1e-20:
353
+ break
354
+ t *= x2
355
+ return s + cmath.log(x) - 0.5/x
356
+
357
+ digamma = _mathfun_real(_digamma_real, _digamma_complex)
358
+
359
+ # TODO: could implement complex erf and erfc here. Need
360
+ # to find an accurate method (avoiding cancellation)
361
+ # for approx. 1 < abs(x) < 9.
362
+
363
+ _erfc_coeff_P = [
364
+ 1.0000000161203922312,
365
+ 2.1275306946297962644,
366
+ 2.2280433377390253297,
367
+ 1.4695509105618423961,
368
+ 0.66275911699770787537,
369
+ 0.20924776504163751585,
370
+ 0.045459713768411264339,
371
+ 0.0063065951710717791934,
372
+ 0.00044560259661560421715][::-1]
373
+
374
+ _erfc_coeff_Q = [
375
+ 1.0000000000000000000,
376
+ 3.2559100272784894318,
377
+ 4.9019435608903239131,
378
+ 4.4971472894498014205,
379
+ 2.7845640601891186528,
380
+ 1.2146026030046904138,
381
+ 0.37647108453729465912,
382
+ 0.080970149639040548613,
383
+ 0.011178148899483545902,
384
+ 0.00078981003831980423513][::-1]
385
+
386
+ def _polyval(coeffs, x):
387
+ p = coeffs[0]
388
+ for c in coeffs[1:]:
389
+ p = c + x*p
390
+ return p
391
+
392
+ def _erf_taylor(x):
393
+ # Taylor series assuming 0 <= x <= 1
394
+ x2 = x*x
395
+ s = t = x
396
+ n = 1
397
+ while abs(t) > 1e-17:
398
+ t *= x2/n
399
+ s -= t/(n+n+1)
400
+ n += 1
401
+ t *= x2/n
402
+ s += t/(n+n+1)
403
+ n += 1
404
+ return 1.1283791670955125739*s
405
+
406
+ def _erfc_mid(x):
407
+ # Rational approximation assuming 0 <= x <= 9
408
+ return exp(-x*x)*_polyval(_erfc_coeff_P,x)/_polyval(_erfc_coeff_Q,x)
409
+
410
+ def _erfc_asymp(x):
411
+ # Asymptotic expansion assuming x >= 9
412
+ x2 = x*x
413
+ v = exp(-x2)/x*0.56418958354775628695
414
+ r = t = 0.5 / x2
415
+ s = 1.0
416
+ for n in range(1,22,4):
417
+ s -= t
418
+ t *= r * (n+2)
419
+ s += t
420
+ t *= r * (n+4)
421
+ if abs(t) < 1e-17:
422
+ break
423
+ return s * v
424
+
425
+ def erf(x):
426
+ """
427
+ erf of a real number.
428
+ """
429
+ x = float(x)
430
+ if x != x:
431
+ return x
432
+ if x < 0.0:
433
+ return -erf(-x)
434
+ if x >= 1.0:
435
+ if x >= 6.0:
436
+ return 1.0
437
+ return 1.0 - _erfc_mid(x)
438
+ return _erf_taylor(x)
439
+
440
+ def erfc(x):
441
+ """
442
+ erfc of a real number.
443
+ """
444
+ x = float(x)
445
+ if x != x:
446
+ return x
447
+ if x < 0.0:
448
+ if x < -6.0:
449
+ return 2.0
450
+ return 2.0-erfc(-x)
451
+ if x > 9.0:
452
+ return _erfc_asymp(x)
453
+ if x >= 1.0:
454
+ return _erfc_mid(x)
455
+ return 1.0 - _erf_taylor(x)
456
+
457
+ gauss42 = [\
458
+ (0.99839961899006235, 0.0041059986046490839),
459
+ (-0.99839961899006235, 0.0041059986046490839),
460
+ (0.9915772883408609, 0.009536220301748501),
461
+ (-0.9915772883408609,0.009536220301748501),
462
+ (0.97934250806374812, 0.014922443697357493),
463
+ (-0.97934250806374812, 0.014922443697357493),
464
+ (0.96175936533820439,0.020227869569052644),
465
+ (-0.96175936533820439, 0.020227869569052644),
466
+ (0.93892355735498811, 0.025422959526113047),
467
+ (-0.93892355735498811,0.025422959526113047),
468
+ (0.91095972490412735, 0.030479240699603467),
469
+ (-0.91095972490412735, 0.030479240699603467),
470
+ (0.87802056981217269,0.03536907109759211),
471
+ (-0.87802056981217269, 0.03536907109759211),
472
+ (0.8402859832618168, 0.040065735180692258),
473
+ (-0.8402859832618168,0.040065735180692258),
474
+ (0.7979620532554873, 0.044543577771965874),
475
+ (-0.7979620532554873, 0.044543577771965874),
476
+ (0.75127993568948048,0.048778140792803244),
477
+ (-0.75127993568948048, 0.048778140792803244),
478
+ (0.70049459055617114, 0.052746295699174064),
479
+ (-0.70049459055617114,0.052746295699174064),
480
+ (0.64588338886924779, 0.056426369358018376),
481
+ (-0.64588338886924779, 0.056426369358018376),
482
+ (0.58774459748510932, 0.059798262227586649),
483
+ (-0.58774459748510932, 0.059798262227586649),
484
+ (0.5263957499311922, 0.062843558045002565),
485
+ (-0.5263957499311922, 0.062843558045002565),
486
+ (0.46217191207042191, 0.065545624364908975),
487
+ (-0.46217191207042191, 0.065545624364908975),
488
+ (0.39542385204297503, 0.067889703376521934),
489
+ (-0.39542385204297503, 0.067889703376521934),
490
+ (0.32651612446541151, 0.069862992492594159),
491
+ (-0.32651612446541151, 0.069862992492594159),
492
+ (0.25582507934287907, 0.071454714265170971),
493
+ (-0.25582507934287907, 0.071454714265170971),
494
+ (0.18373680656485453, 0.072656175243804091),
495
+ (-0.18373680656485453, 0.072656175243804091),
496
+ (0.11064502720851986, 0.073460813453467527),
497
+ (-0.11064502720851986, 0.073460813453467527),
498
+ (0.036948943165351772, 0.073864234232172879),
499
+ (-0.036948943165351772, 0.073864234232172879)]
500
+
501
+ EI_ASYMP_CONVERGENCE_RADIUS = 40.0
502
+
503
+ def ei_asymp(z, _e1=False):
504
+ r = 1./z
505
+ s = t = 1.0
506
+ k = 1
507
+ while 1:
508
+ t *= k*r
509
+ s += t
510
+ if abs(t) < 1e-16:
511
+ break
512
+ k += 1
513
+ v = s*exp(z)/z
514
+ if _e1:
515
+ if type(z) is complex:
516
+ zreal = z.real
517
+ zimag = z.imag
518
+ else:
519
+ zreal = z
520
+ zimag = 0.0
521
+ if zimag == 0.0 and zreal > 0.0:
522
+ v += pi*1j
523
+ else:
524
+ if type(z) is complex:
525
+ if z.imag > 0:
526
+ v += pi*1j
527
+ if z.imag < 0:
528
+ v -= pi*1j
529
+ return v
530
+
531
+ def ei_taylor(z, _e1=False):
532
+ s = t = z
533
+ k = 2
534
+ while 1:
535
+ t = t*z/k
536
+ term = t/k
537
+ if abs(term) < 1e-17:
538
+ break
539
+ s += term
540
+ k += 1
541
+ s += euler
542
+ if _e1:
543
+ s += log(-z)
544
+ else:
545
+ if type(z) is float or z.imag == 0.0:
546
+ s += math_log(abs(z))
547
+ else:
548
+ s += cmath.log(z)
549
+ return s
550
+
551
+ def ei(z, _e1=False):
552
+ typez = type(z)
553
+ if typez not in (float, complex):
554
+ try:
555
+ z = float(z)
556
+ typez = float
557
+ except (TypeError, ValueError):
558
+ z = complex(z)
559
+ typez = complex
560
+ if not z:
561
+ return -INF
562
+ absz = abs(z)
563
+ if absz > EI_ASYMP_CONVERGENCE_RADIUS:
564
+ return ei_asymp(z, _e1)
565
+ elif absz <= 2.0 or (typez is float and z > 0.0):
566
+ return ei_taylor(z, _e1)
567
+ # Integrate, starting from whichever is smaller of a Taylor
568
+ # series value or an asymptotic series value
569
+ if typez is complex and z.real > 0.0:
570
+ zref = z / absz
571
+ ref = ei_taylor(zref, _e1)
572
+ else:
573
+ zref = EI_ASYMP_CONVERGENCE_RADIUS * z / absz
574
+ ref = ei_asymp(zref, _e1)
575
+ C = (zref-z)*0.5
576
+ D = (zref+z)*0.5
577
+ s = 0.0
578
+ if type(z) is complex:
579
+ _exp = cmath.exp
580
+ else:
581
+ _exp = math.exp
582
+ for x,w in gauss42:
583
+ t = C*x+D
584
+ s += w*_exp(t)/t
585
+ ref -= C*s
586
+ return ref
587
+
588
+ def e1(z):
589
+ # hack to get consistent signs if the imaginary part if 0
590
+ # and signed
591
+ typez = type(z)
592
+ if type(z) not in (float, complex):
593
+ try:
594
+ z = float(z)
595
+ typez = float
596
+ except (TypeError, ValueError):
597
+ z = complex(z)
598
+ typez = complex
599
+ if typez is complex and not z.imag:
600
+ z = complex(z.real, 0.0)
601
+ # end hack
602
+ return -ei(-z, _e1=True)
603
+
604
+ _zeta_int = [\
605
+ -0.5,
606
+ 0.0,
607
+ 1.6449340668482264365,1.2020569031595942854,1.0823232337111381915,
608
+ 1.0369277551433699263,1.0173430619844491397,1.0083492773819228268,
609
+ 1.0040773561979443394,1.0020083928260822144,1.0009945751278180853,
610
+ 1.0004941886041194646,1.0002460865533080483,1.0001227133475784891,
611
+ 1.0000612481350587048,1.0000305882363070205,1.0000152822594086519,
612
+ 1.0000076371976378998,1.0000038172932649998,1.0000019082127165539,
613
+ 1.0000009539620338728,1.0000004769329867878,1.0000002384505027277,
614
+ 1.0000001192199259653,1.0000000596081890513,1.0000000298035035147,
615
+ 1.0000000149015548284]
616
+
617
+ _zeta_P = [-3.50000000087575873, -0.701274355654678147,
618
+ -0.0672313458590012612, -0.00398731457954257841,
619
+ -0.000160948723019303141, -4.67633010038383371e-6,
620
+ -1.02078104417700585e-7, -1.68030037095896287e-9,
621
+ -1.85231868742346722e-11][::-1]
622
+
623
+ _zeta_Q = [1.00000000000000000, -0.936552848762465319,
624
+ -0.0588835413263763741, -0.00441498861482948666,
625
+ -0.000143416758067432622, -5.10691659585090782e-6,
626
+ -9.58813053268913799e-8, -1.72963791443181972e-9,
627
+ -1.83527919681474132e-11][::-1]
628
+
629
+ _zeta_1 = [3.03768838606128127e-10, -1.21924525236601262e-8,
630
+ 2.01201845887608893e-7, -1.53917240683468381e-6,
631
+ -5.09890411005967954e-7, 0.000122464707271619326,
632
+ -0.000905721539353130232, -0.00239315326074843037,
633
+ 0.084239750013159168, 0.418938517907442414, 0.500000001921884009]
634
+
635
+ _zeta_0 = [-3.46092485016748794e-10, -6.42610089468292485e-9,
636
+ 1.76409071536679773e-7, -1.47141263991560698e-6, -6.38880222546167613e-7,
637
+ 0.000122641099800668209, -0.000905894913516772796, -0.00239303348507992713,
638
+ 0.0842396947501199816, 0.418938533204660256, 0.500000000000000052]
639
+
640
+ def zeta(s):
641
+ """
642
+ Riemann zeta function, real argument
643
+ """
644
+ if not isinstance(s, (float, int)):
645
+ try:
646
+ s = float(s)
647
+ except (ValueError, TypeError):
648
+ try:
649
+ s = complex(s)
650
+ if not s.imag:
651
+ return complex(zeta(s.real))
652
+ except (ValueError, TypeError):
653
+ pass
654
+ raise NotImplementedError
655
+ if s == 1:
656
+ raise ValueError("zeta(1) pole")
657
+ if s >= 27:
658
+ return 1.0 + 2.0**(-s) + 3.0**(-s)
659
+ n = int(s)
660
+ if n == s:
661
+ if n >= 0:
662
+ return _zeta_int[n]
663
+ if not (n % 2):
664
+ return 0.0
665
+ if s <= 0.0:
666
+ return 2.**s*pi**(s-1)*_sinpi_real(0.5*s)*_gamma_real(1-s)*zeta(1-s)
667
+ if s <= 2.0:
668
+ if s <= 1.0:
669
+ return _polyval(_zeta_0,s)/(s-1)
670
+ return _polyval(_zeta_1,s)/(s-1)
671
+ z = _polyval(_zeta_P,s) / _polyval(_zeta_Q,s)
672
+ return 1.0 + 2.0**(-s) + 3.0**(-s) + 4.0**(-s)*z
lib/python3.11/site-packages/mpmath/matrices/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import eigen # to set methods
2
+ from . import eigen_symmetric # to set methods
lib/python3.11/site-packages/mpmath/matrices/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (328 Bytes). View file
 
lib/python3.11/site-packages/mpmath/matrices/__pycache__/calculus.cpython-311.pyc ADDED
Binary file (22.9 kB). View file
 
lib/python3.11/site-packages/mpmath/matrices/__pycache__/eigen.cpython-311.pyc ADDED
Binary file (31 kB). View file
 
lib/python3.11/site-packages/mpmath/matrices/__pycache__/eigen_symmetric.cpython-311.pyc ADDED
Binary file (70 kB). View file
 
lib/python3.11/site-packages/mpmath/matrices/__pycache__/linalg.cpython-311.pyc ADDED
Binary file (40.1 kB). View file
 
lib/python3.11/site-packages/mpmath/matrices/__pycache__/matrices.cpython-311.pyc ADDED
Binary file (44.8 kB). View file
 
lib/python3.11/site-packages/mpmath/matrices/calculus.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..libmp.backend import xrange
2
+
3
+ # TODO: should use diagonalization-based algorithms
4
+
5
+ class MatrixCalculusMethods(object):
6
+
7
+ def _exp_pade(ctx, a):
8
+ """
9
+ Exponential of a matrix using Pade approximants.
10
+
11
+ See G. H. Golub, C. F. van Loan 'Matrix Computations',
12
+ third Ed., page 572
13
+
14
+ TODO:
15
+ - find a good estimate for q
16
+ - reduce the number of matrix multiplications to improve
17
+ performance
18
+ """
19
+ def eps_pade(p):
20
+ return ctx.mpf(2)**(3-2*p) * \
21
+ ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
22
+ q = 4
23
+ extraq = 8
24
+ while 1:
25
+ if eps_pade(q) < ctx.eps:
26
+ break
27
+ q += 1
28
+ q += extraq
29
+ j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
30
+ extra = q
31
+ prec = ctx.prec
32
+ ctx.dps += extra + 3
33
+ try:
34
+ a = a/2**j
35
+ na = a.rows
36
+ den = ctx.eye(na)
37
+ num = ctx.eye(na)
38
+ x = ctx.eye(na)
39
+ c = ctx.mpf(1)
40
+ for k in range(1, q+1):
41
+ c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
42
+ x = a*x
43
+ cx = c*x
44
+ num += cx
45
+ den += (-1)**k * cx
46
+ f = ctx.lu_solve_mat(den, num)
47
+ for k in range(j):
48
+ f = f*f
49
+ finally:
50
+ ctx.prec = prec
51
+ return f*1
52
+
53
+ def expm(ctx, A, method='taylor'):
54
+ r"""
55
+ Computes the matrix exponential of a square matrix `A`, which is defined
56
+ by the power series
57
+
58
+ .. math ::
59
+
60
+ \exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
61
+
62
+ With method='taylor', the matrix exponential is computed
63
+ using the Taylor series. With method='pade', Pade approximants
64
+ are used instead.
65
+
66
+ **Examples**
67
+
68
+ Basic examples::
69
+
70
+ >>> from mpmath import *
71
+ >>> mp.dps = 15; mp.pretty = True
72
+ >>> expm(zeros(3))
73
+ [1.0 0.0 0.0]
74
+ [0.0 1.0 0.0]
75
+ [0.0 0.0 1.0]
76
+ >>> expm(eye(3))
77
+ [2.71828182845905 0.0 0.0]
78
+ [ 0.0 2.71828182845905 0.0]
79
+ [ 0.0 0.0 2.71828182845905]
80
+ >>> expm([[1,1,0],[1,0,1],[0,1,0]])
81
+ [ 3.86814500615414 2.26812870852145 0.841130841230196]
82
+ [ 2.26812870852145 2.44114713886289 1.42699786729125]
83
+ [0.841130841230196 1.42699786729125 1.6000162976327]
84
+ >>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
85
+ [ 3.86814500615414 2.26812870852145 0.841130841230196]
86
+ [ 2.26812870852145 2.44114713886289 1.42699786729125]
87
+ [0.841130841230196 1.42699786729125 1.6000162976327]
88
+ >>> expm([[1+j, 0], [1+j,1]])
89
+ [(1.46869393991589 + 2.28735528717884j) 0.0]
90
+ [ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
91
+
92
+ Matrices with large entries are allowed::
93
+
94
+ >>> expm(matrix([[1,2],[2,3]])**25)
95
+ [5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
96
+ [9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
97
+
98
+ The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
99
+ noncommuting matrices::
100
+
101
+ >>> A = hilbert(3)
102
+ >>> B = A + eye(3)
103
+ >>> chop(mnorm(A*B - B*A))
104
+ 0.0
105
+ >>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
106
+ 0.0
107
+ >>> B = A + ones(3)
108
+ >>> mnorm(A*B - B*A)
109
+ 1.8
110
+ >>> mnorm(expm(A+B) - expm(A)*expm(B))
111
+ 42.0927851137247
112
+
113
+ """
114
+ if method == 'pade':
115
+ prec = ctx.prec
116
+ try:
117
+ A = ctx.matrix(A)
118
+ ctx.prec += 2*A.rows
119
+ res = ctx._exp_pade(A)
120
+ finally:
121
+ ctx.prec = prec
122
+ return res
123
+ A = ctx.matrix(A)
124
+ prec = ctx.prec
125
+ j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
126
+ j += int(0.5*prec**0.5)
127
+ try:
128
+ ctx.prec += 10 + 2*j
129
+ tol = +ctx.eps
130
+ A = A/2**j
131
+ T = A
132
+ Y = A**0 + A
133
+ k = 2
134
+ while 1:
135
+ T *= A * (1/ctx.mpf(k))
136
+ if ctx.mnorm(T, 'inf') < tol:
137
+ break
138
+ Y += T
139
+ k += 1
140
+ for k in xrange(j):
141
+ Y = Y*Y
142
+ finally:
143
+ ctx.prec = prec
144
+ Y *= 1
145
+ return Y
146
+
147
+ def cosm(ctx, A):
148
+ r"""
149
+ Gives the cosine of a square matrix `A`, defined in analogy
150
+ with the matrix exponential.
151
+
152
+ Examples::
153
+
154
+ >>> from mpmath import *
155
+ >>> mp.dps = 15; mp.pretty = True
156
+ >>> X = eye(3)
157
+ >>> cosm(X)
158
+ [0.54030230586814 0.0 0.0]
159
+ [ 0.0 0.54030230586814 0.0]
160
+ [ 0.0 0.0 0.54030230586814]
161
+ >>> X = hilbert(3)
162
+ >>> cosm(X)
163
+ [ 0.424403834569555 -0.316643413047167 -0.221474945949293]
164
+ [-0.316643413047167 0.820646708837824 -0.127183694770039]
165
+ [-0.221474945949293 -0.127183694770039 0.909236687217541]
166
+ >>> X = matrix([[1+j,-2],[0,-j]])
167
+ >>> cosm(X)
168
+ [(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
169
+ [ 0.0 (1.54308063481524 + 0.0j)]
170
+ """
171
+ B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
172
+ if not sum(A.apply(ctx.im).apply(abs)):
173
+ B = B.apply(ctx.re)
174
+ return B
175
+
176
+ def sinm(ctx, A):
177
+ r"""
178
+ Gives the sine of a square matrix `A`, defined in analogy
179
+ with the matrix exponential.
180
+
181
+ Examples::
182
+
183
+ >>> from mpmath import *
184
+ >>> mp.dps = 15; mp.pretty = True
185
+ >>> X = eye(3)
186
+ >>> sinm(X)
187
+ [0.841470984807897 0.0 0.0]
188
+ [ 0.0 0.841470984807897 0.0]
189
+ [ 0.0 0.0 0.841470984807897]
190
+ >>> X = hilbert(3)
191
+ >>> sinm(X)
192
+ [0.711608512150994 0.339783913247439 0.220742837314741]
193
+ [0.339783913247439 0.244113865695532 0.187231271174372]
194
+ [0.220742837314741 0.187231271174372 0.155816730769635]
195
+ >>> X = matrix([[1+j,-2],[0,-j]])
196
+ >>> sinm(X)
197
+ [(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
198
+ [ 0.0 (0.0 - 1.1752011936438j)]
199
+ """
200
+ B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
201
+ if not sum(A.apply(ctx.im).apply(abs)):
202
+ B = B.apply(ctx.re)
203
+ return B
204
+
205
+ def _sqrtm_rot(ctx, A, _may_rotate):
206
+ # If the iteration fails to converge, cheat by performing
207
+ # a rotation by a complex number
208
+ u = ctx.j**0.3
209
+ return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
210
+
211
+ def sqrtm(ctx, A, _may_rotate=2):
212
+ r"""
213
+ Computes a square root of the square matrix `A`, i.e. returns
214
+ a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
215
+ of a matrix, if it exists, is not unique.
216
+
217
+ **Examples**
218
+
219
+ Square roots of some simple matrices::
220
+
221
+ >>> from mpmath import *
222
+ >>> mp.dps = 15; mp.pretty = True
223
+ >>> sqrtm([[1,0], [0,1]])
224
+ [1.0 0.0]
225
+ [0.0 1.0]
226
+ >>> sqrtm([[0,0], [0,0]])
227
+ [0.0 0.0]
228
+ [0.0 0.0]
229
+ >>> sqrtm([[2,0],[0,1]])
230
+ [1.4142135623731 0.0]
231
+ [ 0.0 1.0]
232
+ >>> sqrtm([[1,1],[1,0]])
233
+ [ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
234
+ [(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
235
+ >>> sqrtm([[1,0],[0,1]])
236
+ [1.0 0.0]
237
+ [0.0 1.0]
238
+ >>> sqrtm([[-1,0],[0,1]])
239
+ [(0.0 - 1.0j) 0.0]
240
+ [ 0.0 (1.0 + 0.0j)]
241
+ >>> sqrtm([[j,0],[0,j]])
242
+ [(0.707106781186547 + 0.707106781186547j) 0.0]
243
+ [ 0.0 (0.707106781186547 + 0.707106781186547j)]
244
+
245
+ A square root of a rotation matrix, giving the corresponding
246
+ half-angle rotation matrix::
247
+
248
+ >>> t1 = 0.75
249
+ >>> t2 = t1 * 0.5
250
+ >>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
251
+ >>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
252
+ >>> sqrtm(A1)
253
+ [0.930507621912314 -0.366272529086048]
254
+ [0.366272529086048 0.930507621912314]
255
+ >>> A2
256
+ [0.930507621912314 -0.366272529086048]
257
+ [0.366272529086048 0.930507621912314]
258
+
259
+ The identity `(A^2)^{1/2} = A` does not necessarily hold::
260
+
261
+ >>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
262
+ >>> sqrtm(A**2)
263
+ [ 4.0 1.0 4.0]
264
+ [ 7.0 8.0 9.0]
265
+ [10.0 2.0 11.0]
266
+ >>> sqrtm(A)**2
267
+ [ 4.0 1.0 4.0]
268
+ [ 7.0 8.0 9.0]
269
+ [10.0 2.0 11.0]
270
+ >>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
271
+ >>> sqrtm(A**2)
272
+ [ 7.43715112194995 -0.324127569985474 1.8481718827526]
273
+ [-0.251549715716942 9.32699765900402 2.48221180985147]
274
+ [ 4.11609388833616 0.775751877098258 13.017955697342]
275
+ >>> chop(sqrtm(A)**2)
276
+ [-4.0 1.0 4.0]
277
+ [ 7.0 -8.0 9.0]
278
+ [10.0 2.0 11.0]
279
+
280
+ For some matrices, a square root does not exist::
281
+
282
+ >>> sqrtm([[0,1], [0,0]])
283
+ Traceback (most recent call last):
284
+ ...
285
+ ZeroDivisionError: matrix is numerically singular
286
+
287
+ Two examples from the documentation for Matlab's ``sqrtm``::
288
+
289
+ >>> mp.dps = 15; mp.pretty = True
290
+ >>> sqrtm([[7,10],[15,22]])
291
+ [1.56669890360128 1.74077655955698]
292
+ [2.61116483933547 4.17786374293675]
293
+ >>>
294
+ >>> X = matrix(\
295
+ ... [[5,-4,1,0,0],
296
+ ... [-4,6,-4,1,0],
297
+ ... [1,-4,6,-4,1],
298
+ ... [0,1,-4,6,-4],
299
+ ... [0,0,1,-4,5]])
300
+ >>> Y = matrix(\
301
+ ... [[2,-1,-0,-0,-0],
302
+ ... [-1,2,-1,0,-0],
303
+ ... [0,-1,2,-1,0],
304
+ ... [-0,0,-1,2,-1],
305
+ ... [-0,-0,-0,-1,2]])
306
+ >>> mnorm(sqrtm(X) - Y)
307
+ 4.53155328326114e-19
308
+
309
+ """
310
+ A = ctx.matrix(A)
311
+ # Trivial
312
+ if A*0 == A:
313
+ return A
314
+ prec = ctx.prec
315
+ if _may_rotate:
316
+ d = ctx.det(A)
317
+ if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
318
+ return ctx._sqrtm_rot(A, _may_rotate-1)
319
+ try:
320
+ ctx.prec += 10
321
+ tol = ctx.eps * 128
322
+ Y = A
323
+ Z = I = A**0
324
+ k = 0
325
+ # Denman-Beavers iteration
326
+ while 1:
327
+ Yprev = Y
328
+ try:
329
+ Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
330
+ except ZeroDivisionError:
331
+ if _may_rotate:
332
+ Y = ctx._sqrtm_rot(A, _may_rotate-1)
333
+ break
334
+ else:
335
+ raise
336
+ mag1 = ctx.mnorm(Y-Yprev, 'inf')
337
+ mag2 = ctx.mnorm(Y, 'inf')
338
+ if mag1 <= mag2*tol:
339
+ break
340
+ if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
341
+ return ctx._sqrtm_rot(A, _may_rotate-1)
342
+ k += 1
343
+ if k > ctx.prec:
344
+ raise ctx.NoConvergence
345
+ finally:
346
+ ctx.prec = prec
347
+ Y *= 1
348
+ return Y
349
+
350
+ def logm(ctx, A):
351
+ r"""
352
+ Computes a logarithm of the square matrix `A`, i.e. returns
353
+ a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
354
+ of a matrix, if it exists, is not unique.
355
+
356
+ **Examples**
357
+
358
+ Logarithms of some simple matrices::
359
+
360
+ >>> from mpmath import *
361
+ >>> mp.dps = 15; mp.pretty = True
362
+ >>> X = eye(3)
363
+ >>> logm(X)
364
+ [0.0 0.0 0.0]
365
+ [0.0 0.0 0.0]
366
+ [0.0 0.0 0.0]
367
+ >>> logm(2*X)
368
+ [0.693147180559945 0.0 0.0]
369
+ [ 0.0 0.693147180559945 0.0]
370
+ [ 0.0 0.0 0.693147180559945]
371
+ >>> logm(expm(X))
372
+ [1.0 0.0 0.0]
373
+ [0.0 1.0 0.0]
374
+ [0.0 0.0 1.0]
375
+
376
+ A logarithm of a complex matrix::
377
+
378
+ >>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
379
+ >>> B = logm(X)
380
+ >>> nprint(B)
381
+ [ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
382
+ [ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
383
+ [(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
384
+ >>> chop(expm(B))
385
+ [(2.0 + 1.0j) 1.0 3.0]
386
+ [(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
387
+ [ -4.0 -5.0 (0.0 + 1.0j)]
388
+
389
+ A matrix `X` close to the identity matrix, for which
390
+ `\log(\exp(X)) = \exp(\log(X)) = X` holds::
391
+
392
+ >>> X = eye(3) + hilbert(3)/4
393
+ >>> X
394
+ [ 1.25 0.125 0.0833333333333333]
395
+ [ 0.125 1.08333333333333 0.0625]
396
+ [0.0833333333333333 0.0625 1.05]
397
+ >>> logm(expm(X))
398
+ [ 1.25 0.125 0.0833333333333333]
399
+ [ 0.125 1.08333333333333 0.0625]
400
+ [0.0833333333333333 0.0625 1.05]
401
+ >>> expm(logm(X))
402
+ [ 1.25 0.125 0.0833333333333333]
403
+ [ 0.125 1.08333333333333 0.0625]
404
+ [0.0833333333333333 0.0625 1.05]
405
+
406
+ A logarithm of a rotation matrix, giving back the angle of
407
+ the rotation::
408
+
409
+ >>> t = 3.7
410
+ >>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
411
+ >>> chop(logm(A))
412
+ [ 0.0 -2.58318530717959]
413
+ [2.58318530717959 0.0]
414
+ >>> (2*pi-t)
415
+ 2.58318530717959
416
+
417
+ For some matrices, a logarithm does not exist::
418
+
419
+ >>> logm([[1,0], [0,0]])
420
+ Traceback (most recent call last):
421
+ ...
422
+ ZeroDivisionError: matrix is numerically singular
423
+
424
+ Logarithm of a matrix with large entries::
425
+
426
+ >>> logm(hilbert(3) * 10**20).apply(re)
427
+ [ 45.5597513593433 1.27721006042799 0.317662687717978]
428
+ [ 1.27721006042799 42.5222778973542 2.24003708791604]
429
+ [0.317662687717978 2.24003708791604 42.395212822267]
430
+
431
+ """
432
+ A = ctx.matrix(A)
433
+ prec = ctx.prec
434
+ try:
435
+ ctx.prec += 10
436
+ tol = ctx.eps * 128
437
+ I = A**0
438
+ B = A
439
+ n = 0
440
+ while 1:
441
+ B = ctx.sqrtm(B)
442
+ n += 1
443
+ if ctx.mnorm(B-I, 'inf') < 0.125:
444
+ break
445
+ T = X = B-I
446
+ L = X*0
447
+ k = 1
448
+ while 1:
449
+ if k & 1:
450
+ L += T / k
451
+ else:
452
+ L -= T / k
453
+ T *= X
454
+ if ctx.mnorm(T, 'inf') < tol:
455
+ break
456
+ k += 1
457
+ if k > ctx.prec:
458
+ raise ctx.NoConvergence
459
+ finally:
460
+ ctx.prec = prec
461
+ L *= 2**n
462
+ return L
463
+
464
+ def powm(ctx, A, r):
465
+ r"""
466
+ Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
467
+ number `r`.
468
+
469
+ **Examples**
470
+
471
+ Powers and inverse powers of a matrix::
472
+
473
+ >>> from mpmath import *
474
+ >>> mp.dps = 15; mp.pretty = True
475
+ >>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
476
+ >>> powm(A, 2)
477
+ [ 63.0 20.0 69.0]
478
+ [174.0 89.0 199.0]
479
+ [164.0 48.0 179.0]
480
+ >>> chop(powm(powm(A, 4), 1/4.))
481
+ [ 4.0 1.0 4.0]
482
+ [ 7.0 8.0 9.0]
483
+ [10.0 2.0 11.0]
484
+ >>> powm(extraprec(20)(powm)(A, -4), -1/4.)
485
+ [ 4.0 1.0 4.0]
486
+ [ 7.0 8.0 9.0]
487
+ [10.0 2.0 11.0]
488
+ >>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
489
+ [ 4.0 1.0 4.0]
490
+ [ 7.0 8.0 9.0]
491
+ [10.0 2.0 11.0]
492
+ >>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
493
+ [ 4.0 1.0 4.0]
494
+ [ 7.0 8.0 9.0]
495
+ [10.0 2.0 11.0]
496
+
497
+ A Fibonacci-generating matrix::
498
+
499
+ >>> powm([[1,1],[1,0]], 10)
500
+ [89.0 55.0]
501
+ [55.0 34.0]
502
+ >>> fib(10)
503
+ 55.0
504
+ >>> powm([[1,1],[1,0]], 6.5)
505
+ [(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
506
+ [(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
507
+ >>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
508
+ (10.2078589271083 - 0.0195927472575932j)
509
+ >>> powm([[1,1],[1,0]], 6.2)
510
+ [ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
511
+ [(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
512
+ >>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
513
+ (8.81733464837593 - 0.0133048601383712j)
514
+
515
+ """
516
+ A = ctx.matrix(A)
517
+ r = ctx.convert(r)
518
+ prec = ctx.prec
519
+ try:
520
+ ctx.prec += 10
521
+ if ctx.isint(r):
522
+ v = A ** int(r)
523
+ elif ctx.isint(r*2):
524
+ y = int(r*2)
525
+ v = ctx.sqrtm(A) ** y
526
+ else:
527
+ v = ctx.expm(r*ctx.logm(A))
528
+ finally:
529
+ ctx.prec = prec
530
+ v *= 1
531
+ return v
lib/python3.11/site-packages/mpmath/matrices/eigen.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ ##################################################################################################
5
+ # module for the eigenvalue problem
6
+ # Copyright 2013 Timo Hartmann (thartmann15 at gmail.com)
7
+ #
8
+ # todo:
9
+ # - implement balancing
10
+ # - agressive early deflation
11
+ #
12
+ ##################################################################################################
13
+
14
+ """
15
+ The eigenvalue problem
16
+ ----------------------
17
+
18
+ This file contains routines for the eigenvalue problem.
19
+
20
+ high level routines:
21
+
22
+ hessenberg : reduction of a real or complex square matrix to upper Hessenberg form
23
+ schur : reduction of a real or complex square matrix to upper Schur form
24
+ eig : eigenvalues and eigenvectors of a real or complex square matrix
25
+
26
+ low level routines:
27
+
28
+ hessenberg_reduce_0 : reduction of a real or complex square matrix to upper Hessenberg form
29
+ hessenberg_reduce_1 : auxiliary routine to hessenberg_reduce_0
30
+ qr_step : a single implicitly shifted QR step for an upper Hessenberg matrix
31
+ hessenberg_qr : Schur decomposition of an upper Hessenberg matrix
32
+ eig_tr_r : right eigenvectors of an upper triangular matrix
33
+ eig_tr_l : left eigenvectors of an upper triangular matrix
34
+ """
35
+
36
+ from ..libmp.backend import xrange
37
+
38
+ class Eigen(object):
39
+ pass
40
+
41
+ def defun(f):
42
+ setattr(Eigen, f.__name__, f)
43
+ return f
44
+
45
+ def hessenberg_reduce_0(ctx, A, T):
46
+ """
47
+ This routine computes the (upper) Hessenberg decomposition of a square matrix A.
48
+ Given A, an unitary matrix Q is calculated such that
49
+
50
+ Q' A Q = H and Q' Q = Q Q' = 1
51
+
52
+ where H is an upper Hessenberg matrix, meaning that it only contains zeros
53
+ below the first subdiagonal. Here ' denotes the hermitian transpose (i.e.
54
+ transposition and conjugation).
55
+
56
+ parameters:
57
+ A (input/output) On input, A contains the square matrix A of
58
+ dimension (n,n). On output, A contains a compressed representation
59
+ of Q and H.
60
+ T (output) An array of length n containing the first elements of
61
+ the Householder reflectors.
62
+ """
63
+
64
+ # internally we work with householder reflections from the right.
65
+ # let u be a row vector (i.e. u[i]=A[i,:i]). then
66
+ # Q is build up by reflectors of the type (1-v'v) where v is a suitable
67
+ # modification of u. these reflectors are applyed to A from the right.
68
+ # because we work with reflectors from the right we have to start with
69
+ # the bottom row of A and work then upwards (this corresponds to
70
+ # some kind of RQ decomposition).
71
+ # the first part of the vectors v (i.e. A[i,:(i-1)]) are stored as row vectors
72
+ # in the lower left part of A (excluding the diagonal and subdiagonal).
73
+ # the last entry of v is stored in T.
74
+ # the upper right part of A (including diagonal and subdiagonal) becomes H.
75
+
76
+
77
+ n = A.rows
78
+ if n <= 2: return
79
+
80
+ for i in xrange(n-1, 1, -1):
81
+
82
+ # scale the vector
83
+
84
+ scale = 0
85
+ for k in xrange(0, i):
86
+ scale += abs(ctx.re(A[i,k])) + abs(ctx.im(A[i,k]))
87
+
88
+ scale_inv = 0
89
+ if scale != 0:
90
+ scale_inv = 1 / scale
91
+
92
+ if scale == 0 or ctx.isinf(scale_inv):
93
+ # sadly there are floating point numbers not equal to zero whose reciprocal is infinity
94
+ T[i] = 0
95
+ A[i,i-1] = 0
96
+ continue
97
+
98
+ # calculate parameters for housholder transformation
99
+
100
+ H = 0
101
+ for k in xrange(0, i):
102
+ A[i,k] *= scale_inv
103
+ rr = ctx.re(A[i,k])
104
+ ii = ctx.im(A[i,k])
105
+ H += rr * rr + ii * ii
106
+
107
+ F = A[i,i-1]
108
+ f = abs(F)
109
+ G = ctx.sqrt(H)
110
+ A[i,i-1] = - G * scale
111
+
112
+ if f == 0:
113
+ T[i] = G
114
+ else:
115
+ ff = F / f
116
+ T[i] = F + G * ff
117
+ A[i,i-1] *= ff
118
+
119
+ H += G * f
120
+ H = 1 / ctx.sqrt(H)
121
+
122
+ T[i] *= H
123
+ for k in xrange(0, i - 1):
124
+ A[i,k] *= H
125
+
126
+ for j in xrange(0, i):
127
+ # apply housholder transformation (from right)
128
+
129
+ G = ctx.conj(T[i]) * A[j,i-1]
130
+ for k in xrange(0, i-1):
131
+ G += ctx.conj(A[i,k]) * A[j,k]
132
+
133
+ A[j,i-1] -= G * T[i]
134
+ for k in xrange(0, i-1):
135
+ A[j,k] -= G * A[i,k]
136
+
137
+ for j in xrange(0, n):
138
+ # apply housholder transformation (from left)
139
+
140
+ G = T[i] * A[i-1,j]
141
+ for k in xrange(0, i-1):
142
+ G += A[i,k] * A[k,j]
143
+
144
+ A[i-1,j] -= G * ctx.conj(T[i])
145
+ for k in xrange(0, i-1):
146
+ A[k,j] -= G * ctx.conj(A[i,k])
147
+
148
+
149
+
150
+ def hessenberg_reduce_1(ctx, A, T):
151
+ """
152
+ This routine forms the unitary matrix Q described in hessenberg_reduce_0.
153
+
154
+ parameters:
155
+ A (input/output) On input, A is the same matrix as delivered by
156
+ hessenberg_reduce_0. On output, A is set to Q.
157
+
158
+ T (input) On input, T is the same array as delivered by hessenberg_reduce_0.
159
+ """
160
+
161
+ n = A.rows
162
+
163
+ if n == 1:
164
+ A[0,0] = 1
165
+ return
166
+
167
+ A[0,0] = A[1,1] = 1
168
+ A[0,1] = A[1,0] = 0
169
+
170
+ for i in xrange(2, n):
171
+ if T[i] != 0:
172
+
173
+ for j in xrange(0, i):
174
+ G = T[i] * A[i-1,j]
175
+ for k in xrange(0, i-1):
176
+ G += A[i,k] * A[k,j]
177
+
178
+ A[i-1,j] -= G * ctx.conj(T[i])
179
+ for k in xrange(0, i-1):
180
+ A[k,j] -= G * ctx.conj(A[i,k])
181
+
182
+ A[i,i] = 1
183
+ for j in xrange(0, i):
184
+ A[j,i] = A[i,j] = 0
185
+
186
+
187
+
188
+ @defun
189
+ def hessenberg(ctx, A, overwrite_a = False):
190
+ """
191
+ This routine computes the Hessenberg decomposition of a square matrix A.
192
+ Given A, an unitary matrix Q is determined such that
193
+
194
+ Q' A Q = H and Q' Q = Q Q' = 1
195
+
196
+ where H is an upper right Hessenberg matrix. Here ' denotes the hermitian
197
+ transpose (i.e. transposition and conjugation).
198
+
199
+ input:
200
+ A : a real or complex square matrix
201
+ overwrite_a : if true, allows modification of A which may improve
202
+ performance. if false, A is not modified.
203
+
204
+ output:
205
+ Q : an unitary matrix
206
+ H : an upper right Hessenberg matrix
207
+
208
+ example:
209
+ >>> from mpmath import mp
210
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
211
+ >>> Q, H = mp.hessenberg(A)
212
+ >>> mp.nprint(H, 3) # doctest:+SKIP
213
+ [ 3.15 2.23 4.44]
214
+ [-0.769 4.85 3.05]
215
+ [ 0.0 3.61 7.0]
216
+ >>> print(mp.chop(A - Q * H * Q.transpose_conj()))
217
+ [0.0 0.0 0.0]
218
+ [0.0 0.0 0.0]
219
+ [0.0 0.0 0.0]
220
+
221
+ return value: (Q, H)
222
+ """
223
+
224
+ n = A.rows
225
+
226
+ if n == 1:
227
+ return (ctx.matrix([[1]]), A)
228
+
229
+ if not overwrite_a:
230
+ A = A.copy()
231
+
232
+ T = ctx.matrix(n, 1)
233
+
234
+ hessenberg_reduce_0(ctx, A, T)
235
+ Q = A.copy()
236
+ hessenberg_reduce_1(ctx, Q, T)
237
+
238
+ for x in xrange(n):
239
+ for y in xrange(x+2, n):
240
+ A[y,x] = 0
241
+
242
+ return Q, A
243
+
244
+
245
+ ###########################################################################
246
+
247
+
248
+ def qr_step(ctx, n0, n1, A, Q, shift):
249
+ """
250
+ This subroutine executes a single implicitly shifted QR step applied to an
251
+ upper Hessenberg matrix A. Given A and shift as input, first an QR
252
+ decomposition is calculated:
253
+
254
+ Q R = A - shift * 1 .
255
+
256
+ The output is then following matrix:
257
+
258
+ R Q + shift * 1
259
+
260
+ parameters:
261
+ n0, n1 (input) Two integers which specify the submatrix A[n0:n1,n0:n1]
262
+ on which this subroutine operators. The subdiagonal elements
263
+ to the left and below this submatrix must be deflated (i.e. zero).
264
+ following restriction is imposed: n1>=n0+2
265
+ A (input/output) On input, A is an upper Hessenberg matrix.
266
+ On output, A is replaced by "R Q + shift * 1"
267
+ Q (input/output) The parameter Q is multiplied by the unitary matrix
268
+ Q arising from the QR decomposition. Q can also be false, in which
269
+ case the unitary matrix Q is not computated.
270
+ shift (input) a complex number specifying the shift. idealy close to an
271
+ eigenvalue of the bottemmost part of the submatrix A[n0:n1,n0:n1].
272
+
273
+ references:
274
+ Stoer, Bulirsch - Introduction to Numerical Analysis.
275
+ Kresser : Numerical Methods for General and Structured Eigenvalue Problems
276
+ """
277
+
278
+ # implicitly shifted and bulge chasing is explained at p.398/399 in "Stoer, Bulirsch - Introduction to Numerical Analysis"
279
+ # for bulge chasing see also "Watkins - The Matrix Eigenvalue Problem" sec.4.5,p.173
280
+
281
+ # the Givens rotation we used is determined as follows: let c,s be two complex
282
+ # numbers. then we have following relation:
283
+ #
284
+ # v = sqrt(|c|^2 + |s|^2)
285
+ #
286
+ # 1/v [ c~ s~] [c] = [v]
287
+ # [-s c ] [s] [0]
288
+ #
289
+ # the matrix on the left is our Givens rotation.
290
+
291
+ n = A.rows
292
+
293
+ # first step
294
+
295
+ # calculate givens rotation
296
+ c = A[n0 ,n0] - shift
297
+ s = A[n0+1,n0]
298
+
299
+ v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
300
+
301
+ if v == 0:
302
+ v = 1
303
+ c = 1
304
+ s = 0
305
+ else:
306
+ c /= v
307
+ s /= v
308
+
309
+ cc = ctx.conj(c)
310
+ cs = ctx.conj(s)
311
+
312
+ for k in xrange(n0, n):
313
+ # apply givens rotation from the left
314
+ x = A[n0 ,k]
315
+ y = A[n0+1,k]
316
+ A[n0 ,k] = cc * x + cs * y
317
+ A[n0+1,k] = c * y - s * x
318
+
319
+ for k in xrange(min(n1, n0+3)):
320
+ # apply givens rotation from the right
321
+ x = A[k,n0 ]
322
+ y = A[k,n0+1]
323
+ A[k,n0 ] = c * x + s * y
324
+ A[k,n0+1] = cc * y - cs * x
325
+
326
+ if not isinstance(Q, bool):
327
+ for k in xrange(n):
328
+ # eigenvectors
329
+ x = Q[k,n0 ]
330
+ y = Q[k,n0+1]
331
+ Q[k,n0 ] = c * x + s * y
332
+ Q[k,n0+1] = cc * y - cs * x
333
+
334
+ # chase the bulge
335
+
336
+ for j in xrange(n0, n1 - 2):
337
+ # calculate givens rotation
338
+
339
+ c = A[j+1,j]
340
+ s = A[j+2,j]
341
+
342
+ v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
343
+
344
+ if v == 0:
345
+ A[j+1,j] = 0
346
+ v = 1
347
+ c = 1
348
+ s = 0
349
+ else:
350
+ A[j+1,j] = v
351
+ c /= v
352
+ s /= v
353
+
354
+ A[j+2,j] = 0
355
+
356
+ cc = ctx.conj(c)
357
+ cs = ctx.conj(s)
358
+
359
+ for k in xrange(j+1, n):
360
+ # apply givens rotation from the left
361
+ x = A[j+1,k]
362
+ y = A[j+2,k]
363
+ A[j+1,k] = cc * x + cs * y
364
+ A[j+2,k] = c * y - s * x
365
+
366
+ for k in xrange(0, min(n1, j+4)):
367
+ # apply givens rotation from the right
368
+ x = A[k,j+1]
369
+ y = A[k,j+2]
370
+ A[k,j+1] = c * x + s * y
371
+ A[k,j+2] = cc * y - cs * x
372
+
373
+ if not isinstance(Q, bool):
374
+ for k in xrange(0, n):
375
+ # eigenvectors
376
+ x = Q[k,j+1]
377
+ y = Q[k,j+2]
378
+ Q[k,j+1] = c * x + s * y
379
+ Q[k,j+2] = cc * y - cs * x
380
+
381
+
382
+
383
+ def hessenberg_qr(ctx, A, Q):
384
+ """
385
+ This routine computes the Schur decomposition of an upper Hessenberg matrix A.
386
+ Given A, an unitary matrix Q is determined such that
387
+
388
+ Q' A Q = R and Q' Q = Q Q' = 1
389
+
390
+ where R is an upper right triangular matrix. Here ' denotes the hermitian
391
+ transpose (i.e. transposition and conjugation).
392
+
393
+ parameters:
394
+ A (input/output) On input, A contains an upper Hessenberg matrix.
395
+ On output, A is replace by the upper right triangluar matrix R.
396
+
397
+ Q (input/output) The parameter Q is multiplied by the unitary
398
+ matrix Q arising from the Schur decomposition. Q can also be
399
+ false, in which case the unitary matrix Q is not computated.
400
+ """
401
+
402
+ n = A.rows
403
+
404
+ norm = 0
405
+ for x in xrange(n):
406
+ for y in xrange(min(x+2, n)):
407
+ norm += ctx.re(A[y,x]) ** 2 + ctx.im(A[y,x]) ** 2
408
+ norm = ctx.sqrt(norm) / n
409
+
410
+ if norm == 0:
411
+ return
412
+
413
+ n0 = 0
414
+ n1 = n
415
+
416
+ eps = ctx.eps / (100 * n)
417
+ maxits = ctx.dps * 4
418
+
419
+ its = totalits = 0
420
+
421
+ while 1:
422
+ # kressner p.32 algo 3
423
+ # the active submatrix is A[n0:n1,n0:n1]
424
+
425
+ k = n0
426
+
427
+ while k + 1 < n1:
428
+ s = abs(ctx.re(A[k,k])) + abs(ctx.im(A[k,k])) + abs(ctx.re(A[k+1,k+1])) + abs(ctx.im(A[k+1,k+1]))
429
+ if s < eps * norm:
430
+ s = norm
431
+ if abs(A[k+1,k]) < eps * s:
432
+ break
433
+ k += 1
434
+
435
+ if k + 1 < n1:
436
+ # deflation found at position (k+1, k)
437
+
438
+ A[k+1,k] = 0
439
+ n0 = k + 1
440
+
441
+ its = 0
442
+
443
+ if n0 + 1 >= n1:
444
+ # block of size at most two has converged
445
+ n0 = 0
446
+ n1 = k + 1
447
+ if n1 < 2:
448
+ # QR algorithm has converged
449
+ return
450
+ else:
451
+ if (its % 30) == 10:
452
+ # exceptional shift
453
+ shift = A[n1-1,n1-2]
454
+ elif (its % 30) == 20:
455
+ # exceptional shift
456
+ shift = abs(A[n1-1,n1-2])
457
+ elif (its % 30) == 29:
458
+ # exceptional shift
459
+ shift = norm
460
+ else:
461
+ # A = [ a b ] det(x-A)=x*x-x*tr(A)+det(A)
462
+ # [ c d ]
463
+ #
464
+ # eigenvalues bad: (tr(A)+sqrt((tr(A))**2-4*det(A)))/2
465
+ # bad because of cancellation if |c| is small and |a-d| is small, too.
466
+ #
467
+ # eigenvalues good: (a+d+sqrt((a-d)**2+4*b*c))/2
468
+
469
+ t = A[n1-2,n1-2] + A[n1-1,n1-1]
470
+ s = (A[n1-1,n1-1] - A[n1-2,n1-2]) ** 2 + 4 * A[n1-1,n1-2] * A[n1-2,n1-1]
471
+ if ctx.re(s) > 0:
472
+ s = ctx.sqrt(s)
473
+ else:
474
+ s = ctx.sqrt(-s) * 1j
475
+ a = (t + s) / 2
476
+ b = (t - s) / 2
477
+ if abs(A[n1-1,n1-1] - a) > abs(A[n1-1,n1-1] - b):
478
+ shift = b
479
+ else:
480
+ shift = a
481
+
482
+ its += 1
483
+ totalits += 1
484
+
485
+ qr_step(ctx, n0, n1, A, Q, shift)
486
+
487
+ if its > maxits:
488
+ raise RuntimeError("qr: failed to converge after %d steps" % its)
489
+
490
+
491
+ @defun
492
+ def schur(ctx, A, overwrite_a = False):
493
+ """
494
+ This routine computes the Schur decomposition of a square matrix A.
495
+ Given A, an unitary matrix Q is determined such that
496
+
497
+ Q' A Q = R and Q' Q = Q Q' = 1
498
+
499
+ where R is an upper right triangular matrix. Here ' denotes the
500
+ hermitian transpose (i.e. transposition and conjugation).
501
+
502
+ input:
503
+ A : a real or complex square matrix
504
+ overwrite_a : if true, allows modification of A which may improve
505
+ performance. if false, A is not modified.
506
+
507
+ output:
508
+ Q : an unitary matrix
509
+ R : an upper right triangular matrix
510
+
511
+ return value: (Q, R)
512
+
513
+ example:
514
+ >>> from mpmath import mp
515
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
516
+ >>> Q, R = mp.schur(A)
517
+ >>> mp.nprint(R, 3) # doctest:+SKIP
518
+ [2.0 0.417 -2.53]
519
+ [0.0 4.0 -4.74]
520
+ [0.0 0.0 9.0]
521
+ >>> print(mp.chop(A - Q * R * Q.transpose_conj()))
522
+ [0.0 0.0 0.0]
523
+ [0.0 0.0 0.0]
524
+ [0.0 0.0 0.0]
525
+
526
+ warning: The Schur decomposition is not unique.
527
+ """
528
+
529
+ n = A.rows
530
+
531
+ if n == 1:
532
+ return (ctx.matrix([[1]]), A)
533
+
534
+ if not overwrite_a:
535
+ A = A.copy()
536
+
537
+ T = ctx.matrix(n, 1)
538
+
539
+ hessenberg_reduce_0(ctx, A, T)
540
+ Q = A.copy()
541
+ hessenberg_reduce_1(ctx, Q, T)
542
+
543
+ for x in xrange(n):
544
+ for y in xrange(x + 2, n):
545
+ A[y,x] = 0
546
+
547
+ hessenberg_qr(ctx, A, Q)
548
+
549
+ return Q, A
550
+
551
+
552
+ def eig_tr_r(ctx, A):
553
+ """
554
+ This routine calculates the right eigenvectors of an upper right triangular matrix.
555
+
556
+ input:
557
+ A an upper right triangular matrix
558
+
559
+ output:
560
+ ER a matrix whose columns form the right eigenvectors of A
561
+
562
+ return value: ER
563
+ """
564
+
565
+ # this subroutine is inspired by the lapack routines ctrevc.f,clatrs.f
566
+
567
+ n = A.rows
568
+
569
+ ER = ctx.eye(n)
570
+
571
+ eps = ctx.eps
572
+
573
+ unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
574
+ # since mpmath effectively has no limits on the exponent, we simply scale doubles up
575
+ # original double has prec*20
576
+
577
+ smlnum = unfl * (n / eps)
578
+ simin = 1 / ctx.sqrt(eps)
579
+
580
+ rmax = 1
581
+
582
+ for i in xrange(1, n):
583
+ s = A[i,i]
584
+
585
+ smin = max(eps * abs(s), smlnum)
586
+
587
+ for j in xrange(i - 1, -1, -1):
588
+
589
+ r = 0
590
+ for k in xrange(j + 1, i + 1):
591
+ r += A[j,k] * ER[k,i]
592
+
593
+ t = A[j,j] - s
594
+ if abs(t) < smin:
595
+ t = smin
596
+
597
+ r = -r / t
598
+ ER[j,i] = r
599
+
600
+ rmax = max(rmax, abs(r))
601
+ if rmax > simin:
602
+ for k in xrange(j, i+1):
603
+ ER[k,i] /= rmax
604
+ rmax = 1
605
+
606
+ if rmax != 1:
607
+ for k in xrange(0, i + 1):
608
+ ER[k,i] /= rmax
609
+
610
+ return ER
611
+
612
+ def eig_tr_l(ctx, A):
613
+ """
614
+ This routine calculates the left eigenvectors of an upper right triangular matrix.
615
+
616
+ input:
617
+ A an upper right triangular matrix
618
+
619
+ output:
620
+ EL a matrix whose rows form the left eigenvectors of A
621
+
622
+ return value: EL
623
+ """
624
+
625
+ n = A.rows
626
+
627
+ EL = ctx.eye(n)
628
+
629
+ eps = ctx.eps
630
+
631
+ unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
632
+ # since mpmath effectively has no limits on the exponent, we simply scale doubles up
633
+ # original double has prec*20
634
+
635
+ smlnum = unfl * (n / eps)
636
+ simin = 1 / ctx.sqrt(eps)
637
+
638
+ rmax = 1
639
+
640
+ for i in xrange(0, n - 1):
641
+ s = A[i,i]
642
+
643
+ smin = max(eps * abs(s), smlnum)
644
+
645
+ for j in xrange(i + 1, n):
646
+
647
+ r = 0
648
+ for k in xrange(i, j):
649
+ r += EL[i,k] * A[k,j]
650
+
651
+ t = A[j,j] - s
652
+ if abs(t) < smin:
653
+ t = smin
654
+
655
+ r = -r / t
656
+ EL[i,j] = r
657
+
658
+ rmax = max(rmax, abs(r))
659
+ if rmax > simin:
660
+ for k in xrange(i, j + 1):
661
+ EL[i,k] /= rmax
662
+ rmax = 1
663
+
664
+ if rmax != 1:
665
+ for k in xrange(i, n):
666
+ EL[i,k] /= rmax
667
+
668
+ return EL
669
+
670
+ @defun
671
+ def eig(ctx, A, left = False, right = True, overwrite_a = False):
672
+ """
673
+ This routine computes the eigenvalues and optionally the left and right
674
+ eigenvectors of a square matrix A. Given A, a vector E and matrices ER
675
+ and EL are calculated such that
676
+
677
+ A ER[:,i] = E[i] ER[:,i]
678
+ EL[i,:] A = EL[i,:] E[i]
679
+
680
+ E contains the eigenvalues of A. The columns of ER contain the right eigenvectors
681
+ of A whereas the rows of EL contain the left eigenvectors.
682
+
683
+
684
+ input:
685
+ A : a real or complex square matrix of shape (n, n)
686
+ left : if true, the left eigenvectors are calculated.
687
+ right : if true, the right eigenvectors are calculated.
688
+ overwrite_a : if true, allows modification of A which may improve
689
+ performance. if false, A is not modified.
690
+
691
+ output:
692
+ E : a list of length n containing the eigenvalues of A.
693
+ ER : a matrix whose columns contain the right eigenvectors of A.
694
+ EL : a matrix whose rows contain the left eigenvectors of A.
695
+
696
+ return values:
697
+ E if left and right are both false.
698
+ (E, ER) if right is true and left is false.
699
+ (E, EL) if left is true and right is false.
700
+ (E, EL, ER) if left and right are true.
701
+
702
+
703
+ examples:
704
+ >>> from mpmath import mp
705
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
706
+ >>> E, ER = mp.eig(A)
707
+ >>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
708
+ [0.0]
709
+ [0.0]
710
+ [0.0]
711
+
712
+ >>> E, EL, ER = mp.eig(A,left = True, right = True)
713
+ >>> E, EL, ER = mp.eig_sort(E, EL, ER)
714
+ >>> mp.nprint(E)
715
+ [2.0, 4.0, 9.0]
716
+ >>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
717
+ [0.0]
718
+ [0.0]
719
+ [0.0]
720
+ >>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
721
+ [0.0 0.0 0.0]
722
+
723
+ warning:
724
+ - If there are multiple eigenvalues, the eigenvectors do not necessarily
725
+ span the whole vectorspace, i.e. ER and EL may have not full rank.
726
+ Furthermore in that case the eigenvectors are numerical ill-conditioned.
727
+ - In the general case the eigenvalues have no natural order.
728
+
729
+ see also:
730
+ - eigh (or eigsy, eighe) for the symmetric eigenvalue problem.
731
+ - eig_sort for sorting of eigenvalues and eigenvectors
732
+ """
733
+
734
+ n = A.rows
735
+
736
+ if n == 1:
737
+ if left and (not right):
738
+ return ([A[0]], ctx.matrix([[1]]))
739
+
740
+ if right and (not left):
741
+ return ([A[0]], ctx.matrix([[1]]))
742
+
743
+ return ([A[0]], ctx.matrix([[1]]), ctx.matrix([[1]]))
744
+
745
+ if not overwrite_a:
746
+ A = A.copy()
747
+
748
+ T = ctx.zeros(n, 1)
749
+
750
+ hessenberg_reduce_0(ctx, A, T)
751
+
752
+ if left or right:
753
+ Q = A.copy()
754
+ hessenberg_reduce_1(ctx, Q, T)
755
+ else:
756
+ Q = False
757
+
758
+ for x in xrange(n):
759
+ for y in xrange(x + 2, n):
760
+ A[y,x] = 0
761
+
762
+ hessenberg_qr(ctx, A, Q)
763
+
764
+ E = [0 for i in xrange(n)]
765
+ for i in xrange(n):
766
+ E[i] = A[i,i]
767
+
768
+ if not (left or right):
769
+ return E
770
+
771
+ if left:
772
+ EL = eig_tr_l(ctx, A)
773
+ EL = EL * Q.transpose_conj()
774
+
775
+ if right:
776
+ ER = eig_tr_r(ctx, A)
777
+ ER = Q * ER
778
+
779
+ if left and (not right):
780
+ return (E, EL)
781
+
782
+ if right and (not left):
783
+ return (E, ER)
784
+
785
+ return (E, EL, ER)
786
+
787
+ @defun
788
+ def eig_sort(ctx, E, EL = False, ER = False, f = "real"):
789
+ """
790
+ This routine sorts the eigenvalues and eigenvectors delivered by ``eig``.
791
+
792
+ parameters:
793
+ E : the eigenvalues as delivered by eig
794
+ EL : the left eigenvectors as delivered by eig, or false
795
+ ER : the right eigenvectors as delivered by eig, or false
796
+ f : either a string ("real" sort by increasing real part, "imag" sort by
797
+ increasing imag part, "abs" sort by absolute value) or a function
798
+ mapping complexs to the reals, i.e. ``f = lambda x: -mp.re(x) ``
799
+ would sort the eigenvalues by decreasing real part.
800
+
801
+ return values:
802
+ E if EL and ER are both false.
803
+ (E, ER) if ER is not false and left is false.
804
+ (E, EL) if EL is not false and right is false.
805
+ (E, EL, ER) if EL and ER are not false.
806
+
807
+ example:
808
+ >>> from mpmath import mp
809
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
810
+ >>> E, EL, ER = mp.eig(A,left = True, right = True)
811
+ >>> E, EL, ER = mp.eig_sort(E, EL, ER)
812
+ >>> mp.nprint(E)
813
+ [2.0, 4.0, 9.0]
814
+ >>> E, EL, ER = mp.eig_sort(E, EL, ER,f = lambda x: -mp.re(x))
815
+ >>> mp.nprint(E)
816
+ [9.0, 4.0, 2.0]
817
+ >>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
818
+ [0.0]
819
+ [0.0]
820
+ [0.0]
821
+ >>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
822
+ [0.0 0.0 0.0]
823
+ """
824
+
825
+ if isinstance(f, str):
826
+ if f == "real":
827
+ f = ctx.re
828
+ elif f == "imag":
829
+ f = ctx.im
830
+ elif f == "abs":
831
+ f = abs
832
+ else:
833
+ raise RuntimeError("unknown function %s" % f)
834
+
835
+ n = len(E)
836
+
837
+ # Sort eigenvalues (bubble-sort)
838
+
839
+ for i in xrange(n):
840
+ imax = i
841
+ s = f(E[i]) # s is the current maximal element
842
+
843
+ for j in xrange(i + 1, n):
844
+ c = f(E[j])
845
+ if c < s:
846
+ s = c
847
+ imax = j
848
+
849
+ if imax != i:
850
+ # swap eigenvalues
851
+
852
+ z = E[i]
853
+ E[i] = E[imax]
854
+ E[imax] = z
855
+
856
+ if not isinstance(EL, bool):
857
+ for j in xrange(n):
858
+ z = EL[i,j]
859
+ EL[i,j] = EL[imax,j]
860
+ EL[imax,j] = z
861
+
862
+ if not isinstance(ER, bool):
863
+ for j in xrange(n):
864
+ z = ER[j,i]
865
+ ER[j,i] = ER[j,imax]
866
+ ER[j,imax] = z
867
+
868
+ if isinstance(EL, bool) and isinstance(ER, bool):
869
+ return E
870
+
871
+ if isinstance(EL, bool) and not(isinstance(ER, bool)):
872
+ return (E, ER)
873
+
874
+ if isinstance(ER, bool) and not(isinstance(EL, bool)):
875
+ return (E, EL)
876
+
877
+ return (E, EL, ER)
lib/python3.11/site-packages/mpmath/matrices/eigen_symmetric.py ADDED
@@ -0,0 +1,1807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ ##################################################################################################
5
+ # module for the symmetric eigenvalue problem
6
+ # Copyright 2013 Timo Hartmann (thartmann15 at gmail.com)
7
+ #
8
+ # todo:
9
+ # - implement balancing
10
+ #
11
+ ##################################################################################################
12
+
13
+ """
14
+ The symmetric eigenvalue problem.
15
+ ---------------------------------
16
+
17
+ This file contains routines for the symmetric eigenvalue problem.
18
+
19
+ high level routines:
20
+
21
+ eigsy : real symmetric (ordinary) eigenvalue problem
22
+ eighe : complex hermitian (ordinary) eigenvalue problem
23
+ eigh : unified interface for eigsy and eighe
24
+ svd_r : singular value decomposition for real matrices
25
+ svd_c : singular value decomposition for complex matrices
26
+ svd : unified interface for svd_r and svd_c
27
+
28
+
29
+ low level routines:
30
+
31
+ r_sy_tridiag : reduction of real symmetric matrix to real symmetric tridiagonal matrix
32
+ c_he_tridiag_0 : reduction of complex hermitian matrix to real symmetric tridiagonal matrix
33
+ c_he_tridiag_1 : auxiliary routine to c_he_tridiag_0
34
+ c_he_tridiag_2 : auxiliary routine to c_he_tridiag_0
35
+ tridiag_eigen : solves the real symmetric tridiagonal matrix eigenvalue problem
36
+ svd_r_raw : raw singular value decomposition for real matrices
37
+ svd_c_raw : raw singular value decomposition for complex matrices
38
+ """
39
+
40
+ from ..libmp.backend import xrange
41
+ from .eigen import defun
42
+
43
+
44
+ def r_sy_tridiag(ctx, A, D, E, calc_ev = True):
45
+ """
46
+ This routine transforms a real symmetric matrix A to a real symmetric
47
+ tridiagonal matrix T using an orthogonal similarity transformation:
48
+ Q' * A * Q = T (here ' denotes the matrix transpose).
49
+ The orthogonal matrix Q is build up from Householder reflectors.
50
+
51
+ parameters:
52
+ A (input/output) On input, A contains the real symmetric matrix of
53
+ dimension (n,n). On output, if calc_ev is true, A contains the
54
+ orthogonal matrix Q, otherwise A is destroyed.
55
+
56
+ D (output) real array of length n, contains the diagonal elements
57
+ of the tridiagonal matrix
58
+
59
+ E (output) real array of length n, contains the offdiagonal elements
60
+ of the tridiagonal matrix in E[0:(n-1)] where is the dimension of
61
+ the matrix A. E[n-1] is undefined.
62
+
63
+ calc_ev (input) If calc_ev is true, this routine explicitly calculates the
64
+ orthogonal matrix Q which is then returned in A. If calc_ev is
65
+ false, Q is not explicitly calculated resulting in a shorter run time.
66
+
67
+ This routine is a python translation of the fortran routine tred2.f in the
68
+ software library EISPACK (see netlib.org) which itself is based on the algol
69
+ procedure tred2 described in:
70
+ - Num. Math. 11, p.181-195 (1968) by Martin, Reinsch and Wilkonson
71
+ - Handbook for auto. comp., Vol II, Linear Algebra, p.212-226 (1971)
72
+
73
+ For a good introduction to Householder reflections, see also
74
+ Stoer, Bulirsch - Introduction to Numerical Analysis.
75
+ """
76
+
77
+ # note : the vector v of the i-th houshoulder reflector is stored in a[(i+1):,i]
78
+ # whereas v/<v,v> is stored in a[i,(i+1):]
79
+
80
+ n = A.rows
81
+ for i in xrange(n - 1, 0, -1):
82
+ # scale the vector
83
+
84
+ scale = 0
85
+ for k in xrange(0, i):
86
+ scale += abs(A[k,i])
87
+
88
+ scale_inv = 0
89
+ if scale != 0:
90
+ scale_inv = 1/scale
91
+
92
+ # sadly there are floating point numbers not equal to zero whose reciprocal is infinity
93
+
94
+ if i == 1 or scale == 0 or ctx.isinf(scale_inv):
95
+ E[i] = A[i-1,i] # nothing to do
96
+ D[i] = 0
97
+ continue
98
+
99
+ # calculate parameters for housholder transformation
100
+
101
+ H = 0
102
+ for k in xrange(0, i):
103
+ A[k,i] *= scale_inv
104
+ H += A[k,i] * A[k,i]
105
+
106
+ F = A[i-1,i]
107
+ G = ctx.sqrt(H)
108
+ if F > 0:
109
+ G = -G
110
+ E[i] = scale * G
111
+ H -= F * G
112
+ A[i-1,i] = F - G
113
+ F = 0
114
+
115
+ # apply housholder transformation
116
+
117
+ for j in xrange(0, i):
118
+ if calc_ev:
119
+ A[i,j] = A[j,i] / H
120
+
121
+ G = 0 # calculate A*U
122
+ for k in xrange(0, j + 1):
123
+ G += A[k,j] * A[k,i]
124
+ for k in xrange(j + 1, i):
125
+ G += A[j,k] * A[k,i]
126
+
127
+ E[j] = G / H # calculate P
128
+ F += E[j] * A[j,i]
129
+
130
+ HH = F / (2 * H)
131
+
132
+ for j in xrange(0, i): # calculate reduced A
133
+ F = A[j,i]
134
+ G = E[j] - HH * F # calculate Q
135
+ E[j] = G
136
+
137
+ for k in xrange(0, j + 1):
138
+ A[k,j] -= F * E[k] + G * A[k,i]
139
+
140
+ D[i] = H
141
+
142
+ for i in xrange(1, n): # better for compatibility
143
+ E[i-1] = E[i]
144
+ E[n-1] = 0
145
+
146
+ if calc_ev:
147
+ D[0] = 0
148
+ for i in xrange(0, n):
149
+ if D[i] != 0:
150
+ for j in xrange(0, i): # accumulate transformation matrices
151
+ G = 0
152
+ for k in xrange(0, i):
153
+ G += A[i,k] * A[k,j]
154
+ for k in xrange(0, i):
155
+ A[k,j] -= G * A[k,i]
156
+
157
+ D[i] = A[i,i]
158
+ A[i,i] = 1
159
+
160
+ for j in xrange(0, i):
161
+ A[j,i] = A[i,j] = 0
162
+ else:
163
+ for i in xrange(0, n):
164
+ D[i] = A[i,i]
165
+
166
+
167
+
168
+
169
+
170
+ def c_he_tridiag_0(ctx, A, D, E, T):
171
+ """
172
+ This routine transforms a complex hermitian matrix A to a real symmetric
173
+ tridiagonal matrix T using an unitary similarity transformation:
174
+ Q' * A * Q = T (here ' denotes the hermitian matrix transpose,
175
+ i.e. transposition und conjugation).
176
+ The unitary matrix Q is build up from Householder reflectors and
177
+ an unitary diagonal matrix.
178
+
179
+ parameters:
180
+ A (input/output) On input, A contains the complex hermitian matrix
181
+ of dimension (n,n). On output, A contains the unitary matrix Q
182
+ in compressed form.
183
+
184
+ D (output) real array of length n, contains the diagonal elements
185
+ of the tridiagonal matrix.
186
+
187
+ E (output) real array of length n, contains the offdiagonal elements
188
+ of the tridiagonal matrix in E[0:(n-1)] where is the dimension of
189
+ the matrix A. E[n-1] is undefined.
190
+
191
+ T (output) complex array of length n, contains a unitary diagonal
192
+ matrix.
193
+
194
+ This routine is a python translation (in slightly modified form) of the fortran
195
+ routine htridi.f in the software library EISPACK (see netlib.org) which itself
196
+ is a complex version of the algol procedure tred1 described in:
197
+ - Num. Math. 11, p.181-195 (1968) by Martin, Reinsch and Wilkonson
198
+ - Handbook for auto. comp., Vol II, Linear Algebra, p.212-226 (1971)
199
+
200
+ For a good introduction to Householder reflections, see also
201
+ Stoer, Bulirsch - Introduction to Numerical Analysis.
202
+ """
203
+
204
+ n = A.rows
205
+ T[n-1] = 1
206
+ for i in xrange(n - 1, 0, -1):
207
+
208
+ # scale the vector
209
+
210
+ scale = 0
211
+ for k in xrange(0, i):
212
+ scale += abs(ctx.re(A[k,i])) + abs(ctx.im(A[k,i]))
213
+
214
+ scale_inv = 0
215
+ if scale != 0:
216
+ scale_inv = 1 / scale
217
+
218
+ # sadly there are floating point numbers not equal to zero whose reciprocal is infinity
219
+
220
+ if scale == 0 or ctx.isinf(scale_inv):
221
+ E[i] = 0
222
+ D[i] = 0
223
+ T[i-1] = 1
224
+ continue
225
+
226
+ if i == 1:
227
+ F = A[i-1,i]
228
+ f = abs(F)
229
+ E[i] = f
230
+ D[i] = 0
231
+ if f != 0:
232
+ T[i-1] = T[i] * F / f
233
+ else:
234
+ T[i-1] = T[i]
235
+ continue
236
+
237
+ # calculate parameters for housholder transformation
238
+
239
+ H = 0
240
+ for k in xrange(0, i):
241
+ A[k,i] *= scale_inv
242
+ rr = ctx.re(A[k,i])
243
+ ii = ctx.im(A[k,i])
244
+ H += rr * rr + ii * ii
245
+
246
+ F = A[i-1,i]
247
+ f = abs(F)
248
+ G = ctx.sqrt(H)
249
+ H += G * f
250
+ E[i] = scale * G
251
+ if f != 0:
252
+ F = F / f
253
+ TZ = - T[i] * F # T[i-1]=-T[i]*F, but we need T[i-1] as temporary storage
254
+ G *= F
255
+ else:
256
+ TZ = -T[i] # T[i-1]=-T[i]
257
+ A[i-1,i] += G
258
+ F = 0
259
+
260
+ # apply housholder transformation
261
+
262
+ for j in xrange(0, i):
263
+ A[i,j] = A[j,i] / H
264
+
265
+ G = 0 # calculate A*U
266
+ for k in xrange(0, j + 1):
267
+ G += ctx.conj(A[k,j]) * A[k,i]
268
+ for k in xrange(j + 1, i):
269
+ G += A[j,k] * A[k,i]
270
+
271
+ T[j] = G / H # calculate P
272
+ F += ctx.conj(T[j]) * A[j,i]
273
+
274
+ HH = F / (2 * H)
275
+
276
+ for j in xrange(0, i): # calculate reduced A
277
+ F = A[j,i]
278
+ G = T[j] - HH * F # calculate Q
279
+ T[j] = G
280
+
281
+ for k in xrange(0, j + 1):
282
+ A[k,j] -= ctx.conj(F) * T[k] + ctx.conj(G) * A[k,i]
283
+ # as we use the lower left part for storage
284
+ # we have to use the transpose of the normal formula
285
+
286
+ T[i-1] = TZ
287
+ D[i] = H
288
+
289
+ for i in xrange(1, n): # better for compatibility
290
+ E[i-1] = E[i]
291
+ E[n-1] = 0
292
+
293
+ D[0] = 0
294
+ for i in xrange(0, n):
295
+ zw = D[i]
296
+ D[i] = ctx.re(A[i,i])
297
+ A[i,i] = zw
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+ def c_he_tridiag_1(ctx, A, T):
306
+ """
307
+ This routine forms the unitary matrix Q described in c_he_tridiag_0.
308
+
309
+ parameters:
310
+ A (input/output) On input, A is the same matrix as delivered by
311
+ c_he_tridiag_0. On output, A is set to Q.
312
+
313
+ T (input) On input, T is the same array as delivered by c_he_tridiag_0.
314
+
315
+ """
316
+
317
+ n = A.rows
318
+
319
+ for i in xrange(0, n):
320
+ if A[i,i] != 0:
321
+ for j in xrange(0, i):
322
+ G = 0
323
+ for k in xrange(0, i):
324
+ G += ctx.conj(A[i,k]) * A[k,j]
325
+ for k in xrange(0, i):
326
+ A[k,j] -= G * A[k,i]
327
+
328
+ A[i,i] = 1
329
+
330
+ for j in xrange(0, i):
331
+ A[j,i] = A[i,j] = 0
332
+
333
+ for i in xrange(0, n):
334
+ for k in xrange(0, n):
335
+ A[i,k] *= T[k]
336
+
337
+
338
+
339
+
340
+ def c_he_tridiag_2(ctx, A, T, B):
341
+ """
342
+ This routine applied the unitary matrix Q described in c_he_tridiag_0
343
+ onto the the matrix B, i.e. it forms Q*B.
344
+
345
+ parameters:
346
+ A (input) On input, A is the same matrix as delivered by c_he_tridiag_0.
347
+
348
+ T (input) On input, T is the same array as delivered by c_he_tridiag_0.
349
+
350
+ B (input/output) On input, B is a complex matrix. On output B is replaced
351
+ by Q*B.
352
+
353
+ This routine is a python translation of the fortran routine htribk.f in the
354
+ software library EISPACK (see netlib.org). See c_he_tridiag_0 for more
355
+ references.
356
+ """
357
+
358
+ n = A.rows
359
+
360
+ for i in xrange(0, n):
361
+ for k in xrange(0, n):
362
+ B[k,i] *= T[k]
363
+
364
+ for i in xrange(0, n):
365
+ if A[i,i] != 0:
366
+ for j in xrange(0, n):
367
+ G = 0
368
+ for k in xrange(0, i):
369
+ G += ctx.conj(A[i,k]) * B[k,j]
370
+ for k in xrange(0, i):
371
+ B[k,j] -= G * A[k,i]
372
+
373
+
374
+
375
+
376
+
377
+ def tridiag_eigen(ctx, d, e, z = False):
378
+ """
379
+ This subroutine find the eigenvalues and the first components of the
380
+ eigenvectors of a real symmetric tridiagonal matrix using the implicit
381
+ QL method.
382
+
383
+ parameters:
384
+
385
+ d (input/output) real array of length n. on input, d contains the diagonal
386
+ elements of the input matrix. on output, d contains the eigenvalues in
387
+ ascending order.
388
+
389
+ e (input) real array of length n. on input, e contains the offdiagonal
390
+ elements of the input matrix in e[0:(n-1)]. On output, e has been
391
+ destroyed.
392
+
393
+ z (input/output) If z is equal to False, no eigenvectors will be computed.
394
+ Otherwise on input z should have the format z[0:m,0:n] (i.e. a real or
395
+ complex matrix of dimension (m,n) ). On output this matrix will be
396
+ multiplied by the matrix of the eigenvectors (i.e. the columns of this
397
+ matrix are the eigenvectors): z --> z*EV
398
+ That means if z[i,j]={1 if j==j; 0 otherwise} on input, then on output
399
+ z will contain the first m components of the eigenvectors. That means
400
+ if m is equal to n, the i-th eigenvector will be z[:,i].
401
+
402
+ This routine is a python translation (in slightly modified form) of the
403
+ fortran routine imtql2.f in the software library EISPACK (see netlib.org)
404
+ which itself is based on the algol procudure imtql2 desribed in:
405
+ - num. math. 12, p. 377-383(1968) by matrin and wilkinson
406
+ - modified in num. math. 15, p. 450(1970) by dubrulle
407
+ - handbook for auto. comp., vol. II-linear algebra, p. 241-248 (1971)
408
+ See also the routine gaussq.f in netlog.org or acm algorithm 726.
409
+ """
410
+
411
+ n = len(d)
412
+ e[n-1] = 0
413
+ iterlim = 2 * ctx.dps
414
+
415
+ for l in xrange(n):
416
+ j = 0
417
+ while 1:
418
+ m = l
419
+ while 1:
420
+ # look for a small subdiagonal element
421
+ if m + 1 == n:
422
+ break
423
+ if abs(e[m]) <= ctx.eps * (abs(d[m]) + abs(d[m + 1])):
424
+ break
425
+ m = m + 1
426
+ if m == l:
427
+ break
428
+
429
+ if j >= iterlim:
430
+ raise RuntimeError("tridiag_eigen: no convergence to an eigenvalue after %d iterations" % iterlim)
431
+
432
+ j += 1
433
+
434
+ # form shift
435
+
436
+ p = d[l]
437
+ g = (d[l + 1] - p) / (2 * e[l])
438
+ r = ctx.hypot(g, 1)
439
+
440
+ if g < 0:
441
+ s = g - r
442
+ else:
443
+ s = g + r
444
+
445
+ g = d[m] - p + e[l] / s
446
+
447
+ s, c, p = 1, 1, 0
448
+
449
+ for i in xrange(m - 1, l - 1, -1):
450
+ f = s * e[i]
451
+ b = c * e[i]
452
+ if abs(f) > abs(g): # this here is a slight improvement also used in gaussq.f or acm algorithm 726.
453
+ c = g / f
454
+ r = ctx.hypot(c, 1)
455
+ e[i + 1] = f * r
456
+ s = 1 / r
457
+ c = c * s
458
+ else:
459
+ s = f / g
460
+ r = ctx.hypot(s, 1)
461
+ e[i + 1] = g * r
462
+ c = 1 / r
463
+ s = s * c
464
+ g = d[i + 1] - p
465
+ r = (d[i] - g) * s + 2 * c * b
466
+ p = s * r
467
+ d[i + 1] = g + p
468
+ g = c * r - b
469
+
470
+ if not isinstance(z, bool):
471
+ # calculate eigenvectors
472
+ for w in xrange(z.rows):
473
+ f = z[w,i+1]
474
+ z[w,i+1] = s * z[w,i] + c * f
475
+ z[w,i ] = c * z[w,i] - s * f
476
+
477
+ d[l] = d[l] - p
478
+ e[l] = g
479
+ e[m] = 0
480
+
481
+ for ii in xrange(1, n):
482
+ # sort eigenvalues and eigenvectors (bubble-sort)
483
+ i = ii - 1
484
+ k = i
485
+ p = d[i]
486
+ for j in xrange(ii, n):
487
+ if d[j] >= p:
488
+ continue
489
+ k = j
490
+ p = d[k]
491
+ if k == i:
492
+ continue
493
+ d[k] = d[i]
494
+ d[i] = p
495
+
496
+ if not isinstance(z, bool):
497
+ for w in xrange(z.rows):
498
+ p = z[w,i]
499
+ z[w,i] = z[w,k]
500
+ z[w,k] = p
501
+
502
+ ########################################################################################
503
+
504
+ @defun
505
+ def eigsy(ctx, A, eigvals_only = False, overwrite_a = False):
506
+ """
507
+ This routine solves the (ordinary) eigenvalue problem for a real symmetric
508
+ square matrix A. Given A, an orthogonal matrix Q is calculated which
509
+ diagonalizes A:
510
+
511
+ Q' A Q = diag(E) and Q Q' = Q' Q = 1
512
+
513
+ Here diag(E) is a diagonal matrix whose diagonal is E.
514
+ ' denotes the transpose.
515
+
516
+ The columns of Q are the eigenvectors of A and E contains the eigenvalues:
517
+
518
+ A Q[:,i] = E[i] Q[:,i]
519
+
520
+
521
+ input:
522
+
523
+ A: real matrix of format (n,n) which is symmetric
524
+ (i.e. A=A' or A[i,j]=A[j,i])
525
+
526
+ eigvals_only: if true, calculates only the eigenvalues E.
527
+ if false, calculates both eigenvectors and eigenvalues.
528
+
529
+ overwrite_a: if true, allows modification of A which may improve
530
+ performance. if false, A is not modified.
531
+
532
+ output:
533
+
534
+ E: vector of format (n). contains the eigenvalues of A in ascending order.
535
+
536
+ Q: orthogonal matrix of format (n,n). contains the eigenvectors
537
+ of A as columns.
538
+
539
+ return value:
540
+
541
+ E if eigvals_only is true
542
+ (E, Q) if eigvals_only is false
543
+
544
+ example:
545
+ >>> from mpmath import mp
546
+ >>> A = mp.matrix([[3, 2], [2, 0]])
547
+ >>> E = mp.eigsy(A, eigvals_only = True)
548
+ >>> print(E)
549
+ [-1.0]
550
+ [ 4.0]
551
+
552
+ >>> A = mp.matrix([[1, 2], [2, 3]])
553
+ >>> E, Q = mp.eigsy(A)
554
+ >>> print(mp.chop(A * Q[:,0] - E[0] * Q[:,0]))
555
+ [0.0]
556
+ [0.0]
557
+
558
+ see also: eighe, eigh, eig
559
+ """
560
+
561
+ if not overwrite_a:
562
+ A = A.copy()
563
+
564
+ d = ctx.zeros(A.rows, 1)
565
+ e = ctx.zeros(A.rows, 1)
566
+
567
+ if eigvals_only:
568
+ r_sy_tridiag(ctx, A, d, e, calc_ev = False)
569
+ tridiag_eigen(ctx, d, e, False)
570
+ return d
571
+ else:
572
+ r_sy_tridiag(ctx, A, d, e, calc_ev = True)
573
+ tridiag_eigen(ctx, d, e, A)
574
+ return (d, A)
575
+
576
+
577
+ @defun
578
+ def eighe(ctx, A, eigvals_only = False, overwrite_a = False):
579
+ """
580
+ This routine solves the (ordinary) eigenvalue problem for a complex
581
+ hermitian square matrix A. Given A, an unitary matrix Q is calculated which
582
+ diagonalizes A:
583
+
584
+ Q' A Q = diag(E) and Q Q' = Q' Q = 1
585
+
586
+ Here diag(E) a is diagonal matrix whose diagonal is E.
587
+ ' denotes the hermitian transpose (i.e. ordinary transposition and
588
+ complex conjugation).
589
+
590
+ The columns of Q are the eigenvectors of A and E contains the eigenvalues:
591
+
592
+ A Q[:,i] = E[i] Q[:,i]
593
+
594
+
595
+ input:
596
+
597
+ A: complex matrix of format (n,n) which is hermitian
598
+ (i.e. A=A' or A[i,j]=conj(A[j,i]))
599
+
600
+ eigvals_only: if true, calculates only the eigenvalues E.
601
+ if false, calculates both eigenvectors and eigenvalues.
602
+
603
+ overwrite_a: if true, allows modification of A which may improve
604
+ performance. if false, A is not modified.
605
+
606
+ output:
607
+
608
+ E: vector of format (n). contains the eigenvalues of A in ascending order.
609
+
610
+ Q: unitary matrix of format (n,n). contains the eigenvectors
611
+ of A as columns.
612
+
613
+ return value:
614
+
615
+ E if eigvals_only is true
616
+ (E, Q) if eigvals_only is false
617
+
618
+ example:
619
+ >>> from mpmath import mp
620
+ >>> A = mp.matrix([[1, -3 - 1j], [-3 + 1j, -2]])
621
+ >>> E = mp.eighe(A, eigvals_only = True)
622
+ >>> print(E)
623
+ [-4.0]
624
+ [ 3.0]
625
+
626
+ >>> A = mp.matrix([[1, 2 + 5j], [2 - 5j, 3]])
627
+ >>> E, Q = mp.eighe(A)
628
+ >>> print(mp.chop(A * Q[:,0] - E[0] * Q[:,0]))
629
+ [0.0]
630
+ [0.0]
631
+
632
+ see also: eigsy, eigh, eig
633
+ """
634
+
635
+ if not overwrite_a:
636
+ A = A.copy()
637
+
638
+ d = ctx.zeros(A.rows, 1)
639
+ e = ctx.zeros(A.rows, 1)
640
+ t = ctx.zeros(A.rows, 1)
641
+
642
+ if eigvals_only:
643
+ c_he_tridiag_0(ctx, A, d, e, t)
644
+ tridiag_eigen(ctx, d, e, False)
645
+ return d
646
+ else:
647
+ c_he_tridiag_0(ctx, A, d, e, t)
648
+ B = ctx.eye(A.rows)
649
+ tridiag_eigen(ctx, d, e, B)
650
+ c_he_tridiag_2(ctx, A, t, B)
651
+ return (d, B)
652
+
653
+ @defun
654
+ def eigh(ctx, A, eigvals_only = False, overwrite_a = False):
655
+ """
656
+ "eigh" is a unified interface for "eigsy" and "eighe". Depending on
657
+ whether A is real or complex the appropriate function is called.
658
+
659
+ This routine solves the (ordinary) eigenvalue problem for a real symmetric
660
+ or complex hermitian square matrix A. Given A, an orthogonal (A real) or
661
+ unitary (A complex) matrix Q is calculated which diagonalizes A:
662
+
663
+ Q' A Q = diag(E) and Q Q' = Q' Q = 1
664
+
665
+ Here diag(E) a is diagonal matrix whose diagonal is E.
666
+ ' denotes the hermitian transpose (i.e. ordinary transposition and
667
+ complex conjugation).
668
+
669
+ The columns of Q are the eigenvectors of A and E contains the eigenvalues:
670
+
671
+ A Q[:,i] = E[i] Q[:,i]
672
+
673
+ input:
674
+
675
+ A: a real or complex square matrix of format (n,n) which is symmetric
676
+ (i.e. A[i,j]=A[j,i]) or hermitian (i.e. A[i,j]=conj(A[j,i])).
677
+
678
+ eigvals_only: if true, calculates only the eigenvalues E.
679
+ if false, calculates both eigenvectors and eigenvalues.
680
+
681
+ overwrite_a: if true, allows modification of A which may improve
682
+ performance. if false, A is not modified.
683
+
684
+ output:
685
+
686
+ E: vector of format (n). contains the eigenvalues of A in ascending order.
687
+
688
+ Q: an orthogonal or unitary matrix of format (n,n). contains the
689
+ eigenvectors of A as columns.
690
+
691
+ return value:
692
+
693
+ E if eigvals_only is true
694
+ (E, Q) if eigvals_only is false
695
+
696
+ example:
697
+ >>> from mpmath import mp
698
+ >>> A = mp.matrix([[3, 2], [2, 0]])
699
+ >>> E = mp.eigh(A, eigvals_only = True)
700
+ >>> print(E)
701
+ [-1.0]
702
+ [ 4.0]
703
+
704
+ >>> A = mp.matrix([[1, 2], [2, 3]])
705
+ >>> E, Q = mp.eigh(A)
706
+ >>> print(mp.chop(A * Q[:,0] - E[0] * Q[:,0]))
707
+ [0.0]
708
+ [0.0]
709
+
710
+ >>> A = mp.matrix([[1, 2 + 5j], [2 - 5j, 3]])
711
+ >>> E, Q = mp.eigh(A)
712
+ >>> print(mp.chop(A * Q[:,0] - E[0] * Q[:,0]))
713
+ [0.0]
714
+ [0.0]
715
+
716
+ see also: eigsy, eighe, eig
717
+ """
718
+
719
+ iscomplex = any(type(x) is ctx.mpc for x in A)
720
+
721
+ if iscomplex:
722
+ return ctx.eighe(A, eigvals_only = eigvals_only, overwrite_a = overwrite_a)
723
+ else:
724
+ return ctx.eigsy(A, eigvals_only = eigvals_only, overwrite_a = overwrite_a)
725
+
726
+
727
+ @defun
728
+ def gauss_quadrature(ctx, n, qtype = "legendre", alpha = 0, beta = 0):
729
+ """
730
+ This routine calulates gaussian quadrature rules for different
731
+ families of orthogonal polynomials. Let (a, b) be an interval,
732
+ W(x) a positive weight function and n a positive integer.
733
+ Then the purpose of this routine is to calculate pairs (x_k, w_k)
734
+ for k=0, 1, 2, ... (n-1) which give
735
+
736
+ int(W(x) * F(x), x = a..b) = sum(w_k * F(x_k),k = 0..(n-1))
737
+
738
+ exact for all polynomials F(x) of degree (strictly) less than 2*n. For all
739
+ integrable functions F(x) the sum is a (more or less) good approximation to
740
+ the integral. The x_k are called nodes (which are the zeros of the
741
+ related orthogonal polynomials) and the w_k are called the weights.
742
+
743
+ parameters
744
+ n (input) The degree of the quadrature rule, i.e. its number of
745
+ nodes.
746
+
747
+ qtype (input) The family of orthogonal polynmomials for which to
748
+ compute the quadrature rule. See the list below.
749
+
750
+ alpha (input) real number, used as parameter for some orthogonal
751
+ polynomials
752
+
753
+ beta (input) real number, used as parameter for some orthogonal
754
+ polynomials.
755
+
756
+ return value
757
+
758
+ (X, W) a pair of two real arrays where x_k = X[k] and w_k = W[k].
759
+
760
+
761
+ orthogonal polynomials:
762
+
763
+ qtype polynomial
764
+ ----- ----------
765
+
766
+ "legendre" Legendre polynomials, W(x)=1 on the interval (-1, +1)
767
+ "legendre01" shifted Legendre polynomials, W(x)=1 on the interval (0, +1)
768
+ "hermite" Hermite polynomials, W(x)=exp(-x*x) on (-infinity,+infinity)
769
+ "laguerre" Laguerre polynomials, W(x)=exp(-x) on (0,+infinity)
770
+ "glaguerre" generalized Laguerre polynomials, W(x)=exp(-x)*x**alpha
771
+ on (0, +infinity)
772
+ "chebyshev1" Chebyshev polynomials of the first kind, W(x)=1/sqrt(1-x*x)
773
+ on (-1, +1)
774
+ "chebyshev2" Chebyshev polynomials of the second kind, W(x)=sqrt(1-x*x)
775
+ on (-1, +1)
776
+ "jacobi" Jacobi polynomials, W(x)=(1-x)**alpha * (1+x)**beta on (-1, +1)
777
+ with alpha>-1 and beta>-1
778
+
779
+ examples:
780
+ >>> from mpmath import mp
781
+ >>> f = lambda x: x**8 + 2 * x**6 - 3 * x**4 + 5 * x**2 - 7
782
+ >>> X, W = mp.gauss_quadrature(5, "hermite")
783
+ >>> A = mp.fdot([(f(x), w) for x, w in zip(X, W)])
784
+ >>> B = mp.sqrt(mp.pi) * 57 / 16
785
+ >>> C = mp.quad(lambda x: mp.exp(- x * x) * f(x), [-mp.inf, +mp.inf])
786
+ >>> mp.nprint((mp.chop(A-B, tol = 1e-10), mp.chop(A-C, tol = 1e-10)))
787
+ (0.0, 0.0)
788
+
789
+ >>> f = lambda x: x**5 - 2 * x**4 + 3 * x**3 - 5 * x**2 + 7 * x - 11
790
+ >>> X, W = mp.gauss_quadrature(3, "laguerre")
791
+ >>> A = mp.fdot([(f(x), w) for x, w in zip(X, W)])
792
+ >>> B = 76
793
+ >>> C = mp.quad(lambda x: mp.exp(-x) * f(x), [0, +mp.inf])
794
+ >>> mp.nprint(mp.chop(A-B, tol = 1e-10), mp.chop(A-C, tol = 1e-10))
795
+ .0
796
+
797
+ # orthogonality of the chebyshev polynomials:
798
+ >>> f = lambda x: mp.chebyt(3, x) * mp.chebyt(2, x)
799
+ >>> X, W = mp.gauss_quadrature(3, "chebyshev1")
800
+ >>> A = mp.fdot([(f(x), w) for x, w in zip(X, W)])
801
+ >>> print(mp.chop(A, tol = 1e-10))
802
+ 0.0
803
+
804
+ references:
805
+ - golub and welsch, "calculations of gaussian quadrature rules", mathematics of
806
+ computation 23, p. 221-230 (1969)
807
+ - golub, "some modified matrix eigenvalue problems", siam review 15, p. 318-334 (1973)
808
+ - stroud and secrest, "gaussian quadrature formulas", prentice-hall (1966)
809
+
810
+ See also the routine gaussq.f in netlog.org or ACM Transactions on
811
+ Mathematical Software algorithm 726.
812
+ """
813
+
814
+ d = ctx.zeros(n, 1)
815
+ e = ctx.zeros(n, 1)
816
+ z = ctx.zeros(1, n)
817
+
818
+ z[0,0] = 1
819
+
820
+ if qtype == "legendre":
821
+ # legendre on the range -1 +1 , abramowitz, table 25.4, p.916
822
+ w = 2
823
+ for i in xrange(n):
824
+ j = i + 1
825
+ e[i] = ctx.sqrt(j * j / (4 * j * j - ctx.mpf(1)))
826
+ elif qtype == "legendre01":
827
+ # legendre shifted to 0 1 , abramowitz, table 25.8, p.921
828
+ w = 1
829
+ for i in xrange(n):
830
+ d[i] = 1 / ctx.mpf(2)
831
+ j = i + 1
832
+ e[i] = ctx.sqrt(j * j / (16 * j * j - ctx.mpf(4)))
833
+ elif qtype == "hermite":
834
+ # hermite on the range -inf +inf , abramowitz, table 25.10,p.924
835
+ w = ctx.sqrt(ctx.pi)
836
+ for i in xrange(n):
837
+ j = i + 1
838
+ e[i] = ctx.sqrt(j / ctx.mpf(2))
839
+ elif qtype == "laguerre":
840
+ # laguerre on the range 0 +inf , abramowitz, table 25.9, p. 923
841
+ w = 1
842
+ for i in xrange(n):
843
+ j = i + 1
844
+ d[i] = 2 * j - 1
845
+ e[i] = j
846
+ elif qtype=="chebyshev1":
847
+ # chebyshev polynimials of the first kind
848
+ w = ctx.pi
849
+ for i in xrange(n):
850
+ e[i] = 1 / ctx.mpf(2)
851
+ e[0] = ctx.sqrt(1 / ctx.mpf(2))
852
+ elif qtype == "chebyshev2":
853
+ # chebyshev polynimials of the second kind
854
+ w = ctx.pi / 2
855
+ for i in xrange(n):
856
+ e[i] = 1 / ctx.mpf(2)
857
+ elif qtype == "glaguerre":
858
+ # generalized laguerre on the range 0 +inf
859
+ w = ctx.gamma(1 + alpha)
860
+ for i in xrange(n):
861
+ j = i + 1
862
+ d[i] = 2 * j - 1 + alpha
863
+ e[i] = ctx.sqrt(j * (j + alpha))
864
+ elif qtype == "jacobi":
865
+ # jacobi polynomials
866
+ alpha = ctx.mpf(alpha)
867
+ beta = ctx.mpf(beta)
868
+ ab = alpha + beta
869
+ abi = ab + 2
870
+ w = (2**(ab+1)) * ctx.gamma(alpha + 1) * ctx.gamma(beta + 1) / ctx.gamma(abi)
871
+ d[0] = (beta - alpha) / abi
872
+ e[0] = ctx.sqrt(4 * (1 + alpha) * (1 + beta) / ((abi + 1) * (abi * abi)))
873
+ a2b2 = beta * beta - alpha * alpha
874
+ for i in xrange(1, n):
875
+ j = i + 1
876
+ abi = 2 * j + ab
877
+ d[i] = a2b2 / ((abi - 2) * abi)
878
+ e[i] = ctx.sqrt(4 * j * (j + alpha) * (j + beta) * (j + ab) / ((abi * abi - 1) * abi * abi))
879
+ elif isinstance(qtype, str):
880
+ raise ValueError("unknown quadrature rule \"%s\"" % qtype)
881
+ elif not isinstance(qtype, str):
882
+ w = qtype(d, e)
883
+ else:
884
+ assert 0
885
+
886
+ tridiag_eigen(ctx, d, e, z)
887
+
888
+ for i in xrange(len(z)):
889
+ z[i] *= z[i]
890
+
891
+ z = z.transpose()
892
+ return (d, w * z)
893
+
894
+ ##################################################################################################
895
+ ##################################################################################################
896
+ ##################################################################################################
897
+
898
+ def svd_r_raw(ctx, A, V = False, calc_u = False):
899
+ """
900
+ This routine computes the singular value decomposition of a matrix A.
901
+ Given A, two orthogonal matrices U and V are calculated such that
902
+
903
+ A = U S V
904
+
905
+ where S is a suitable shaped matrix whose off-diagonal elements are zero.
906
+ The diagonal elements of S are the singular values of A, i.e. the
907
+ squareroots of the eigenvalues of A' A or A A'. Here ' denotes the transpose.
908
+ Householder bidiagonalization and a variant of the QR algorithm is used.
909
+
910
+ overview of the matrices :
911
+
912
+ A : m*n A gets replaced by U
913
+ U : m*n U replaces A. If n>m then only the first m*m block of U is
914
+ non-zero. column-orthogonal: U' U = B
915
+ here B is a n*n matrix whose first min(m,n) diagonal
916
+ elements are 1 and all other elements are zero.
917
+ S : n*n diagonal matrix, only the diagonal elements are stored in
918
+ the array S. only the first min(m,n) diagonal elements are non-zero.
919
+ V : n*n orthogonal: V V' = V' V = 1
920
+
921
+ parameters:
922
+ A (input/output) On input, A contains a real matrix of shape m*n.
923
+ On output, if calc_u is true A contains the column-orthogonal
924
+ matrix U; otherwise A is simply used as workspace and thus destroyed.
925
+
926
+ V (input/output) if false, the matrix V is not calculated. otherwise
927
+ V must be a matrix of shape n*n.
928
+
929
+ calc_u (input) If true, the matrix U is calculated and replaces A.
930
+ if false, U is not calculated and A is simply destroyed
931
+
932
+ return value:
933
+ S an array of length n containing the singular values of A sorted by
934
+ decreasing magnitude. only the first min(m,n) elements are non-zero.
935
+
936
+ This routine is a python translation of the fortran routine svd.f in the
937
+ software library EISPACK (see netlib.org) which itself is based on the
938
+ algol procedure svd described in:
939
+ - num. math. 14, 403-420(1970) by golub and reinsch.
940
+ - wilkinson/reinsch: handbook for auto. comp., vol ii-linear algebra, 134-151(1971).
941
+
942
+ """
943
+
944
+ m, n = A.rows, A.cols
945
+
946
+ S = ctx.zeros(n, 1)
947
+
948
+ # work is a temporary array of size n
949
+ work = ctx.zeros(n, 1)
950
+
951
+ g = scale = anorm = 0
952
+ maxits = 3 * ctx.dps
953
+
954
+ for i in xrange(n): # householder reduction to bidiagonal form
955
+ work[i] = scale*g
956
+ g = s = scale = 0
957
+ if i < m:
958
+ for k in xrange(i, m):
959
+ scale += ctx.fabs(A[k,i])
960
+ if scale != 0:
961
+ for k in xrange(i, m):
962
+ A[k,i] /= scale
963
+ s += A[k,i] * A[k,i]
964
+ f = A[i,i]
965
+ g = -ctx.sqrt(s)
966
+ if f < 0:
967
+ g = -g
968
+ h = f * g - s
969
+ A[i,i] = f - g
970
+ for j in xrange(i+1, n):
971
+ s = 0
972
+ for k in xrange(i, m):
973
+ s += A[k,i] * A[k,j]
974
+ f = s / h
975
+ for k in xrange(i, m):
976
+ A[k,j] += f * A[k,i]
977
+ for k in xrange(i,m):
978
+ A[k,i] *= scale
979
+
980
+ S[i] = scale * g
981
+ g = s = scale = 0
982
+
983
+ if i < m and i != n - 1:
984
+ for k in xrange(i+1, n):
985
+ scale += ctx.fabs(A[i,k])
986
+ if scale:
987
+ for k in xrange(i+1, n):
988
+ A[i,k] /= scale
989
+ s += A[i,k] * A[i,k]
990
+ f = A[i,i+1]
991
+ g = -ctx.sqrt(s)
992
+ if f < 0:
993
+ g = -g
994
+ h = f * g - s
995
+ A[i,i+1] = f - g
996
+
997
+ for k in xrange(i+1, n):
998
+ work[k] = A[i,k] / h
999
+
1000
+ for j in xrange(i+1, m):
1001
+ s = 0
1002
+ for k in xrange(i+1, n):
1003
+ s += A[j,k] * A[i,k]
1004
+ for k in xrange(i+1, n):
1005
+ A[j,k] += s * work[k]
1006
+
1007
+ for k in xrange(i+1, n):
1008
+ A[i,k] *= scale
1009
+
1010
+ anorm = max(anorm, ctx.fabs(S[i]) + ctx.fabs(work[i]))
1011
+
1012
+ if not isinstance(V, bool):
1013
+ for i in xrange(n-2, -1, -1): # accumulation of right hand transformations
1014
+ V[i+1,i+1] = 1
1015
+
1016
+ if work[i+1] != 0:
1017
+ for j in xrange(i+1, n):
1018
+ V[i,j] = (A[i,j] / A[i,i+1]) / work[i+1]
1019
+ for j in xrange(i+1, n):
1020
+ s = 0
1021
+ for k in xrange(i+1, n):
1022
+ s += A[i,k] * V[j,k]
1023
+ for k in xrange(i+1, n):
1024
+ V[j,k] += s * V[i,k]
1025
+
1026
+ for j in xrange(i+1, n):
1027
+ V[j,i] = V[i,j] = 0
1028
+
1029
+ V[0,0] = 1
1030
+
1031
+ if m<n : minnm = m
1032
+ else : minnm = n
1033
+
1034
+ if calc_u:
1035
+ for i in xrange(minnm-1, -1, -1): # accumulation of left hand transformations
1036
+ g = S[i]
1037
+ for j in xrange(i+1, n):
1038
+ A[i,j] = 0
1039
+ if g != 0:
1040
+ g = 1 / g
1041
+ for j in xrange(i+1, n):
1042
+ s = 0
1043
+ for k in xrange(i+1, m):
1044
+ s += A[k,i] * A[k,j]
1045
+ f = (s / A[i,i]) * g
1046
+ for k in xrange(i, m):
1047
+ A[k,j] += f * A[k,i]
1048
+ for j in xrange(i, m):
1049
+ A[j,i] *= g
1050
+ else:
1051
+ for j in xrange(i, m):
1052
+ A[j,i] = 0
1053
+ A[i,i] += 1
1054
+
1055
+ for k in xrange(n - 1, -1, -1):
1056
+ # diagonalization of the bidiagonal form:
1057
+ # loop over singular values, and over allowed itations
1058
+
1059
+ its = 0
1060
+ while 1:
1061
+ its += 1
1062
+ flag = True
1063
+
1064
+ for l in xrange(k, -1, -1):
1065
+ nm = l-1
1066
+
1067
+ if ctx.fabs(work[l]) + anorm == anorm:
1068
+ flag = False
1069
+ break
1070
+
1071
+ if ctx.fabs(S[nm]) + anorm == anorm:
1072
+ break
1073
+
1074
+ if flag:
1075
+ c = 0
1076
+ s = 1
1077
+ for i in xrange(l, k + 1):
1078
+ f = s * work[i]
1079
+ work[i] *= c
1080
+ if ctx.fabs(f) + anorm == anorm:
1081
+ break
1082
+ g = S[i]
1083
+ h = ctx.hypot(f, g)
1084
+ S[i] = h
1085
+ h = 1 / h
1086
+ c = g * h
1087
+ s = - f * h
1088
+
1089
+ if calc_u:
1090
+ for j in xrange(m):
1091
+ y = A[j,nm]
1092
+ z = A[j,i]
1093
+ A[j,nm] = y * c + z * s
1094
+ A[j,i] = z * c - y * s
1095
+
1096
+ z = S[k]
1097
+
1098
+ if l == k: # convergence
1099
+ if z < 0: # singular value is made nonnegative
1100
+ S[k] = -z
1101
+ if not isinstance(V, bool):
1102
+ for j in xrange(n):
1103
+ V[k,j] = -V[k,j]
1104
+ break
1105
+
1106
+ if its >= maxits:
1107
+ raise RuntimeError("svd: no convergence to an eigenvalue after %d iterations" % its)
1108
+
1109
+ x = S[l] # shift from bottom 2 by 2 minor
1110
+ nm = k-1
1111
+ y = S[nm]
1112
+ g = work[nm]
1113
+ h = work[k]
1114
+ f = ((y - z) * (y + z) + (g - h) * (g + h))/(2 * h * y)
1115
+ g = ctx.hypot(f, 1)
1116
+ if f >= 0: f = ((x - z) * (x + z) + h * ((y / (f + g)) - h)) / x
1117
+ else: f = ((x - z) * (x + z) + h * ((y / (f - g)) - h)) / x
1118
+
1119
+ c = s = 1 # next qt transformation
1120
+
1121
+ for j in xrange(l, nm + 1):
1122
+ g = work[j+1]
1123
+ y = S[j+1]
1124
+ h = s * g
1125
+ g = c * g
1126
+ z = ctx.hypot(f, h)
1127
+ work[j] = z
1128
+ c = f / z
1129
+ s = h / z
1130
+ f = x * c + g * s
1131
+ g = g * c - x * s
1132
+ h = y * s
1133
+ y *= c
1134
+ if not isinstance(V, bool):
1135
+ for jj in xrange(n):
1136
+ x = V[j ,jj]
1137
+ z = V[j+1,jj]
1138
+ V[j ,jj]= x * c + z * s
1139
+ V[j+1 ,jj]= z * c - x * s
1140
+ z = ctx.hypot(f, h)
1141
+ S[j] = z
1142
+ if z != 0: # rotation can be arbitray if z=0
1143
+ z = 1 / z
1144
+ c = f * z
1145
+ s = h * z
1146
+ f = c * g + s * y
1147
+ x = c * y - s * g
1148
+
1149
+ if calc_u:
1150
+ for jj in xrange(m):
1151
+ y = A[jj,j ]
1152
+ z = A[jj,j+1]
1153
+ A[jj,j ] = y * c + z * s
1154
+ A[jj,j+1 ] = z * c - y * s
1155
+
1156
+ work[l] = 0
1157
+ work[k] = f
1158
+ S[k] = x
1159
+
1160
+ ##########################
1161
+
1162
+ # Sort singular values into decreasing order (bubble-sort)
1163
+
1164
+ for i in xrange(n):
1165
+ imax = i
1166
+ s = ctx.fabs(S[i]) # s is the current maximal element
1167
+
1168
+ for j in xrange(i + 1, n):
1169
+ c = ctx.fabs(S[j])
1170
+ if c > s:
1171
+ s = c
1172
+ imax = j
1173
+
1174
+ if imax != i:
1175
+ # swap singular values
1176
+
1177
+ z = S[i]
1178
+ S[i] = S[imax]
1179
+ S[imax] = z
1180
+
1181
+ if calc_u:
1182
+ for j in xrange(m):
1183
+ z = A[j,i]
1184
+ A[j,i] = A[j,imax]
1185
+ A[j,imax] = z
1186
+
1187
+ if not isinstance(V, bool):
1188
+ for j in xrange(n):
1189
+ z = V[i,j]
1190
+ V[i,j] = V[imax,j]
1191
+ V[imax,j] = z
1192
+
1193
+ return S
1194
+
1195
+ #######################
1196
+
1197
+ def svd_c_raw(ctx, A, V = False, calc_u = False):
1198
+ """
1199
+ This routine computes the singular value decomposition of a matrix A.
1200
+ Given A, two unitary matrices U and V are calculated such that
1201
+
1202
+ A = U S V
1203
+
1204
+ where S is a suitable shaped matrix whose off-diagonal elements are zero.
1205
+ The diagonal elements of S are the singular values of A, i.e. the
1206
+ squareroots of the eigenvalues of A' A or A A'. Here ' denotes the hermitian
1207
+ transpose (i.e. transposition and conjugation). Householder bidiagonalization
1208
+ and a variant of the QR algorithm is used.
1209
+
1210
+ overview of the matrices :
1211
+
1212
+ A : m*n A gets replaced by U
1213
+ U : m*n U replaces A. If n>m then only the first m*m block of U is
1214
+ non-zero. column-unitary: U' U = B
1215
+ here B is a n*n matrix whose first min(m,n) diagonal
1216
+ elements are 1 and all other elements are zero.
1217
+ S : n*n diagonal matrix, only the diagonal elements are stored in
1218
+ the array S. only the first min(m,n) diagonal elements are non-zero.
1219
+ V : n*n unitary: V V' = V' V = 1
1220
+
1221
+ parameters:
1222
+ A (input/output) On input, A contains a complex matrix of shape m*n.
1223
+ On output, if calc_u is true A contains the column-unitary
1224
+ matrix U; otherwise A is simply used as workspace and thus destroyed.
1225
+
1226
+ V (input/output) if false, the matrix V is not calculated. otherwise
1227
+ V must be a matrix of shape n*n.
1228
+
1229
+ calc_u (input) If true, the matrix U is calculated and replaces A.
1230
+ if false, U is not calculated and A is simply destroyed
1231
+
1232
+ return value:
1233
+ S an array of length n containing the singular values of A sorted by
1234
+ decreasing magnitude. only the first min(m,n) elements are non-zero.
1235
+
1236
+ This routine is a python translation of the fortran routine svd.f in the
1237
+ software library EISPACK (see netlib.org) which itself is based on the
1238
+ algol procedure svd described in:
1239
+ - num. math. 14, 403-420(1970) by golub and reinsch.
1240
+ - wilkinson/reinsch: handbook for auto. comp., vol ii-linear algebra, 134-151(1971).
1241
+
1242
+ """
1243
+
1244
+ m, n = A.rows, A.cols
1245
+
1246
+ S = ctx.zeros(n, 1)
1247
+
1248
+ # work is a temporary array of size n
1249
+ work = ctx.zeros(n, 1)
1250
+ lbeta = ctx.zeros(n, 1)
1251
+ rbeta = ctx.zeros(n, 1)
1252
+ dwork = ctx.zeros(n, 1)
1253
+
1254
+ g = scale = anorm = 0
1255
+ maxits = 3 * ctx.dps
1256
+
1257
+ for i in xrange(n): # householder reduction to bidiagonal form
1258
+ dwork[i] = scale * g # dwork are the side-diagonal elements
1259
+ g = s = scale = 0
1260
+ if i < m:
1261
+ for k in xrange(i, m):
1262
+ scale += ctx.fabs(ctx.re(A[k,i])) + ctx.fabs(ctx.im(A[k,i]))
1263
+ if scale != 0:
1264
+ for k in xrange(i, m):
1265
+ A[k,i] /= scale
1266
+ ar = ctx.re(A[k,i])
1267
+ ai = ctx.im(A[k,i])
1268
+ s += ar * ar + ai * ai
1269
+ f = A[i,i]
1270
+ g = -ctx.sqrt(s)
1271
+ if ctx.re(f) < 0:
1272
+ beta = -g - ctx.conj(f)
1273
+ g = -g
1274
+ else:
1275
+ beta = -g + ctx.conj(f)
1276
+ beta /= ctx.conj(beta)
1277
+ beta += 1
1278
+ h = 2 * (ctx.re(f) * g - s)
1279
+ A[i,i] = f - g
1280
+ beta /= h
1281
+ lbeta[i] = (beta / scale) / scale
1282
+ for j in xrange(i+1, n):
1283
+ s = 0
1284
+ for k in xrange(i, m):
1285
+ s += ctx.conj(A[k,i]) * A[k,j]
1286
+ f = beta * s
1287
+ for k in xrange(i, m):
1288
+ A[k,j] += f * A[k,i]
1289
+ for k in xrange(i, m):
1290
+ A[k,i] *= scale
1291
+
1292
+ S[i] = scale * g # S are the diagonal elements
1293
+ g = s = scale = 0
1294
+
1295
+ if i < m and i != n - 1:
1296
+ for k in xrange(i+1, n):
1297
+ scale += ctx.fabs(ctx.re(A[i,k])) + ctx.fabs(ctx.im(A[i,k]))
1298
+ if scale:
1299
+ for k in xrange(i+1, n):
1300
+ A[i,k] /= scale
1301
+ ar = ctx.re(A[i,k])
1302
+ ai = ctx.im(A[i,k])
1303
+ s += ar * ar + ai * ai
1304
+ f = A[i,i+1]
1305
+ g = -ctx.sqrt(s)
1306
+ if ctx.re(f) < 0:
1307
+ beta = -g - ctx.conj(f)
1308
+ g = -g
1309
+ else:
1310
+ beta = -g + ctx.conj(f)
1311
+
1312
+ beta /= ctx.conj(beta)
1313
+ beta += 1
1314
+
1315
+ h = 2 * (ctx.re(f) * g - s)
1316
+ A[i,i+1] = f - g
1317
+
1318
+ beta /= h
1319
+ rbeta[i] = (beta / scale) / scale
1320
+
1321
+ for k in xrange(i+1, n):
1322
+ work[k] = A[i, k]
1323
+
1324
+ for j in xrange(i+1, m):
1325
+ s = 0
1326
+ for k in xrange(i+1, n):
1327
+ s += ctx.conj(A[i,k]) * A[j,k]
1328
+ f = s * beta
1329
+ for k in xrange(i+1,n):
1330
+ A[j,k] += f * work[k]
1331
+
1332
+ for k in xrange(i+1, n):
1333
+ A[i,k] *= scale
1334
+
1335
+ anorm = max(anorm,ctx.fabs(S[i]) + ctx.fabs(dwork[i]))
1336
+
1337
+ if not isinstance(V, bool):
1338
+ for i in xrange(n-2, -1, -1): # accumulation of right hand transformations
1339
+ V[i+1,i+1] = 1
1340
+
1341
+ if dwork[i+1] != 0:
1342
+ f = ctx.conj(rbeta[i])
1343
+ for j in xrange(i+1, n):
1344
+ V[i,j] = A[i,j] * f
1345
+ for j in xrange(i+1, n):
1346
+ s = 0
1347
+ for k in xrange(i+1, n):
1348
+ s += ctx.conj(A[i,k]) * V[j,k]
1349
+ for k in xrange(i+1, n):
1350
+ V[j,k] += s * V[i,k]
1351
+
1352
+ for j in xrange(i+1,n):
1353
+ V[j,i] = V[i,j] = 0
1354
+
1355
+ V[0,0] = 1
1356
+
1357
+ if m < n : minnm = m
1358
+ else : minnm = n
1359
+
1360
+ if calc_u:
1361
+ for i in xrange(minnm-1, -1, -1): # accumulation of left hand transformations
1362
+ g = S[i]
1363
+ for j in xrange(i+1, n):
1364
+ A[i,j] = 0
1365
+ if g != 0:
1366
+ g = 1 / g
1367
+ for j in xrange(i+1, n):
1368
+ s = 0
1369
+ for k in xrange(i+1, m):
1370
+ s += ctx.conj(A[k,i]) * A[k,j]
1371
+ f = s * ctx.conj(lbeta[i])
1372
+ for k in xrange(i, m):
1373
+ A[k,j] += f * A[k,i]
1374
+ for j in xrange(i, m):
1375
+ A[j,i] *= g
1376
+ else:
1377
+ for j in xrange(i, m):
1378
+ A[j,i] = 0
1379
+ A[i,i] += 1
1380
+
1381
+ for k in xrange(n-1, -1, -1):
1382
+ # diagonalization of the bidiagonal form:
1383
+ # loop over singular values, and over allowed itations
1384
+
1385
+ its = 0
1386
+ while 1:
1387
+ its += 1
1388
+ flag = True
1389
+
1390
+ for l in xrange(k, -1, -1):
1391
+ nm = l - 1
1392
+
1393
+ if ctx.fabs(dwork[l]) + anorm == anorm:
1394
+ flag = False
1395
+ break
1396
+
1397
+ if ctx.fabs(S[nm]) + anorm == anorm:
1398
+ break
1399
+
1400
+ if flag:
1401
+ c = 0
1402
+ s = 1
1403
+ for i in xrange(l, k+1):
1404
+ f = s * dwork[i]
1405
+ dwork[i] *= c
1406
+ if ctx.fabs(f) + anorm == anorm:
1407
+ break
1408
+ g = S[i]
1409
+ h = ctx.hypot(f, g)
1410
+ S[i] = h
1411
+ h = 1 / h
1412
+ c = g * h
1413
+ s = -f * h
1414
+
1415
+ if calc_u:
1416
+ for j in xrange(m):
1417
+ y = A[j,nm]
1418
+ z = A[j,i]
1419
+ A[j,nm]= y * c + z * s
1420
+ A[j,i] = z * c - y * s
1421
+
1422
+ z = S[k]
1423
+
1424
+ if l == k: # convergence
1425
+ if z < 0: # singular value is made nonnegative
1426
+ S[k] = -z
1427
+ if not isinstance(V, bool):
1428
+ for j in xrange(n):
1429
+ V[k,j] = -V[k,j]
1430
+ break
1431
+
1432
+ if its >= maxits:
1433
+ raise RuntimeError("svd: no convergence to an eigenvalue after %d iterations" % its)
1434
+
1435
+ x = S[l] # shift from bottom 2 by 2 minor
1436
+ nm = k-1
1437
+ y = S[nm]
1438
+ g = dwork[nm]
1439
+ h = dwork[k]
1440
+ f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2 * h * y)
1441
+ g = ctx.hypot(f, 1)
1442
+ if f >=0: f = (( x - z) *( x + z) + h *((y / (f + g)) - h)) / x
1443
+ else: f = (( x - z) *( x + z) + h *((y / (f - g)) - h)) / x
1444
+
1445
+ c = s = 1 # next qt transformation
1446
+
1447
+ for j in xrange(l, nm + 1):
1448
+ g = dwork[j+1]
1449
+ y = S[j+1]
1450
+ h = s * g
1451
+ g = c * g
1452
+ z = ctx.hypot(f, h)
1453
+ dwork[j] = z
1454
+ c = f / z
1455
+ s = h / z
1456
+ f = x * c + g * s
1457
+ g = g * c - x * s
1458
+ h = y * s
1459
+ y *= c
1460
+ if not isinstance(V, bool):
1461
+ for jj in xrange(n):
1462
+ x = V[j ,jj]
1463
+ z = V[j+1,jj]
1464
+ V[j ,jj]= x * c + z * s
1465
+ V[j+1,jj ]= z * c - x * s
1466
+ z = ctx.hypot(f, h)
1467
+ S[j] = z
1468
+ if z != 0: # rotation can be arbitray if z=0
1469
+ z = 1 / z
1470
+ c = f * z
1471
+ s = h * z
1472
+ f = c * g + s * y
1473
+ x = c * y - s * g
1474
+ if calc_u:
1475
+ for jj in xrange(m):
1476
+ y = A[jj,j ]
1477
+ z = A[jj,j+1]
1478
+ A[jj,j ]= y * c + z * s
1479
+ A[jj,j+1 ]= z * c - y * s
1480
+
1481
+ dwork[l] = 0
1482
+ dwork[k] = f
1483
+ S[k] = x
1484
+
1485
+ ##########################
1486
+
1487
+ # Sort singular values into decreasing order (bubble-sort)
1488
+
1489
+ for i in xrange(n):
1490
+ imax = i
1491
+ s = ctx.fabs(S[i]) # s is the current maximal element
1492
+
1493
+ for j in xrange(i + 1, n):
1494
+ c = ctx.fabs(S[j])
1495
+ if c > s:
1496
+ s = c
1497
+ imax = j
1498
+
1499
+ if imax != i:
1500
+ # swap singular values
1501
+
1502
+ z = S[i]
1503
+ S[i] = S[imax]
1504
+ S[imax] = z
1505
+
1506
+ if calc_u:
1507
+ for j in xrange(m):
1508
+ z = A[j,i]
1509
+ A[j,i] = A[j,imax]
1510
+ A[j,imax] = z
1511
+
1512
+ if not isinstance(V, bool):
1513
+ for j in xrange(n):
1514
+ z = V[i,j]
1515
+ V[i,j] = V[imax,j]
1516
+ V[imax,j] = z
1517
+
1518
+ return S
1519
+
1520
+ ##################################################################################################
1521
+
1522
+ @defun
1523
+ def svd_r(ctx, A, full_matrices = False, compute_uv = True, overwrite_a = False):
1524
+ """
1525
+ This routine computes the singular value decomposition of a matrix A.
1526
+ Given A, two orthogonal matrices U and V are calculated such that
1527
+
1528
+ A = U S V and U' U = 1 and V V' = 1
1529
+
1530
+ where S is a suitable shaped matrix whose off-diagonal elements are zero.
1531
+ Here ' denotes the transpose. The diagonal elements of S are the singular
1532
+ values of A, i.e. the squareroots of the eigenvalues of A' A or A A'.
1533
+
1534
+ input:
1535
+ A : a real matrix of shape (m, n)
1536
+ full_matrices : if true, U and V are of shape (m, m) and (n, n).
1537
+ if false, U and V are of shape (m, min(m, n)) and (min(m, n), n).
1538
+ compute_uv : if true, U and V are calculated. if false, only S is calculated.
1539
+ overwrite_a : if true, allows modification of A which may improve
1540
+ performance. if false, A is not modified.
1541
+
1542
+ output:
1543
+ U : an orthogonal matrix: U' U = 1. if full_matrices is true, U is of
1544
+ shape (m, m). ortherwise it is of shape (m, min(m, n)).
1545
+
1546
+ S : an array of length min(m, n) containing the singular values of A sorted by
1547
+ decreasing magnitude.
1548
+
1549
+ V : an orthogonal matrix: V V' = 1. if full_matrices is true, V is of
1550
+ shape (n, n). ortherwise it is of shape (min(m, n), n).
1551
+
1552
+ return value:
1553
+
1554
+ S if compute_uv is false
1555
+ (U, S, V) if compute_uv is true
1556
+
1557
+ overview of the matrices:
1558
+
1559
+ full_matrices true:
1560
+ A : m*n
1561
+ U : m*m U' U = 1
1562
+ S as matrix : m*n
1563
+ V : n*n V V' = 1
1564
+
1565
+ full_matrices false:
1566
+ A : m*n
1567
+ U : m*min(n,m) U' U = 1
1568
+ S as matrix : min(m,n)*min(m,n)
1569
+ V : min(m,n)*n V V' = 1
1570
+
1571
+ examples:
1572
+
1573
+ >>> from mpmath import mp
1574
+ >>> A = mp.matrix([[2, -2, -1], [3, 4, -2], [-2, -2, 0]])
1575
+ >>> S = mp.svd_r(A, compute_uv = False)
1576
+ >>> print(S)
1577
+ [6.0]
1578
+ [3.0]
1579
+ [1.0]
1580
+
1581
+ >>> U, S, V = mp.svd_r(A)
1582
+ >>> print(mp.chop(A - U * mp.diag(S) * V))
1583
+ [0.0 0.0 0.0]
1584
+ [0.0 0.0 0.0]
1585
+ [0.0 0.0 0.0]
1586
+
1587
+
1588
+ see also: svd, svd_c
1589
+ """
1590
+
1591
+ m, n = A.rows, A.cols
1592
+
1593
+ if not compute_uv:
1594
+ if not overwrite_a:
1595
+ A = A.copy()
1596
+ S = svd_r_raw(ctx, A, V = False, calc_u = False)
1597
+ S = S[:min(m,n)]
1598
+ return S
1599
+
1600
+ if full_matrices and n < m:
1601
+ V = ctx.zeros(m, m)
1602
+ A0 = ctx.zeros(m, m)
1603
+ A0[:,:n] = A
1604
+ S = svd_r_raw(ctx, A0, V, calc_u = True)
1605
+
1606
+ S = S[:n]
1607
+ V = V[:n,:n]
1608
+
1609
+ return (A0, S, V)
1610
+ else:
1611
+ if not overwrite_a:
1612
+ A = A.copy()
1613
+ V = ctx.zeros(n, n)
1614
+ S = svd_r_raw(ctx, A, V, calc_u = True)
1615
+
1616
+ if n > m:
1617
+ if full_matrices == False:
1618
+ V = V[:m,:]
1619
+
1620
+ S = S[:m]
1621
+ A = A[:,:m]
1622
+
1623
+ return (A, S, V)
1624
+
1625
+ ##############################
1626
+
1627
+ @defun
1628
+ def svd_c(ctx, A, full_matrices = False, compute_uv = True, overwrite_a = False):
1629
+ """
1630
+ This routine computes the singular value decomposition of a matrix A.
1631
+ Given A, two unitary matrices U and V are calculated such that
1632
+
1633
+ A = U S V and U' U = 1 and V V' = 1
1634
+
1635
+ where S is a suitable shaped matrix whose off-diagonal elements are zero.
1636
+ Here ' denotes the hermitian transpose (i.e. transposition and complex
1637
+ conjugation). The diagonal elements of S are the singular values of A,
1638
+ i.e. the squareroots of the eigenvalues of A' A or A A'.
1639
+
1640
+ input:
1641
+ A : a complex matrix of shape (m, n)
1642
+ full_matrices : if true, U and V are of shape (m, m) and (n, n).
1643
+ if false, U and V are of shape (m, min(m, n)) and (min(m, n), n).
1644
+ compute_uv : if true, U and V are calculated. if false, only S is calculated.
1645
+ overwrite_a : if true, allows modification of A which may improve
1646
+ performance. if false, A is not modified.
1647
+
1648
+ output:
1649
+ U : an unitary matrix: U' U = 1. if full_matrices is true, U is of
1650
+ shape (m, m). ortherwise it is of shape (m, min(m, n)).
1651
+
1652
+ S : an array of length min(m, n) containing the singular values of A sorted by
1653
+ decreasing magnitude.
1654
+
1655
+ V : an unitary matrix: V V' = 1. if full_matrices is true, V is of
1656
+ shape (n, n). ortherwise it is of shape (min(m, n), n).
1657
+
1658
+ return value:
1659
+
1660
+ S if compute_uv is false
1661
+ (U, S, V) if compute_uv is true
1662
+
1663
+ overview of the matrices:
1664
+
1665
+ full_matrices true:
1666
+ A : m*n
1667
+ U : m*m U' U = 1
1668
+ S as matrix : m*n
1669
+ V : n*n V V' = 1
1670
+
1671
+ full_matrices false:
1672
+ A : m*n
1673
+ U : m*min(n,m) U' U = 1
1674
+ S as matrix : min(m,n)*min(m,n)
1675
+ V : min(m,n)*n V V' = 1
1676
+
1677
+ example:
1678
+ >>> from mpmath import mp
1679
+ >>> A = mp.matrix([[-2j, -1-3j, -2+2j], [2-2j, -1-3j, 1], [-3+1j,-2j,0]])
1680
+ >>> S = mp.svd_c(A, compute_uv = False)
1681
+ >>> print(mp.chop(S - mp.matrix([mp.sqrt(34), mp.sqrt(15), mp.sqrt(6)])))
1682
+ [0.0]
1683
+ [0.0]
1684
+ [0.0]
1685
+
1686
+ >>> U, S, V = mp.svd_c(A)
1687
+ >>> print(mp.chop(A - U * mp.diag(S) * V))
1688
+ [0.0 0.0 0.0]
1689
+ [0.0 0.0 0.0]
1690
+ [0.0 0.0 0.0]
1691
+
1692
+ see also: svd, svd_r
1693
+ """
1694
+
1695
+ m, n = A.rows, A.cols
1696
+
1697
+ if not compute_uv:
1698
+ if not overwrite_a:
1699
+ A = A.copy()
1700
+ S = svd_c_raw(ctx, A, V = False, calc_u = False)
1701
+ S = S[:min(m,n)]
1702
+ return S
1703
+
1704
+ if full_matrices and n < m:
1705
+ V = ctx.zeros(m, m)
1706
+ A0 = ctx.zeros(m, m)
1707
+ A0[:,:n] = A
1708
+ S = svd_c_raw(ctx, A0, V, calc_u = True)
1709
+
1710
+ S = S[:n]
1711
+ V = V[:n,:n]
1712
+
1713
+ return (A0, S, V)
1714
+ else:
1715
+ if not overwrite_a:
1716
+ A = A.copy()
1717
+ V = ctx.zeros(n, n)
1718
+ S = svd_c_raw(ctx, A, V, calc_u = True)
1719
+
1720
+ if n > m:
1721
+ if full_matrices == False:
1722
+ V = V[:m,:]
1723
+
1724
+ S = S[:m]
1725
+ A = A[:,:m]
1726
+
1727
+ return (A, S, V)
1728
+
1729
+ @defun
1730
+ def svd(ctx, A, full_matrices = False, compute_uv = True, overwrite_a = False):
1731
+ """
1732
+ "svd" is a unified interface for "svd_r" and "svd_c". Depending on
1733
+ whether A is real or complex the appropriate function is called.
1734
+
1735
+ This routine computes the singular value decomposition of a matrix A.
1736
+ Given A, two orthogonal (A real) or unitary (A complex) matrices U and V
1737
+ are calculated such that
1738
+
1739
+ A = U S V and U' U = 1 and V V' = 1
1740
+
1741
+ where S is a suitable shaped matrix whose off-diagonal elements are zero.
1742
+ Here ' denotes the hermitian transpose (i.e. transposition and complex
1743
+ conjugation). The diagonal elements of S are the singular values of A,
1744
+ i.e. the squareroots of the eigenvalues of A' A or A A'.
1745
+
1746
+ input:
1747
+ A : a real or complex matrix of shape (m, n)
1748
+ full_matrices : if true, U and V are of shape (m, m) and (n, n).
1749
+ if false, U and V are of shape (m, min(m, n)) and (min(m, n), n).
1750
+ compute_uv : if true, U and V are calculated. if false, only S is calculated.
1751
+ overwrite_a : if true, allows modification of A which may improve
1752
+ performance. if false, A is not modified.
1753
+
1754
+ output:
1755
+ U : an orthogonal or unitary matrix: U' U = 1. if full_matrices is true, U is of
1756
+ shape (m, m). ortherwise it is of shape (m, min(m, n)).
1757
+
1758
+ S : an array of length min(m, n) containing the singular values of A sorted by
1759
+ decreasing magnitude.
1760
+
1761
+ V : an orthogonal or unitary matrix: V V' = 1. if full_matrices is true, V is of
1762
+ shape (n, n). ortherwise it is of shape (min(m, n), n).
1763
+
1764
+ return value:
1765
+
1766
+ S if compute_uv is false
1767
+ (U, S, V) if compute_uv is true
1768
+
1769
+ overview of the matrices:
1770
+
1771
+ full_matrices true:
1772
+ A : m*n
1773
+ U : m*m U' U = 1
1774
+ S as matrix : m*n
1775
+ V : n*n V V' = 1
1776
+
1777
+ full_matrices false:
1778
+ A : m*n
1779
+ U : m*min(n,m) U' U = 1
1780
+ S as matrix : min(m,n)*min(m,n)
1781
+ V : min(m,n)*n V V' = 1
1782
+
1783
+ examples:
1784
+
1785
+ >>> from mpmath import mp
1786
+ >>> A = mp.matrix([[2, -2, -1], [3, 4, -2], [-2, -2, 0]])
1787
+ >>> S = mp.svd(A, compute_uv = False)
1788
+ >>> print(S)
1789
+ [6.0]
1790
+ [3.0]
1791
+ [1.0]
1792
+
1793
+ >>> U, S, V = mp.svd(A)
1794
+ >>> print(mp.chop(A - U * mp.diag(S) * V))
1795
+ [0.0 0.0 0.0]
1796
+ [0.0 0.0 0.0]
1797
+ [0.0 0.0 0.0]
1798
+
1799
+ see also: svd_r, svd_c
1800
+ """
1801
+
1802
+ iscomplex = any(type(x) is ctx.mpc for x in A)
1803
+
1804
+ if iscomplex:
1805
+ return ctx.svd_c(A, full_matrices = full_matrices, compute_uv = compute_uv, overwrite_a = overwrite_a)
1806
+ else:
1807
+ return ctx.svd_r(A, full_matrices = full_matrices, compute_uv = compute_uv, overwrite_a = overwrite_a)
lib/python3.11/site-packages/mpmath/matrices/linalg.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear algebra
3
+ --------------
4
+
5
+ Linear equations
6
+ ................
7
+
8
+ Basic linear algebra is implemented; you can for example solve the linear
9
+ equation system::
10
+
11
+ x + 2*y = -10
12
+ 3*x + 4*y = 10
13
+
14
+ using ``lu_solve``::
15
+
16
+ >>> from mpmath import *
17
+ >>> mp.pretty = False
18
+ >>> A = matrix([[1, 2], [3, 4]])
19
+ >>> b = matrix([-10, 10])
20
+ >>> x = lu_solve(A, b)
21
+ >>> x
22
+ matrix(
23
+ [['30.0'],
24
+ ['-20.0']])
25
+
26
+ If you don't trust the result, use ``residual`` to calculate the residual ||A*x-b||::
27
+
28
+ >>> residual(A, x, b)
29
+ matrix(
30
+ [['3.46944695195361e-18'],
31
+ ['3.46944695195361e-18']])
32
+ >>> str(eps)
33
+ '2.22044604925031e-16'
34
+
35
+ As you can see, the solution is quite accurate. The error is caused by the
36
+ inaccuracy of the internal floating point arithmetic. Though, it's even smaller
37
+ than the current machine epsilon, which basically means you can trust the
38
+ result.
39
+
40
+ If you need more speed, use NumPy, or ``fp.lu_solve`` for a floating-point computation.
41
+
42
+ >>> fp.lu_solve(A, b) # doctest: +ELLIPSIS
43
+ matrix(...)
44
+
45
+ ``lu_solve`` accepts overdetermined systems. It is usually not possible to solve
46
+ such systems, so the residual is minimized instead. Internally this is done
47
+ using Cholesky decomposition to compute a least squares approximation. This means
48
+ that that ``lu_solve`` will square the errors. If you can't afford this, use
49
+ ``qr_solve`` instead. It is twice as slow but more accurate, and it calculates
50
+ the residual automatically.
51
+
52
+
53
+ Matrix factorization
54
+ ....................
55
+
56
+ The function ``lu`` computes an explicit LU factorization of a matrix::
57
+
58
+ >>> P, L, U = lu(matrix([[0,2,3],[4,5,6],[7,8,9]]))
59
+ >>> print(P)
60
+ [0.0 0.0 1.0]
61
+ [1.0 0.0 0.0]
62
+ [0.0 1.0 0.0]
63
+ >>> print(L)
64
+ [ 1.0 0.0 0.0]
65
+ [ 0.0 1.0 0.0]
66
+ [0.571428571428571 0.214285714285714 1.0]
67
+ >>> print(U)
68
+ [7.0 8.0 9.0]
69
+ [0.0 2.0 3.0]
70
+ [0.0 0.0 0.214285714285714]
71
+ >>> print(P.T*L*U)
72
+ [0.0 2.0 3.0]
73
+ [4.0 5.0 6.0]
74
+ [7.0 8.0 9.0]
75
+
76
+ Interval matrices
77
+ -----------------
78
+
79
+ Matrices may contain interval elements. This allows one to perform
80
+ basic linear algebra operations such as matrix multiplication
81
+ and equation solving with rigorous error bounds::
82
+
83
+ >>> a = iv.matrix([['0.1','0.3','1.0'],
84
+ ... ['7.1','5.5','4.8'],
85
+ ... ['3.2','4.4','5.6']])
86
+ >>>
87
+ >>> b = iv.matrix(['4','0.6','0.5'])
88
+ >>> c = iv.lu_solve(a, b)
89
+ >>> print(c)
90
+ [ [5.2582327113062568605927528666, 5.25823271130625686059275702219]]
91
+ [[-13.1550493962678375411635581388, -13.1550493962678375411635540152]]
92
+ [ [7.42069154774972557628979076189, 7.42069154774972557628979190734]]
93
+ >>> print(a*c)
94
+ [ [3.99999999999999999999999844904, 4.00000000000000000000000155096]]
95
+ [[0.599999999999999999999968898009, 0.600000000000000000000031763736]]
96
+ [[0.499999999999999999999979320485, 0.500000000000000000000020679515]]
97
+ """
98
+
99
+ # TODO:
100
+ # *implement high-level qr()
101
+ # *test unitvector
102
+ # *iterative solving
103
+
104
+ from copy import copy
105
+
106
+ from ..libmp.backend import xrange
107
+
108
+ class LinearAlgebraMethods(object):
109
+
110
+ def LU_decomp(ctx, A, overwrite=False, use_cache=True):
111
+ """
112
+ LU-factorization of a n*n matrix using the Gauss algorithm.
113
+ Returns L and U in one matrix and the pivot indices.
114
+
115
+ Use overwrite to specify whether A will be overwritten with L and U.
116
+ """
117
+ if not A.rows == A.cols:
118
+ raise ValueError('need n*n matrix')
119
+ # get from cache if possible
120
+ if use_cache and isinstance(A, ctx.matrix) and A._LU:
121
+ return A._LU
122
+ if not overwrite:
123
+ orig = A
124
+ A = A.copy()
125
+ tol = ctx.absmin(ctx.mnorm(A,1) * ctx.eps) # each pivot element has to be bigger
126
+ n = A.rows
127
+ p = [None]*(n - 1)
128
+ for j in xrange(n - 1):
129
+ # pivoting, choose max(abs(reciprocal row sum)*abs(pivot element))
130
+ biggest = 0
131
+ for k in xrange(j, n):
132
+ s = ctx.fsum([ctx.absmin(A[k,l]) for l in xrange(j, n)])
133
+ if ctx.absmin(s) <= tol:
134
+ raise ZeroDivisionError('matrix is numerically singular')
135
+ current = 1/s * ctx.absmin(A[k,j])
136
+ if current > biggest: # TODO: what if equal?
137
+ biggest = current
138
+ p[j] = k
139
+ # swap rows according to p
140
+ ctx.swap_row(A, j, p[j])
141
+ if ctx.absmin(A[j,j]) <= tol:
142
+ raise ZeroDivisionError('matrix is numerically singular')
143
+ # calculate elimination factors and add rows
144
+ for i in xrange(j + 1, n):
145
+ A[i,j] /= A[j,j]
146
+ for k in xrange(j + 1, n):
147
+ A[i,k] -= A[i,j]*A[j,k]
148
+ if ctx.absmin(A[n - 1,n - 1]) <= tol:
149
+ raise ZeroDivisionError('matrix is numerically singular')
150
+ # cache decomposition
151
+ if not overwrite and isinstance(orig, ctx.matrix):
152
+ orig._LU = (A, p)
153
+ return A, p
154
+
155
+ def L_solve(ctx, L, b, p=None):
156
+ """
157
+ Solve the lower part of a LU factorized matrix for y.
158
+ """
159
+ if L.rows != L.cols:
160
+ raise RuntimeError("need n*n matrix")
161
+ n = L.rows
162
+ if len(b) != n:
163
+ raise ValueError("Value should be equal to n")
164
+ b = copy(b)
165
+ if p: # swap b according to p
166
+ for k in xrange(0, len(p)):
167
+ ctx.swap_row(b, k, p[k])
168
+ # solve
169
+ for i in xrange(1, n):
170
+ for j in xrange(i):
171
+ b[i] -= L[i,j] * b[j]
172
+ return b
173
+
174
+ def U_solve(ctx, U, y):
175
+ """
176
+ Solve the upper part of a LU factorized matrix for x.
177
+ """
178
+ if U.rows != U.cols:
179
+ raise RuntimeError("need n*n matrix")
180
+ n = U.rows
181
+ if len(y) != n:
182
+ raise ValueError("Value should be equal to n")
183
+ x = copy(y)
184
+ for i in xrange(n - 1, -1, -1):
185
+ for j in xrange(i + 1, n):
186
+ x[i] -= U[i,j] * x[j]
187
+ x[i] /= U[i,i]
188
+ return x
189
+
190
+ def lu_solve(ctx, A, b, **kwargs):
191
+ """
192
+ Ax = b => x
193
+
194
+ Solve a determined or overdetermined linear equations system.
195
+ Fast LU decomposition is used, which is less accurate than QR decomposition
196
+ (especially for overdetermined systems), but it's twice as efficient.
197
+ Use qr_solve if you want more precision or have to solve a very ill-
198
+ conditioned system.
199
+
200
+ If you specify real=True, it does not check for overdeterminded complex
201
+ systems.
202
+ """
203
+ prec = ctx.prec
204
+ try:
205
+ ctx.prec += 10
206
+ # do not overwrite A nor b
207
+ A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
208
+ if A.rows < A.cols:
209
+ raise ValueError('cannot solve underdetermined system')
210
+ if A.rows > A.cols:
211
+ # use least-squares method if overdetermined
212
+ # (this increases errors)
213
+ AH = A.H
214
+ A = AH * A
215
+ b = AH * b
216
+ if (kwargs.get('real', False) or
217
+ not sum(type(i) is ctx.mpc for i in A)):
218
+ # TODO: necessary to check also b?
219
+ x = ctx.cholesky_solve(A, b)
220
+ else:
221
+ x = ctx.lu_solve(A, b)
222
+ else:
223
+ # LU factorization
224
+ A, p = ctx.LU_decomp(A)
225
+ b = ctx.L_solve(A, b, p)
226
+ x = ctx.U_solve(A, b)
227
+ finally:
228
+ ctx.prec = prec
229
+ return x
230
+
231
+ def improve_solution(ctx, A, x, b, maxsteps=1):
232
+ """
233
+ Improve a solution to a linear equation system iteratively.
234
+
235
+ This re-uses the LU decomposition and is thus cheap.
236
+ Usually 3 up to 4 iterations are giving the maximal improvement.
237
+ """
238
+ if A.rows != A.cols:
239
+ raise RuntimeError("need n*n matrix") # TODO: really?
240
+ for _ in xrange(maxsteps):
241
+ r = ctx.residual(A, x, b)
242
+ if ctx.norm(r, 2) < 10*ctx.eps:
243
+ break
244
+ # this uses cached LU decomposition and is thus cheap
245
+ dx = ctx.lu_solve(A, -r)
246
+ x += dx
247
+ return x
248
+
249
+ def lu(ctx, A):
250
+ """
251
+ A -> P, L, U
252
+
253
+ LU factorisation of a square matrix A. L is the lower, U the upper part.
254
+ P is the permutation matrix indicating the row swaps.
255
+
256
+ P*A = L*U
257
+
258
+ If you need efficiency, use the low-level method LU_decomp instead, it's
259
+ much more memory efficient.
260
+ """
261
+ # get factorization
262
+ A, p = ctx.LU_decomp(A)
263
+ n = A.rows
264
+ L = ctx.matrix(n)
265
+ U = ctx.matrix(n)
266
+ for i in xrange(n):
267
+ for j in xrange(n):
268
+ if i > j:
269
+ L[i,j] = A[i,j]
270
+ elif i == j:
271
+ L[i,j] = 1
272
+ U[i,j] = A[i,j]
273
+ else:
274
+ U[i,j] = A[i,j]
275
+ # calculate permutation matrix
276
+ P = ctx.eye(n)
277
+ for k in xrange(len(p)):
278
+ ctx.swap_row(P, k, p[k])
279
+ return P, L, U
280
+
281
+ def unitvector(ctx, n, i):
282
+ """
283
+ Return the i-th n-dimensional unit vector.
284
+ """
285
+ assert 0 < i <= n, 'this unit vector does not exist'
286
+ return [ctx.zero]*(i-1) + [ctx.one] + [ctx.zero]*(n-i)
287
+
288
+ def inverse(ctx, A, **kwargs):
289
+ """
290
+ Calculate the inverse of a matrix.
291
+
292
+ If you want to solve an equation system Ax = b, it's recommended to use
293
+ solve(A, b) instead, it's about 3 times more efficient.
294
+ """
295
+ prec = ctx.prec
296
+ try:
297
+ ctx.prec += 10
298
+ # do not overwrite A
299
+ A = ctx.matrix(A, **kwargs).copy()
300
+ n = A.rows
301
+ # get LU factorisation
302
+ A, p = ctx.LU_decomp(A)
303
+ cols = []
304
+ # calculate unit vectors and solve corresponding system to get columns
305
+ for i in xrange(1, n + 1):
306
+ e = ctx.unitvector(n, i)
307
+ y = ctx.L_solve(A, e, p)
308
+ cols.append(ctx.U_solve(A, y))
309
+ # convert columns to matrix
310
+ inv = []
311
+ for i in xrange(n):
312
+ row = []
313
+ for j in xrange(n):
314
+ row.append(cols[j][i])
315
+ inv.append(row)
316
+ result = ctx.matrix(inv, **kwargs)
317
+ finally:
318
+ ctx.prec = prec
319
+ return result
320
+
321
+ def householder(ctx, A):
322
+ """
323
+ (A|b) -> H, p, x, res
324
+
325
+ (A|b) is the coefficient matrix with left hand side of an optionally
326
+ overdetermined linear equation system.
327
+ H and p contain all information about the transformation matrices.
328
+ x is the solution, res the residual.
329
+ """
330
+ if not isinstance(A, ctx.matrix):
331
+ raise TypeError("A should be a type of ctx.matrix")
332
+ m = A.rows
333
+ n = A.cols
334
+ if m < n - 1:
335
+ raise RuntimeError("Columns should not be less than rows")
336
+ # calculate Householder matrix
337
+ p = []
338
+ for j in xrange(0, n - 1):
339
+ s = ctx.fsum(abs(A[i,j])**2 for i in xrange(j, m))
340
+ if not abs(s) > ctx.eps:
341
+ raise ValueError('matrix is numerically singular')
342
+ p.append(-ctx.sign(ctx.re(A[j,j])) * ctx.sqrt(s))
343
+ kappa = ctx.one / (s - p[j] * A[j,j])
344
+ A[j,j] -= p[j]
345
+ for k in xrange(j+1, n):
346
+ y = ctx.fsum(ctx.conj(A[i,j]) * A[i,k] for i in xrange(j, m)) * kappa
347
+ for i in xrange(j, m):
348
+ A[i,k] -= A[i,j] * y
349
+ # solve Rx = c1
350
+ x = [A[i,n - 1] for i in xrange(n - 1)]
351
+ for i in xrange(n - 2, -1, -1):
352
+ x[i] -= ctx.fsum(A[i,j] * x[j] for j in xrange(i + 1, n - 1))
353
+ x[i] /= p[i]
354
+ # calculate residual
355
+ if not m == n - 1:
356
+ r = [A[m-1-i, n-1] for i in xrange(m - n + 1)]
357
+ else:
358
+ # determined system, residual should be 0
359
+ r = [0]*m # maybe a bad idea, changing r[i] will change all elements
360
+ return A, p, x, r
361
+
362
+ #def qr(ctx, A):
363
+ # """
364
+ # A -> Q, R
365
+ #
366
+ # QR factorisation of a square matrix A using Householder decomposition.
367
+ # Q is orthogonal, this leads to very few numerical errors.
368
+ #
369
+ # A = Q*R
370
+ # """
371
+ # H, p, x, res = householder(A)
372
+ # TODO: implement this
373
+
374
+ def residual(ctx, A, x, b, **kwargs):
375
+ """
376
+ Calculate the residual of a solution to a linear equation system.
377
+
378
+ r = A*x - b for A*x = b
379
+ """
380
+ oldprec = ctx.prec
381
+ try:
382
+ ctx.prec *= 2
383
+ A, x, b = ctx.matrix(A, **kwargs), ctx.matrix(x, **kwargs), ctx.matrix(b, **kwargs)
384
+ return A*x - b
385
+ finally:
386
+ ctx.prec = oldprec
387
+
388
+ def qr_solve(ctx, A, b, norm=None, **kwargs):
389
+ """
390
+ Ax = b => x, ||Ax - b||
391
+
392
+ Solve a determined or overdetermined linear equations system and
393
+ calculate the norm of the residual (error).
394
+ QR decomposition using Householder factorization is applied, which gives very
395
+ accurate results even for ill-conditioned matrices. qr_solve is twice as
396
+ efficient.
397
+ """
398
+ if norm is None:
399
+ norm = ctx.norm
400
+ prec = ctx.prec
401
+ try:
402
+ ctx.prec += 10
403
+ # do not overwrite A nor b
404
+ A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
405
+ if A.rows < A.cols:
406
+ raise ValueError('cannot solve underdetermined system')
407
+ H, p, x, r = ctx.householder(ctx.extend(A, b))
408
+ res = ctx.norm(r)
409
+ # calculate residual "manually" for determined systems
410
+ if res == 0:
411
+ res = ctx.norm(ctx.residual(A, x, b))
412
+ return ctx.matrix(x, **kwargs), res
413
+ finally:
414
+ ctx.prec = prec
415
+
416
+ def cholesky(ctx, A, tol=None):
417
+ r"""
418
+ Cholesky decomposition of a symmetric positive-definite matrix `A`.
419
+ Returns a lower triangular matrix `L` such that `A = L \times L^T`.
420
+ More generally, for a complex Hermitian positive-definite matrix,
421
+ a Cholesky decomposition satisfying `A = L \times L^H` is returned.
422
+
423
+ The Cholesky decomposition can be used to solve linear equation
424
+ systems twice as efficiently as LU decomposition, or to
425
+ test whether `A` is positive-definite.
426
+
427
+ The optional parameter ``tol`` determines the tolerance for
428
+ verifying positive-definiteness.
429
+
430
+ **Examples**
431
+
432
+ Cholesky decomposition of a positive-definite symmetric matrix::
433
+
434
+ >>> from mpmath import *
435
+ >>> mp.dps = 25; mp.pretty = True
436
+ >>> A = eye(3) + hilbert(3)
437
+ >>> nprint(A)
438
+ [ 2.0 0.5 0.333333]
439
+ [ 0.5 1.33333 0.25]
440
+ [0.333333 0.25 1.2]
441
+ >>> L = cholesky(A)
442
+ >>> nprint(L)
443
+ [ 1.41421 0.0 0.0]
444
+ [0.353553 1.09924 0.0]
445
+ [0.235702 0.15162 1.05899]
446
+ >>> chop(A - L*L.T)
447
+ [0.0 0.0 0.0]
448
+ [0.0 0.0 0.0]
449
+ [0.0 0.0 0.0]
450
+
451
+ Cholesky decomposition of a Hermitian matrix::
452
+
453
+ >>> A = eye(3) + matrix([[0,0.25j,-0.5j],[-0.25j,0,0],[0.5j,0,0]])
454
+ >>> L = cholesky(A)
455
+ >>> nprint(L)
456
+ [ 1.0 0.0 0.0]
457
+ [(0.0 - 0.25j) (0.968246 + 0.0j) 0.0]
458
+ [ (0.0 + 0.5j) (0.129099 + 0.0j) (0.856349 + 0.0j)]
459
+ >>> chop(A - L*L.H)
460
+ [0.0 0.0 0.0]
461
+ [0.0 0.0 0.0]
462
+ [0.0 0.0 0.0]
463
+
464
+ Attempted Cholesky decomposition of a matrix that is not positive
465
+ definite::
466
+
467
+ >>> A = -eye(3) + hilbert(3)
468
+ >>> L = cholesky(A)
469
+ Traceback (most recent call last):
470
+ ...
471
+ ValueError: matrix is not positive-definite
472
+
473
+ **References**
474
+
475
+ 1. [Wikipedia]_ http://en.wikipedia.org/wiki/Cholesky_decomposition
476
+
477
+ """
478
+ if not isinstance(A, ctx.matrix):
479
+ raise RuntimeError("A should be a type of ctx.matrix")
480
+ if not A.rows == A.cols:
481
+ raise ValueError('need n*n matrix')
482
+ if tol is None:
483
+ tol = +ctx.eps
484
+ n = A.rows
485
+ L = ctx.matrix(n)
486
+ for j in xrange(n):
487
+ c = ctx.re(A[j,j])
488
+ if abs(c-A[j,j]) > tol:
489
+ raise ValueError('matrix is not Hermitian')
490
+ s = c - ctx.fsum((L[j,k] for k in xrange(j)),
491
+ absolute=True, squared=True)
492
+ if s < tol:
493
+ raise ValueError('matrix is not positive-definite')
494
+ L[j,j] = ctx.sqrt(s)
495
+ for i in xrange(j, n):
496
+ it1 = (L[i,k] for k in xrange(j))
497
+ it2 = (L[j,k] for k in xrange(j))
498
+ t = ctx.fdot(it1, it2, conjugate=True)
499
+ L[i,j] = (A[i,j] - t) / L[j,j]
500
+ return L
501
+
502
+ def cholesky_solve(ctx, A, b, **kwargs):
503
+ """
504
+ Ax = b => x
505
+
506
+ Solve a symmetric positive-definite linear equation system.
507
+ This is twice as efficient as lu_solve.
508
+
509
+ Typical use cases:
510
+ * A.T*A
511
+ * Hessian matrix
512
+ * differential equations
513
+ """
514
+ prec = ctx.prec
515
+ try:
516
+ ctx.prec += 10
517
+ # do not overwrite A nor b
518
+ A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
519
+ if A.rows != A.cols:
520
+ raise ValueError('can only solve determined system')
521
+ # Cholesky factorization
522
+ L = ctx.cholesky(A)
523
+ # solve
524
+ n = L.rows
525
+ if len(b) != n:
526
+ raise ValueError("Value should be equal to n")
527
+ for i in xrange(n):
528
+ b[i] -= ctx.fsum(L[i,j] * b[j] for j in xrange(i))
529
+ b[i] /= L[i,i]
530
+ x = ctx.U_solve(L.T, b)
531
+ return x
532
+ finally:
533
+ ctx.prec = prec
534
+
535
+ def det(ctx, A):
536
+ """
537
+ Calculate the determinant of a matrix.
538
+ """
539
+ prec = ctx.prec
540
+ try:
541
+ # do not overwrite A
542
+ A = ctx.matrix(A).copy()
543
+ # use LU factorization to calculate determinant
544
+ try:
545
+ R, p = ctx.LU_decomp(A)
546
+ except ZeroDivisionError:
547
+ return 0
548
+ z = 1
549
+ for i, e in enumerate(p):
550
+ if i != e:
551
+ z *= -1
552
+ for i in xrange(A.rows):
553
+ z *= R[i,i]
554
+ return z
555
+ finally:
556
+ ctx.prec = prec
557
+
558
+ def cond(ctx, A, norm=None):
559
+ """
560
+ Calculate the condition number of a matrix using a specified matrix norm.
561
+
562
+ The condition number estimates the sensitivity of a matrix to errors.
563
+ Example: small input errors for ill-conditioned coefficient matrices
564
+ alter the solution of the system dramatically.
565
+
566
+ For ill-conditioned matrices it's recommended to use qr_solve() instead
567
+ of lu_solve(). This does not help with input errors however, it just avoids
568
+ to add additional errors.
569
+
570
+ Definition: cond(A) = ||A|| * ||A**-1||
571
+ """
572
+ if norm is None:
573
+ norm = lambda x: ctx.mnorm(x,1)
574
+ return norm(A) * norm(ctx.inverse(A))
575
+
576
+ def lu_solve_mat(ctx, a, b):
577
+ """Solve a * x = b where a and b are matrices."""
578
+ r = ctx.matrix(a.rows, b.cols)
579
+ for i in range(b.cols):
580
+ c = ctx.lu_solve(a, b.column(i))
581
+ for j in range(len(c)):
582
+ r[j, i] = c[j]
583
+ return r
584
+
585
+ def qr(ctx, A, mode = 'full', edps = 10):
586
+ """
587
+ Compute a QR factorization $A = QR$ where
588
+ A is an m x n matrix of real or complex numbers where m >= n
589
+
590
+ mode has following meanings:
591
+ (1) mode = 'raw' returns two matrixes (A, tau) in the
592
+ internal format used by LAPACK
593
+ (2) mode = 'skinny' returns the leading n columns of Q
594
+ and n rows of R
595
+ (3) Any other value returns the leading m columns of Q
596
+ and m rows of R
597
+
598
+ edps is the increase in mp precision used for calculations
599
+
600
+ **Examples**
601
+
602
+ >>> from mpmath import *
603
+ >>> mp.dps = 15
604
+ >>> mp.pretty = True
605
+ >>> A = matrix([[1, 2], [3, 4], [1, 1]])
606
+ >>> Q, R = qr(A)
607
+ >>> Q
608
+ [-0.301511344577764 0.861640436855329 0.408248290463863]
609
+ [-0.904534033733291 -0.123091490979333 -0.408248290463863]
610
+ [-0.301511344577764 -0.492365963917331 0.816496580927726]
611
+ >>> R
612
+ [-3.3166247903554 -4.52267016866645]
613
+ [ 0.0 0.738548945875996]
614
+ [ 0.0 0.0]
615
+ >>> Q * R
616
+ [1.0 2.0]
617
+ [3.0 4.0]
618
+ [1.0 1.0]
619
+ >>> chop(Q.T * Q)
620
+ [1.0 0.0 0.0]
621
+ [0.0 1.0 0.0]
622
+ [0.0 0.0 1.0]
623
+ >>> B = matrix([[1+0j, 2-3j], [3+j, 4+5j]])
624
+ >>> Q, R = qr(B)
625
+ >>> nprint(Q)
626
+ [ (-0.301511 + 0.0j) (0.0695795 - 0.95092j)]
627
+ [(-0.904534 - 0.301511j) (-0.115966 + 0.278318j)]
628
+ >>> nprint(R)
629
+ [(-3.31662 + 0.0j) (-5.72872 - 2.41209j)]
630
+ [ 0.0 (3.91965 + 0.0j)]
631
+ >>> Q * R
632
+ [(1.0 + 0.0j) (2.0 - 3.0j)]
633
+ [(3.0 + 1.0j) (4.0 + 5.0j)]
634
+ >>> chop(Q.T * Q.conjugate())
635
+ [1.0 0.0]
636
+ [0.0 1.0]
637
+
638
+ """
639
+
640
+ # check values before continuing
641
+ assert isinstance(A, ctx.matrix)
642
+ m = A.rows
643
+ n = A.cols
644
+ assert n >= 0
645
+ assert m >= n
646
+ assert edps >= 0
647
+
648
+ # check for complex data type
649
+ cmplx = any(type(x) is ctx.mpc for x in A)
650
+
651
+ # temporarily increase the precision and initialize
652
+ with ctx.extradps(edps):
653
+ tau = ctx.matrix(n,1)
654
+ A = A.copy()
655
+
656
+ # ---------------
657
+ # FACTOR MATRIX A
658
+ # ---------------
659
+ if cmplx:
660
+ one = ctx.mpc('1.0', '0.0')
661
+ zero = ctx.mpc('0.0', '0.0')
662
+ rzero = ctx.mpf('0.0')
663
+
664
+ # main loop to factor A (complex)
665
+ for j in xrange(0, n):
666
+ alpha = A[j,j]
667
+ alphr = ctx.re(alpha)
668
+ alphi = ctx.im(alpha)
669
+
670
+ if (m-j) >= 2:
671
+ xnorm = ctx.fsum( A[i,j]*ctx.conj(A[i,j]) for i in xrange(j+1, m) )
672
+ xnorm = ctx.re( ctx.sqrt(xnorm) )
673
+ else:
674
+ xnorm = rzero
675
+
676
+ if (xnorm == rzero) and (alphi == rzero):
677
+ tau[j] = zero
678
+ continue
679
+
680
+ if alphr < rzero:
681
+ beta = ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
682
+ else:
683
+ beta = -ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
684
+
685
+ tau[j] = ctx.mpc( (beta - alphr) / beta, -alphi / beta )
686
+ t = -ctx.conj(tau[j])
687
+ za = one / (alpha - beta)
688
+
689
+ for i in xrange(j+1, m):
690
+ A[i,j] *= za
691
+
692
+ A[j,j] = one
693
+ for k in xrange(j+1, n):
694
+ y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j, m))
695
+ temp = t * ctx.conj(y)
696
+ for i in xrange(j, m):
697
+ A[i,k] += A[i,j] * temp
698
+
699
+ A[j,j] = ctx.mpc(beta, '0.0')
700
+ else:
701
+ one = ctx.mpf('1.0')
702
+ zero = ctx.mpf('0.0')
703
+
704
+ # main loop to factor A (real)
705
+ for j in xrange(0, n):
706
+ alpha = A[j,j]
707
+
708
+ if (m-j) > 2:
709
+ xnorm = ctx.fsum( (A[i,j])**2 for i in xrange(j+1, m) )
710
+ xnorm = ctx.sqrt(xnorm)
711
+ elif (m-j) == 2:
712
+ xnorm = abs( A[m-1,j] )
713
+ else:
714
+ xnorm = zero
715
+
716
+ if xnorm == zero:
717
+ tau[j] = zero
718
+ continue
719
+
720
+ if alpha < zero:
721
+ beta = ctx.sqrt(alpha**2 + xnorm**2)
722
+ else:
723
+ beta = -ctx.sqrt(alpha**2 + xnorm**2)
724
+
725
+ tau[j] = (beta - alpha) / beta
726
+ t = -tau[j]
727
+ da = one / (alpha - beta)
728
+
729
+ for i in xrange(j+1, m):
730
+ A[i,j] *= da
731
+
732
+ A[j,j] = one
733
+ for k in xrange(j+1, n):
734
+ y = ctx.fsum( A[i,j] * A[i,k] for i in xrange(j, m) )
735
+ temp = t * y
736
+ for i in xrange(j,m):
737
+ A[i,k] += A[i,j] * temp
738
+
739
+ A[j,j] = beta
740
+
741
+ # return factorization in same internal format as LAPACK
742
+ if (mode == 'raw') or (mode == 'RAW'):
743
+ return A, tau
744
+
745
+ # ----------------------------------
746
+ # FORM Q USING BACKWARD ACCUMULATION
747
+ # ----------------------------------
748
+
749
+ # form R before the values are overwritten
750
+ R = A.copy()
751
+ for j in xrange(0, n):
752
+ for i in xrange(j+1, m):
753
+ R[i,j] = zero
754
+
755
+ # set the value of p (number of columns of Q to return)
756
+ p = m
757
+ if (mode == 'skinny') or (mode == 'SKINNY'):
758
+ p = n
759
+
760
+ # add columns to A if needed and initialize
761
+ A.cols += (p-n)
762
+ for j in xrange(0, p):
763
+ A[j,j] = one
764
+ for i in xrange(0, j):
765
+ A[i,j] = zero
766
+
767
+ # main loop to form Q
768
+ for j in xrange(n-1, -1, -1):
769
+ t = -tau[j]
770
+ A[j,j] += t
771
+
772
+ for k in xrange(j+1, p):
773
+ if cmplx:
774
+ y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j+1, m))
775
+ temp = t * ctx.conj(y)
776
+ else:
777
+ y = ctx.fsum(A[i,j] * A[i,k] for i in xrange(j+1, m))
778
+ temp = t * y
779
+ A[j,k] = temp
780
+ for i in xrange(j+1, m):
781
+ A[i,k] += A[i,j] * temp
782
+
783
+ for i in xrange(j+1, m):
784
+ A[i, j] *= t
785
+
786
+ return A, R[0:p,0:n]
787
+
788
+ # ------------------
789
+ # END OF FUNCTION QR
790
+ # ------------------
lib/python3.11/site-packages/mpmath/matrices/matrices.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..libmp.backend import xrange
2
+ import warnings
3
+
4
+ # TODO: interpret list as vectors (for multiplication)
5
+
6
+ rowsep = '\n'
7
+ colsep = ' '
8
+
9
+ class _matrix(object):
10
+ """
11
+ Numerical matrix.
12
+
13
+ Specify the dimensions or the data as a nested list.
14
+ Elements default to zero.
15
+ Use a flat list to create a column vector easily.
16
+
17
+ The datatype of the context (mpf for mp, mpi for iv, and float for fp) is used to store the data.
18
+
19
+ Creating matrices
20
+ -----------------
21
+
22
+ Matrices in mpmath are implemented using dictionaries. Only non-zero values
23
+ are stored, so it is cheap to represent sparse matrices.
24
+
25
+ The most basic way to create one is to use the ``matrix`` class directly.
26
+ You can create an empty matrix specifying the dimensions:
27
+
28
+ >>> from mpmath import *
29
+ >>> mp.dps = 15
30
+ >>> matrix(2)
31
+ matrix(
32
+ [['0.0', '0.0'],
33
+ ['0.0', '0.0']])
34
+ >>> matrix(2, 3)
35
+ matrix(
36
+ [['0.0', '0.0', '0.0'],
37
+ ['0.0', '0.0', '0.0']])
38
+
39
+ Calling ``matrix`` with one dimension will create a square matrix.
40
+
41
+ To access the dimensions of a matrix, use the ``rows`` or ``cols`` keyword:
42
+
43
+ >>> A = matrix(3, 2)
44
+ >>> A
45
+ matrix(
46
+ [['0.0', '0.0'],
47
+ ['0.0', '0.0'],
48
+ ['0.0', '0.0']])
49
+ >>> A.rows
50
+ 3
51
+ >>> A.cols
52
+ 2
53
+
54
+ You can also change the dimension of an existing matrix. This will set the
55
+ new elements to 0. If the new dimension is smaller than before, the
56
+ concerning elements are discarded:
57
+
58
+ >>> A.rows = 2
59
+ >>> A
60
+ matrix(
61
+ [['0.0', '0.0'],
62
+ ['0.0', '0.0']])
63
+
64
+ Internally ``mpmathify`` is used every time an element is set. This
65
+ is done using the syntax A[row,column], counting from 0:
66
+
67
+ >>> A = matrix(2)
68
+ >>> A[1,1] = 1 + 1j
69
+ >>> A
70
+ matrix(
71
+ [['0.0', '0.0'],
72
+ ['0.0', mpc(real='1.0', imag='1.0')]])
73
+
74
+ A more comfortable way to create a matrix lets you use nested lists:
75
+
76
+ >>> matrix([[1, 2], [3, 4]])
77
+ matrix(
78
+ [['1.0', '2.0'],
79
+ ['3.0', '4.0']])
80
+
81
+ Convenient advanced functions are available for creating various standard
82
+ matrices, see ``zeros``, ``ones``, ``diag``, ``eye``, ``randmatrix`` and
83
+ ``hilbert``.
84
+
85
+ Vectors
86
+ .......
87
+
88
+ Vectors may also be represented by the ``matrix`` class (with rows = 1 or cols = 1).
89
+ For vectors there are some things which make life easier. A column vector can
90
+ be created using a flat list, a row vectors using an almost flat nested list::
91
+
92
+ >>> matrix([1, 2, 3])
93
+ matrix(
94
+ [['1.0'],
95
+ ['2.0'],
96
+ ['3.0']])
97
+ >>> matrix([[1, 2, 3]])
98
+ matrix(
99
+ [['1.0', '2.0', '3.0']])
100
+
101
+ Optionally vectors can be accessed like lists, using only a single index::
102
+
103
+ >>> x = matrix([1, 2, 3])
104
+ >>> x[1]
105
+ mpf('2.0')
106
+ >>> x[1,0]
107
+ mpf('2.0')
108
+
109
+ Other
110
+ .....
111
+
112
+ Like you probably expected, matrices can be printed::
113
+
114
+ >>> print randmatrix(3) # doctest:+SKIP
115
+ [ 0.782963853573023 0.802057689719883 0.427895717335467]
116
+ [0.0541876859348597 0.708243266653103 0.615134039977379]
117
+ [ 0.856151514955773 0.544759264818486 0.686210904770947]
118
+
119
+ Use ``nstr`` or ``nprint`` to specify the number of digits to print::
120
+
121
+ >>> nprint(randmatrix(5), 3) # doctest:+SKIP
122
+ [2.07e-1 1.66e-1 5.06e-1 1.89e-1 8.29e-1]
123
+ [6.62e-1 6.55e-1 4.47e-1 4.82e-1 2.06e-2]
124
+ [4.33e-1 7.75e-1 6.93e-2 2.86e-1 5.71e-1]
125
+ [1.01e-1 2.53e-1 6.13e-1 3.32e-1 2.59e-1]
126
+ [1.56e-1 7.27e-2 6.05e-1 6.67e-2 2.79e-1]
127
+
128
+ As matrices are mutable, you will need to copy them sometimes::
129
+
130
+ >>> A = matrix(2)
131
+ >>> A
132
+ matrix(
133
+ [['0.0', '0.0'],
134
+ ['0.0', '0.0']])
135
+ >>> B = A.copy()
136
+ >>> B[0,0] = 1
137
+ >>> B
138
+ matrix(
139
+ [['1.0', '0.0'],
140
+ ['0.0', '0.0']])
141
+ >>> A
142
+ matrix(
143
+ [['0.0', '0.0'],
144
+ ['0.0', '0.0']])
145
+
146
+ Finally, it is possible to convert a matrix to a nested list. This is very useful,
147
+ as most Python libraries involving matrices or arrays (namely NumPy or SymPy)
148
+ support this format::
149
+
150
+ >>> B.tolist()
151
+ [[mpf('1.0'), mpf('0.0')], [mpf('0.0'), mpf('0.0')]]
152
+
153
+
154
+ Matrix operations
155
+ -----------------
156
+
157
+ You can add and subtract matrices of compatible dimensions::
158
+
159
+ >>> A = matrix([[1, 2], [3, 4]])
160
+ >>> B = matrix([[-2, 4], [5, 9]])
161
+ >>> A + B
162
+ matrix(
163
+ [['-1.0', '6.0'],
164
+ ['8.0', '13.0']])
165
+ >>> A - B
166
+ matrix(
167
+ [['3.0', '-2.0'],
168
+ ['-2.0', '-5.0']])
169
+ >>> A + ones(3) # doctest:+ELLIPSIS
170
+ Traceback (most recent call last):
171
+ ...
172
+ ValueError: incompatible dimensions for addition
173
+
174
+ It is possible to multiply or add matrices and scalars. In the latter case the
175
+ operation will be done element-wise::
176
+
177
+ >>> A * 2
178
+ matrix(
179
+ [['2.0', '4.0'],
180
+ ['6.0', '8.0']])
181
+ >>> A / 4
182
+ matrix(
183
+ [['0.25', '0.5'],
184
+ ['0.75', '1.0']])
185
+ >>> A - 1
186
+ matrix(
187
+ [['0.0', '1.0'],
188
+ ['2.0', '3.0']])
189
+
190
+ Of course you can perform matrix multiplication, if the dimensions are
191
+ compatible, using ``@`` (for Python >= 3.5) or ``*``. For clarity, ``@`` is
192
+ recommended (`PEP 465 <https://www.python.org/dev/peps/pep-0465/>`), because
193
+ the meaning of ``*`` is different in many other Python libraries such as NumPy.
194
+
195
+ >>> A @ B # doctest:+SKIP
196
+ matrix(
197
+ [['8.0', '22.0'],
198
+ ['14.0', '48.0']])
199
+ >>> A * B # same as A @ B
200
+ matrix(
201
+ [['8.0', '22.0'],
202
+ ['14.0', '48.0']])
203
+ >>> matrix([[1, 2, 3]]) * matrix([[-6], [7], [-2]])
204
+ matrix(
205
+ [['2.0']])
206
+
207
+ ..
208
+ COMMENT: TODO: the above "doctest:+SKIP" may be removed as soon as we
209
+ have dropped support for Python 3.5 and below.
210
+
211
+ You can raise powers of square matrices::
212
+
213
+ >>> A**2
214
+ matrix(
215
+ [['7.0', '10.0'],
216
+ ['15.0', '22.0']])
217
+
218
+ Negative powers will calculate the inverse::
219
+
220
+ >>> A**-1
221
+ matrix(
222
+ [['-2.0', '1.0'],
223
+ ['1.5', '-0.5']])
224
+ >>> A * A**-1
225
+ matrix(
226
+ [['1.0', '1.0842021724855e-19'],
227
+ ['-2.16840434497101e-19', '1.0']])
228
+
229
+
230
+
231
+ Matrix transposition is straightforward::
232
+
233
+ >>> A = ones(2, 3)
234
+ >>> A
235
+ matrix(
236
+ [['1.0', '1.0', '1.0'],
237
+ ['1.0', '1.0', '1.0']])
238
+ >>> A.T
239
+ matrix(
240
+ [['1.0', '1.0'],
241
+ ['1.0', '1.0'],
242
+ ['1.0', '1.0']])
243
+
244
+ Norms
245
+ .....
246
+
247
+ Sometimes you need to know how "large" a matrix or vector is. Due to their
248
+ multidimensional nature it's not possible to compare them, but there are
249
+ several functions to map a matrix or a vector to a positive real number, the
250
+ so called norms.
251
+
252
+ For vectors the p-norm is intended, usually the 1-, the 2- and the oo-norm are
253
+ used.
254
+
255
+ >>> x = matrix([-10, 2, 100])
256
+ >>> norm(x, 1)
257
+ mpf('112.0')
258
+ >>> norm(x, 2)
259
+ mpf('100.5186549850325')
260
+ >>> norm(x, inf)
261
+ mpf('100.0')
262
+
263
+ Please note that the 2-norm is the most used one, though it is more expensive
264
+ to calculate than the 1- or oo-norm.
265
+
266
+ It is possible to generalize some vector norms to matrix norm::
267
+
268
+ >>> A = matrix([[1, -1000], [100, 50]])
269
+ >>> mnorm(A, 1)
270
+ mpf('1050.0')
271
+ >>> mnorm(A, inf)
272
+ mpf('1001.0')
273
+ >>> mnorm(A, 'F')
274
+ mpf('1006.2310867787777')
275
+
276
+ The last norm (the "Frobenius-norm") is an approximation for the 2-norm, which
277
+ is hard to calculate and not available. The Frobenius-norm lacks some
278
+ mathematical properties you might expect from a norm.
279
+ """
280
+
281
+ def __init__(self, *args, **kwargs):
282
+ self.__data = {}
283
+ # LU decompostion cache, this is useful when solving the same system
284
+ # multiple times, when calculating the inverse and when calculating the
285
+ # determinant
286
+ self._LU = None
287
+ if "force_type" in kwargs:
288
+ warnings.warn("The force_type argument was removed, it did not work"
289
+ " properly anyway. If you want to force floating-point or"
290
+ " interval computations, use the respective methods from `fp`"
291
+ " or `mp` instead, e.g., `fp.matrix()` or `iv.matrix()`."
292
+ " If you want to truncate values to integer, use .apply(int) instead.")
293
+ if isinstance(args[0], (list, tuple)):
294
+ if isinstance(args[0][0], (list, tuple)):
295
+ # interpret nested list as matrix
296
+ A = args[0]
297
+ self.__rows = len(A)
298
+ self.__cols = len(A[0])
299
+ for i, row in enumerate(A):
300
+ for j, a in enumerate(row):
301
+ # note: this will call __setitem__ which will call self.ctx.convert() to convert the datatype.
302
+ self[i, j] = a
303
+ else:
304
+ # interpret list as row vector
305
+ v = args[0]
306
+ self.__rows = len(v)
307
+ self.__cols = 1
308
+ for i, e in enumerate(v):
309
+ self[i, 0] = e
310
+ elif isinstance(args[0], int):
311
+ # create empty matrix of given dimensions
312
+ if len(args) == 1:
313
+ self.__rows = self.__cols = args[0]
314
+ else:
315
+ if not isinstance(args[1], int):
316
+ raise TypeError("expected int")
317
+ self.__rows = args[0]
318
+ self.__cols = args[1]
319
+ elif isinstance(args[0], _matrix):
320
+ A = args[0]
321
+ self.__rows = A._matrix__rows
322
+ self.__cols = A._matrix__cols
323
+ for i in xrange(A.__rows):
324
+ for j in xrange(A.__cols):
325
+ self[i, j] = A[i, j]
326
+ elif hasattr(args[0], 'tolist'):
327
+ A = self.ctx.matrix(args[0].tolist())
328
+ self.__data = A._matrix__data
329
+ self.__rows = A._matrix__rows
330
+ self.__cols = A._matrix__cols
331
+ else:
332
+ raise TypeError('could not interpret given arguments')
333
+
334
+ def apply(self, f):
335
+ """
336
+ Return a copy of self with the function `f` applied elementwise.
337
+ """
338
+ new = self.ctx.matrix(self.__rows, self.__cols)
339
+ for i in xrange(self.__rows):
340
+ for j in xrange(self.__cols):
341
+ new[i,j] = f(self[i,j])
342
+ return new
343
+
344
+ def __nstr__(self, n=None, **kwargs):
345
+ # Build table of string representations of the elements
346
+ res = []
347
+ # Track per-column max lengths for pretty alignment
348
+ maxlen = [0] * self.cols
349
+ for i in range(self.rows):
350
+ res.append([])
351
+ for j in range(self.cols):
352
+ if n:
353
+ string = self.ctx.nstr(self[i,j], n, **kwargs)
354
+ else:
355
+ string = str(self[i,j])
356
+ res[-1].append(string)
357
+ maxlen[j] = max(len(string), maxlen[j])
358
+ # Patch strings together
359
+ for i, row in enumerate(res):
360
+ for j, elem in enumerate(row):
361
+ # Pad each element up to maxlen so the columns line up
362
+ row[j] = elem.rjust(maxlen[j])
363
+ res[i] = "[" + colsep.join(row) + "]"
364
+ return rowsep.join(res)
365
+
366
+ def __str__(self):
367
+ return self.__nstr__()
368
+
369
+ def _toliststr(self, avoid_type=False):
370
+ """
371
+ Create a list string from a matrix.
372
+
373
+ If avoid_type: avoid multiple 'mpf's.
374
+ """
375
+ # XXX: should be something like self.ctx._types
376
+ typ = self.ctx.mpf
377
+ s = '['
378
+ for i in xrange(self.__rows):
379
+ s += '['
380
+ for j in xrange(self.__cols):
381
+ if not avoid_type or not isinstance(self[i,j], typ):
382
+ a = repr(self[i,j])
383
+ else:
384
+ a = "'" + str(self[i,j]) + "'"
385
+ s += a + ', '
386
+ s = s[:-2]
387
+ s += '],\n '
388
+ s = s[:-3]
389
+ s += ']'
390
+ return s
391
+
392
+ def tolist(self):
393
+ """
394
+ Convert the matrix to a nested list.
395
+ """
396
+ return [[self[i,j] for j in range(self.__cols)] for i in range(self.__rows)]
397
+
398
+ def __repr__(self):
399
+ if self.ctx.pretty:
400
+ return self.__str__()
401
+ s = 'matrix(\n'
402
+ s += self._toliststr(avoid_type=True) + ')'
403
+ return s
404
+
405
+ def __get_element(self, key):
406
+ '''
407
+ Fast extraction of the i,j element from the matrix
408
+ This function is for private use only because is unsafe:
409
+ 1. Does not check on the value of key it expects key to be a integer tuple (i,j)
410
+ 2. Does not check bounds
411
+ '''
412
+ if key in self.__data:
413
+ return self.__data[key]
414
+ else:
415
+ return self.ctx.zero
416
+
417
+ def __set_element(self, key, value):
418
+ '''
419
+ Fast assignment of the i,j element in the matrix
420
+ This function is unsafe:
421
+ 1. Does not check on the value of key it expects key to be a integer tuple (i,j)
422
+ 2. Does not check bounds
423
+ 3. Does not check the value type
424
+ 4. Does not reset the LU cache
425
+ '''
426
+ if value: # only store non-zeros
427
+ self.__data[key] = value
428
+ elif key in self.__data:
429
+ del self.__data[key]
430
+
431
+
432
+ def __getitem__(self, key):
433
+ '''
434
+ Getitem function for mp matrix class with slice index enabled
435
+ it allows the following assingments
436
+ scalar to a slice of the matrix
437
+ B = A[:,2:6]
438
+ '''
439
+ # Convert vector to matrix indexing
440
+ if isinstance(key, int) or isinstance(key,slice):
441
+ # only sufficent for vectors
442
+ if self.__rows == 1:
443
+ key = (0, key)
444
+ elif self.__cols == 1:
445
+ key = (key, 0)
446
+ else:
447
+ raise IndexError('insufficient indices for matrix')
448
+
449
+ if isinstance(key[0],slice) or isinstance(key[1],slice):
450
+
451
+ #Rows
452
+ if isinstance(key[0],slice):
453
+ #Check bounds
454
+ if (key[0].start is None or key[0].start >= 0) and \
455
+ (key[0].stop is None or key[0].stop <= self.__rows+1):
456
+ # Generate indices
457
+ rows = xrange(*key[0].indices(self.__rows))
458
+ else:
459
+ raise IndexError('Row index out of bounds')
460
+ else:
461
+ # Single row
462
+ rows = [key[0]]
463
+
464
+ # Columns
465
+ if isinstance(key[1],slice):
466
+ # Check bounds
467
+ if (key[1].start is None or key[1].start >= 0) and \
468
+ (key[1].stop is None or key[1].stop <= self.__cols+1):
469
+ # Generate indices
470
+ columns = xrange(*key[1].indices(self.__cols))
471
+ else:
472
+ raise IndexError('Column index out of bounds')
473
+
474
+ else:
475
+ # Single column
476
+ columns = [key[1]]
477
+
478
+ # Create matrix slice
479
+ m = self.ctx.matrix(len(rows),len(columns))
480
+
481
+ # Assign elements to the output matrix
482
+ for i,x in enumerate(rows):
483
+ for j,y in enumerate(columns):
484
+ m.__set_element((i,j),self.__get_element((x,y)))
485
+
486
+ return m
487
+
488
+ else:
489
+ # single element extraction
490
+ if key[0] >= self.__rows or key[1] >= self.__cols:
491
+ raise IndexError('matrix index out of range')
492
+ if key in self.__data:
493
+ return self.__data[key]
494
+ else:
495
+ return self.ctx.zero
496
+
497
+ def __setitem__(self, key, value):
498
+ # setitem function for mp matrix class with slice index enabled
499
+ # it allows the following assingments
500
+ # scalar to a slice of the matrix
501
+ # A[:,2:6] = 2.5
502
+ # submatrix to matrix (the value matrix should be the same size as the slice size)
503
+ # A[3,:] = B where A is n x m and B is n x 1
504
+ # Convert vector to matrix indexing
505
+ if isinstance(key, int) or isinstance(key,slice):
506
+ # only sufficent for vectors
507
+ if self.__rows == 1:
508
+ key = (0, key)
509
+ elif self.__cols == 1:
510
+ key = (key, 0)
511
+ else:
512
+ raise IndexError('insufficient indices for matrix')
513
+ # Slice indexing
514
+ if isinstance(key[0],slice) or isinstance(key[1],slice):
515
+ # Rows
516
+ if isinstance(key[0],slice):
517
+ # Check bounds
518
+ if (key[0].start is None or key[0].start >= 0) and \
519
+ (key[0].stop is None or key[0].stop <= self.__rows+1):
520
+ # generate row indices
521
+ rows = xrange(*key[0].indices(self.__rows))
522
+ else:
523
+ raise IndexError('Row index out of bounds')
524
+ else:
525
+ # Single row
526
+ rows = [key[0]]
527
+ # Columns
528
+ if isinstance(key[1],slice):
529
+ # Check bounds
530
+ if (key[1].start is None or key[1].start >= 0) and \
531
+ (key[1].stop is None or key[1].stop <= self.__cols+1):
532
+ # Generate column indices
533
+ columns = xrange(*key[1].indices(self.__cols))
534
+ else:
535
+ raise IndexError('Column index out of bounds')
536
+ else:
537
+ # Single column
538
+ columns = [key[1]]
539
+ # Assign slice with a scalar
540
+ if isinstance(value,self.ctx.matrix):
541
+ # Assign elements to matrix if input and output dimensions match
542
+ if len(rows) == value.rows and len(columns) == value.cols:
543
+ for i,x in enumerate(rows):
544
+ for j,y in enumerate(columns):
545
+ self.__set_element((x,y), value.__get_element((i,j)))
546
+ else:
547
+ raise ValueError('Dimensions do not match')
548
+ else:
549
+ # Assign slice with scalars
550
+ value = self.ctx.convert(value)
551
+ for i in rows:
552
+ for j in columns:
553
+ self.__set_element((i,j), value)
554
+ else:
555
+ # Single element assingment
556
+ # Check bounds
557
+ if key[0] >= self.__rows or key[1] >= self.__cols:
558
+ raise IndexError('matrix index out of range')
559
+ # Convert and store value
560
+ value = self.ctx.convert(value)
561
+ if value: # only store non-zeros
562
+ self.__data[key] = value
563
+ elif key in self.__data:
564
+ del self.__data[key]
565
+
566
+ if self._LU:
567
+ self._LU = None
568
+ return
569
+
570
+ def __iter__(self):
571
+ for i in xrange(self.__rows):
572
+ for j in xrange(self.__cols):
573
+ yield self[i,j]
574
+
575
+ def __mul__(self, other):
576
+ if isinstance(other, self.ctx.matrix):
577
+ # dot multiplication
578
+ if self.__cols != other.__rows:
579
+ raise ValueError('dimensions not compatible for multiplication')
580
+ new = self.ctx.matrix(self.__rows, other.__cols)
581
+ self_zero = self.ctx.zero
582
+ self_get = self.__data.get
583
+ other_zero = other.ctx.zero
584
+ other_get = other.__data.get
585
+ for i in xrange(self.__rows):
586
+ for j in xrange(other.__cols):
587
+ new[i, j] = self.ctx.fdot((self_get((i,k), self_zero), other_get((k,j), other_zero))
588
+ for k in xrange(other.__rows))
589
+ return new
590
+ else:
591
+ # try scalar multiplication
592
+ new = self.ctx.matrix(self.__rows, self.__cols)
593
+ for i in xrange(self.__rows):
594
+ for j in xrange(self.__cols):
595
+ new[i, j] = other * self[i, j]
596
+ return new
597
+
598
+ def __matmul__(self, other):
599
+ return self.__mul__(other)
600
+
601
+ def __rmul__(self, other):
602
+ # assume other is scalar and thus commutative
603
+ if isinstance(other, self.ctx.matrix):
604
+ raise TypeError("other should not be type of ctx.matrix")
605
+ return self.__mul__(other)
606
+
607
+ def __pow__(self, other):
608
+ # avoid cyclic import problems
609
+ #from linalg import inverse
610
+ if not isinstance(other, int):
611
+ raise ValueError('only integer exponents are supported')
612
+ if not self.__rows == self.__cols:
613
+ raise ValueError('only powers of square matrices are defined')
614
+ n = other
615
+ if n == 0:
616
+ return self.ctx.eye(self.__rows)
617
+ if n < 0:
618
+ n = -n
619
+ neg = True
620
+ else:
621
+ neg = False
622
+ i = n
623
+ y = 1
624
+ z = self.copy()
625
+ while i != 0:
626
+ if i % 2 == 1:
627
+ y = y * z
628
+ z = z*z
629
+ i = i // 2
630
+ if neg:
631
+ y = self.ctx.inverse(y)
632
+ return y
633
+
634
+ def __div__(self, other):
635
+ # assume other is scalar and do element-wise divison
636
+ assert not isinstance(other, self.ctx.matrix)
637
+ new = self.ctx.matrix(self.__rows, self.__cols)
638
+ for i in xrange(self.__rows):
639
+ for j in xrange(self.__cols):
640
+ new[i,j] = self[i,j] / other
641
+ return new
642
+
643
+ __truediv__ = __div__
644
+
645
+ def __add__(self, other):
646
+ if isinstance(other, self.ctx.matrix):
647
+ if not (self.__rows == other.__rows and self.__cols == other.__cols):
648
+ raise ValueError('incompatible dimensions for addition')
649
+ new = self.ctx.matrix(self.__rows, self.__cols)
650
+ for i in xrange(self.__rows):
651
+ for j in xrange(self.__cols):
652
+ new[i,j] = self[i,j] + other[i,j]
653
+ return new
654
+ else:
655
+ # assume other is scalar and add element-wise
656
+ new = self.ctx.matrix(self.__rows, self.__cols)
657
+ for i in xrange(self.__rows):
658
+ for j in xrange(self.__cols):
659
+ new[i,j] += self[i,j] + other
660
+ return new
661
+
662
+ def __radd__(self, other):
663
+ return self.__add__(other)
664
+
665
+ def __sub__(self, other):
666
+ if isinstance(other, self.ctx.matrix) and not (self.__rows == other.__rows
667
+ and self.__cols == other.__cols):
668
+ raise ValueError('incompatible dimensions for subtraction')
669
+ return self.__add__(other * (-1))
670
+
671
+ def __pos__(self):
672
+ """
673
+ +M returns a copy of M, rounded to current working precision.
674
+ """
675
+ return (+1) * self
676
+
677
+ def __neg__(self):
678
+ return (-1) * self
679
+
680
+ def __rsub__(self, other):
681
+ return -self + other
682
+
683
+ def __eq__(self, other):
684
+ return self.__rows == other.__rows and self.__cols == other.__cols \
685
+ and self.__data == other.__data
686
+
687
+ def __len__(self):
688
+ if self.rows == 1:
689
+ return self.cols
690
+ elif self.cols == 1:
691
+ return self.rows
692
+ else:
693
+ return self.rows # do it like numpy
694
+
695
+ def __getrows(self):
696
+ return self.__rows
697
+
698
+ def __setrows(self, value):
699
+ for key in self.__data.copy():
700
+ if key[0] >= value:
701
+ del self.__data[key]
702
+ self.__rows = value
703
+
704
+ rows = property(__getrows, __setrows, doc='number of rows')
705
+
706
+ def __getcols(self):
707
+ return self.__cols
708
+
709
+ def __setcols(self, value):
710
+ for key in self.__data.copy():
711
+ if key[1] >= value:
712
+ del self.__data[key]
713
+ self.__cols = value
714
+
715
+ cols = property(__getcols, __setcols, doc='number of columns')
716
+
717
+ def transpose(self):
718
+ new = self.ctx.matrix(self.__cols, self.__rows)
719
+ for i in xrange(self.__rows):
720
+ for j in xrange(self.__cols):
721
+ new[j,i] = self[i,j]
722
+ return new
723
+
724
+ T = property(transpose)
725
+
726
+ def conjugate(self):
727
+ return self.apply(self.ctx.conj)
728
+
729
+ def transpose_conj(self):
730
+ return self.conjugate().transpose()
731
+
732
+ H = property(transpose_conj)
733
+
734
+ def copy(self):
735
+ new = self.ctx.matrix(self.__rows, self.__cols)
736
+ new.__data = self.__data.copy()
737
+ return new
738
+
739
+ __copy__ = copy
740
+
741
+ def column(self, n):
742
+ m = self.ctx.matrix(self.rows, 1)
743
+ for i in range(self.rows):
744
+ m[i] = self[i,n]
745
+ return m
746
+
747
+ class MatrixMethods(object):
748
+
749
+ def __init__(ctx):
750
+ # XXX: subclass
751
+ ctx.matrix = type('matrix', (_matrix,), {})
752
+ ctx.matrix.ctx = ctx
753
+ ctx.matrix.convert = ctx.convert
754
+
755
+ def eye(ctx, n, **kwargs):
756
+ """
757
+ Create square identity matrix n x n.
758
+ """
759
+ A = ctx.matrix(n, **kwargs)
760
+ for i in xrange(n):
761
+ A[i,i] = 1
762
+ return A
763
+
764
+ def diag(ctx, diagonal, **kwargs):
765
+ """
766
+ Create square diagonal matrix using given list.
767
+
768
+ Example:
769
+ >>> from mpmath import diag, mp
770
+ >>> mp.pretty = False
771
+ >>> diag([1, 2, 3])
772
+ matrix(
773
+ [['1.0', '0.0', '0.0'],
774
+ ['0.0', '2.0', '0.0'],
775
+ ['0.0', '0.0', '3.0']])
776
+ """
777
+ A = ctx.matrix(len(diagonal), **kwargs)
778
+ for i in xrange(len(diagonal)):
779
+ A[i,i] = diagonal[i]
780
+ return A
781
+
782
+ def zeros(ctx, *args, **kwargs):
783
+ """
784
+ Create matrix m x n filled with zeros.
785
+ One given dimension will create square matrix n x n.
786
+
787
+ Example:
788
+ >>> from mpmath import zeros, mp
789
+ >>> mp.pretty = False
790
+ >>> zeros(2)
791
+ matrix(
792
+ [['0.0', '0.0'],
793
+ ['0.0', '0.0']])
794
+ """
795
+ if len(args) == 1:
796
+ m = n = args[0]
797
+ elif len(args) == 2:
798
+ m = args[0]
799
+ n = args[1]
800
+ else:
801
+ raise TypeError('zeros expected at most 2 arguments, got %i' % len(args))
802
+ A = ctx.matrix(m, n, **kwargs)
803
+ for i in xrange(m):
804
+ for j in xrange(n):
805
+ A[i,j] = 0
806
+ return A
807
+
808
+ def ones(ctx, *args, **kwargs):
809
+ """
810
+ Create matrix m x n filled with ones.
811
+ One given dimension will create square matrix n x n.
812
+
813
+ Example:
814
+ >>> from mpmath import ones, mp
815
+ >>> mp.pretty = False
816
+ >>> ones(2)
817
+ matrix(
818
+ [['1.0', '1.0'],
819
+ ['1.0', '1.0']])
820
+ """
821
+ if len(args) == 1:
822
+ m = n = args[0]
823
+ elif len(args) == 2:
824
+ m = args[0]
825
+ n = args[1]
826
+ else:
827
+ raise TypeError('ones expected at most 2 arguments, got %i' % len(args))
828
+ A = ctx.matrix(m, n, **kwargs)
829
+ for i in xrange(m):
830
+ for j in xrange(n):
831
+ A[i,j] = 1
832
+ return A
833
+
834
+ def hilbert(ctx, m, n=None):
835
+ """
836
+ Create (pseudo) hilbert matrix m x n.
837
+ One given dimension will create hilbert matrix n x n.
838
+
839
+ The matrix is very ill-conditioned and symmetric, positive definite if
840
+ square.
841
+ """
842
+ if n is None:
843
+ n = m
844
+ A = ctx.matrix(m, n)
845
+ for i in xrange(m):
846
+ for j in xrange(n):
847
+ A[i,j] = ctx.one / (i + j + 1)
848
+ return A
849
+
850
+ def randmatrix(ctx, m, n=None, min=0, max=1, **kwargs):
851
+ """
852
+ Create a random m x n matrix.
853
+
854
+ All values are >= min and <max.
855
+ n defaults to m.
856
+
857
+ Example:
858
+ >>> from mpmath import randmatrix
859
+ >>> randmatrix(2) # doctest:+SKIP
860
+ matrix(
861
+ [['0.53491598236191806', '0.57195669543302752'],
862
+ ['0.85589992269513615', '0.82444367501382143']])
863
+ """
864
+ if not n:
865
+ n = m
866
+ A = ctx.matrix(m, n, **kwargs)
867
+ for i in xrange(m):
868
+ for j in xrange(n):
869
+ A[i,j] = ctx.rand() * (max - min) + min
870
+ return A
871
+
872
+ def swap_row(ctx, A, i, j):
873
+ """
874
+ Swap row i with row j.
875
+ """
876
+ if i == j:
877
+ return
878
+ if isinstance(A, ctx.matrix):
879
+ for k in xrange(A.cols):
880
+ A[i,k], A[j,k] = A[j,k], A[i,k]
881
+ elif isinstance(A, list):
882
+ A[i], A[j] = A[j], A[i]
883
+ else:
884
+ raise TypeError('could not interpret type')
885
+
886
+ def extend(ctx, A, b):
887
+ """
888
+ Extend matrix A with column b and return result.
889
+ """
890
+ if not isinstance(A, ctx.matrix):
891
+ raise TypeError("A should be a type of ctx.matrix")
892
+ if A.rows != len(b):
893
+ raise ValueError("Value should be equal to len(b)")
894
+ A = A.copy()
895
+ A.cols += 1
896
+ for i in xrange(A.rows):
897
+ A[i, A.cols-1] = b[i]
898
+ return A
899
+
900
+ def norm(ctx, x, p=2):
901
+ r"""
902
+ Gives the entrywise `p`-norm of an iterable *x*, i.e. the vector norm
903
+ `\left(\sum_k |x_k|^p\right)^{1/p}`, for any given `1 \le p \le \infty`.
904
+
905
+ Special cases:
906
+
907
+ If *x* is not iterable, this just returns ``absmax(x)``.
908
+
909
+ ``p=1`` gives the sum of absolute values.
910
+
911
+ ``p=2`` is the standard Euclidean vector norm.
912
+
913
+ ``p=inf`` gives the magnitude of the largest element.
914
+
915
+ For *x* a matrix, ``p=2`` is the Frobenius norm.
916
+ For operator matrix norms, use :func:`~mpmath.mnorm` instead.
917
+
918
+ You can use the string 'inf' as well as float('inf') or mpf('inf')
919
+ to specify the infinity norm.
920
+
921
+ **Examples**
922
+
923
+ >>> from mpmath import *
924
+ >>> mp.dps = 15; mp.pretty = False
925
+ >>> x = matrix([-10, 2, 100])
926
+ >>> norm(x, 1)
927
+ mpf('112.0')
928
+ >>> norm(x, 2)
929
+ mpf('100.5186549850325')
930
+ >>> norm(x, inf)
931
+ mpf('100.0')
932
+
933
+ """
934
+ try:
935
+ iter(x)
936
+ except TypeError:
937
+ return ctx.absmax(x)
938
+ if type(p) is not int:
939
+ p = ctx.convert(p)
940
+ if p == ctx.inf:
941
+ return max(ctx.absmax(i) for i in x)
942
+ elif p == 1:
943
+ return ctx.fsum(x, absolute=1)
944
+ elif p == 2:
945
+ return ctx.sqrt(ctx.fsum(x, absolute=1, squared=1))
946
+ elif p > 1:
947
+ return ctx.nthroot(ctx.fsum(abs(i)**p for i in x), p)
948
+ else:
949
+ raise ValueError('p has to be >= 1')
950
+
951
+ def mnorm(ctx, A, p=1):
952
+ r"""
953
+ Gives the matrix (operator) `p`-norm of A. Currently ``p=1`` and ``p=inf``
954
+ are supported:
955
+
956
+ ``p=1`` gives the 1-norm (maximal column sum)
957
+
958
+ ``p=inf`` gives the `\infty`-norm (maximal row sum).
959
+ You can use the string 'inf' as well as float('inf') or mpf('inf')
960
+
961
+ ``p=2`` (not implemented) for a square matrix is the usual spectral
962
+ matrix norm, i.e. the largest singular value.
963
+
964
+ ``p='f'`` (or 'F', 'fro', 'Frobenius, 'frobenius') gives the
965
+ Frobenius norm, which is the elementwise 2-norm. The Frobenius norm is an
966
+ approximation of the spectral norm and satisfies
967
+
968
+ .. math ::
969
+
970
+ \frac{1}{\sqrt{\mathrm{rank}(A)}} \|A\|_F \le \|A\|_2 \le \|A\|_F
971
+
972
+ The Frobenius norm lacks some mathematical properties that might
973
+ be expected of a norm.
974
+
975
+ For general elementwise `p`-norms, use :func:`~mpmath.norm` instead.
976
+
977
+ **Examples**
978
+
979
+ >>> from mpmath import *
980
+ >>> mp.dps = 15; mp.pretty = False
981
+ >>> A = matrix([[1, -1000], [100, 50]])
982
+ >>> mnorm(A, 1)
983
+ mpf('1050.0')
984
+ >>> mnorm(A, inf)
985
+ mpf('1001.0')
986
+ >>> mnorm(A, 'F')
987
+ mpf('1006.2310867787777')
988
+
989
+ """
990
+ A = ctx.matrix(A)
991
+ if type(p) is not int:
992
+ if type(p) is str and 'frobenius'.startswith(p.lower()):
993
+ return ctx.norm(A, 2)
994
+ p = ctx.convert(p)
995
+ m, n = A.rows, A.cols
996
+ if p == 1:
997
+ return max(ctx.fsum((A[i,j] for i in xrange(m)), absolute=1) for j in xrange(n))
998
+ elif p == ctx.inf:
999
+ return max(ctx.fsum((A[i,j] for j in xrange(n)), absolute=1) for i in xrange(m))
1000
+ else:
1001
+ raise NotImplementedError("matrix p-norm for arbitrary p")
1002
+
1003
+ if __name__ == '__main__':
1004
+ import doctest
1005
+ doctest.testmod()
lib/python3.11/site-packages/mpmath/rational.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ import sys
3
+ from .libmp import int_types, mpf_hash, bitcount, from_man_exp, HASH_MODULUS
4
+
5
+ new = object.__new__
6
+
7
+ def create_reduced(p, q, _cache={}):
8
+ key = p, q
9
+ if key in _cache:
10
+ return _cache[key]
11
+ x, y = p, q
12
+ while y:
13
+ x, y = y, x % y
14
+ if x != 1:
15
+ p //= x
16
+ q //= x
17
+ v = new(mpq)
18
+ v._mpq_ = p, q
19
+ # Speedup integers, half-integers and other small fractions
20
+ if q <= 4 and abs(key[0]) < 100:
21
+ _cache[key] = v
22
+ return v
23
+
24
+ class mpq(object):
25
+ """
26
+ Exact rational type, currently only intended for internal use.
27
+ """
28
+
29
+ __slots__ = ["_mpq_"]
30
+
31
+ def __new__(cls, p, q=1):
32
+ if type(p) is tuple:
33
+ p, q = p
34
+ elif hasattr(p, '_mpq_'):
35
+ p, q = p._mpq_
36
+ return create_reduced(p, q)
37
+
38
+ def __repr__(s):
39
+ return "mpq(%s,%s)" % s._mpq_
40
+
41
+ def __str__(s):
42
+ return "(%s/%s)" % s._mpq_
43
+
44
+ def __int__(s):
45
+ a, b = s._mpq_
46
+ return a // b
47
+
48
+ def __nonzero__(s):
49
+ return bool(s._mpq_[0])
50
+
51
+ __bool__ = __nonzero__
52
+
53
+ def __hash__(s):
54
+ a, b = s._mpq_
55
+ if sys.version_info >= (3, 2):
56
+ inverse = pow(b, HASH_MODULUS-2, HASH_MODULUS)
57
+ if not inverse:
58
+ h = sys.hash_info.inf
59
+ else:
60
+ h = (abs(a) * inverse) % HASH_MODULUS
61
+ if a < 0: h = -h
62
+ if h == -1: h = -2
63
+ return h
64
+ else:
65
+ if b == 1:
66
+ return hash(a)
67
+ # Power of two: mpf compatible hash
68
+ if not (b & (b-1)):
69
+ return mpf_hash(from_man_exp(a, 1-bitcount(b)))
70
+ return hash((a,b))
71
+
72
+ def __eq__(s, t):
73
+ ttype = type(t)
74
+ if ttype is mpq:
75
+ return s._mpq_ == t._mpq_
76
+ if ttype in int_types:
77
+ a, b = s._mpq_
78
+ if b != 1:
79
+ return False
80
+ return a == t
81
+ return NotImplemented
82
+
83
+ def __ne__(s, t):
84
+ ttype = type(t)
85
+ if ttype is mpq:
86
+ return s._mpq_ != t._mpq_
87
+ if ttype in int_types:
88
+ a, b = s._mpq_
89
+ if b != 1:
90
+ return True
91
+ return a != t
92
+ return NotImplemented
93
+
94
+ def _cmp(s, t, op):
95
+ ttype = type(t)
96
+ if ttype in int_types:
97
+ a, b = s._mpq_
98
+ return op(a, t*b)
99
+ if ttype is mpq:
100
+ a, b = s._mpq_
101
+ c, d = t._mpq_
102
+ return op(a*d, b*c)
103
+ return NotImplementedError
104
+
105
+ def __lt__(s, t): return s._cmp(t, operator.lt)
106
+ def __le__(s, t): return s._cmp(t, operator.le)
107
+ def __gt__(s, t): return s._cmp(t, operator.gt)
108
+ def __ge__(s, t): return s._cmp(t, operator.ge)
109
+
110
+ def __abs__(s):
111
+ a, b = s._mpq_
112
+ if a >= 0:
113
+ return s
114
+ v = new(mpq)
115
+ v._mpq_ = -a, b
116
+ return v
117
+
118
+ def __neg__(s):
119
+ a, b = s._mpq_
120
+ v = new(mpq)
121
+ v._mpq_ = -a, b
122
+ return v
123
+
124
+ def __pos__(s):
125
+ return s
126
+
127
+ def __add__(s, t):
128
+ ttype = type(t)
129
+ if ttype is mpq:
130
+ a, b = s._mpq_
131
+ c, d = t._mpq_
132
+ return create_reduced(a*d+b*c, b*d)
133
+ if ttype in int_types:
134
+ a, b = s._mpq_
135
+ v = new(mpq)
136
+ v._mpq_ = a+b*t, b
137
+ return v
138
+ return NotImplemented
139
+
140
+ __radd__ = __add__
141
+
142
+ def __sub__(s, t):
143
+ ttype = type(t)
144
+ if ttype is mpq:
145
+ a, b = s._mpq_
146
+ c, d = t._mpq_
147
+ return create_reduced(a*d-b*c, b*d)
148
+ if ttype in int_types:
149
+ a, b = s._mpq_
150
+ v = new(mpq)
151
+ v._mpq_ = a-b*t, b
152
+ return v
153
+ return NotImplemented
154
+
155
+ def __rsub__(s, t):
156
+ ttype = type(t)
157
+ if ttype is mpq:
158
+ a, b = s._mpq_
159
+ c, d = t._mpq_
160
+ return create_reduced(b*c-a*d, b*d)
161
+ if ttype in int_types:
162
+ a, b = s._mpq_
163
+ v = new(mpq)
164
+ v._mpq_ = b*t-a, b
165
+ return v
166
+ return NotImplemented
167
+
168
+ def __mul__(s, t):
169
+ ttype = type(t)
170
+ if ttype is mpq:
171
+ a, b = s._mpq_
172
+ c, d = t._mpq_
173
+ return create_reduced(a*c, b*d)
174
+ if ttype in int_types:
175
+ a, b = s._mpq_
176
+ return create_reduced(a*t, b)
177
+ return NotImplemented
178
+
179
+ __rmul__ = __mul__
180
+
181
+ def __div__(s, t):
182
+ ttype = type(t)
183
+ if ttype is mpq:
184
+ a, b = s._mpq_
185
+ c, d = t._mpq_
186
+ return create_reduced(a*d, b*c)
187
+ if ttype in int_types:
188
+ a, b = s._mpq_
189
+ return create_reduced(a, b*t)
190
+ return NotImplemented
191
+
192
+ def __rdiv__(s, t):
193
+ ttype = type(t)
194
+ if ttype is mpq:
195
+ a, b = s._mpq_
196
+ c, d = t._mpq_
197
+ return create_reduced(b*c, a*d)
198
+ if ttype in int_types:
199
+ a, b = s._mpq_
200
+ return create_reduced(b*t, a)
201
+ return NotImplemented
202
+
203
+ def __pow__(s, t):
204
+ ttype = type(t)
205
+ if ttype in int_types:
206
+ a, b = s._mpq_
207
+ if t:
208
+ if t < 0:
209
+ a, b, t = b, a, -t
210
+ v = new(mpq)
211
+ v._mpq_ = a**t, b**t
212
+ return v
213
+ raise ZeroDivisionError
214
+ return NotImplemented
215
+
216
+
217
+ mpq_1 = mpq((1,1))
218
+ mpq_0 = mpq((0,1))
219
+ mpq_1_2 = mpq((1,2))
220
+ mpq_3_2 = mpq((3,2))
221
+ mpq_1_4 = mpq((1,4))
222
+ mpq_1_16 = mpq((1,16))
223
+ mpq_3_16 = mpq((3,16))
224
+ mpq_5_2 = mpq((5,2))
225
+ mpq_3_4 = mpq((3,4))
226
+ mpq_7_4 = mpq((7,4))
227
+ mpq_5_4 = mpq((5,4))
228
+
229
+
230
+ # Register with "numbers" ABC
231
+ # We do not subclass, hence we do not use the @abstractmethod checks. While
232
+ # this is less invasive it may turn out that we do not actually support
233
+ # parts of the expected interfaces. See
234
+ # http://docs.python.org/2/library/numbers.html for list of abstract
235
+ # methods.
236
+ try:
237
+ import numbers
238
+ numbers.Rational.register(mpq)
239
+ except ImportError:
240
+ pass
lib/python3.11/site-packages/mpmath/tests/__init__.py ADDED
File without changes
lib/python3.11/site-packages/mpmath/tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (227 Bytes). View file
 
lib/python3.11/site-packages/mpmath/tests/__pycache__/extratest_gamma.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
lib/python3.11/site-packages/mpmath/tests/__pycache__/extratest_zeta.cpython-311.pyc ADDED
Binary file (1.8 kB). View file
 
lib/python3.11/site-packages/mpmath/tests/__pycache__/runtests.cpython-311.pyc ADDED
Binary file (7.55 kB). View file
 
lib/python3.11/site-packages/mpmath/tests/__pycache__/test_basic_ops.cpython-311.pyc ADDED
Binary file (39.2 kB). View file
 
lib/python3.11/site-packages/mpmath/tests/__pycache__/test_bitwise.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
lib/python3.11/site-packages/mpmath/tests/__pycache__/test_calculus.cpython-311.pyc ADDED
Binary file (14 kB). View file