Spaces:
Sleeping
Sleeping
""" | |
Utility functions for integer math. | |
TODO: rename, cleanup, perhaps move the gmpy wrapper code | |
here from settings.py | |
""" | |
import math | |
from bisect import bisect | |
from .backend import xrange | |
from .backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO | |
small_trailing = [0] * 256 | |
for j in range(1,8): | |
small_trailing[1<<j::1<<(j+1)] = [j] * (1<<(7-j)) | |
def giant_steps(start, target, n=2): | |
""" | |
Return a list of integers ~= | |
[start, n*start, ..., target/n^2, target/n, target] | |
but conservatively rounded so that the quotient between two | |
successive elements is actually slightly less than n. | |
With n = 2, this describes suitable precision steps for a | |
quadratically convergent algorithm such as Newton's method; | |
with n = 3 steps for cubic convergence (Halley's method), etc. | |
>>> giant_steps(50,1000) | |
[66, 128, 253, 502, 1000] | |
>>> giant_steps(50,1000,4) | |
[65, 252, 1000] | |
""" | |
L = [target] | |
while L[-1] > start*n: | |
L = L + [L[-1]//n + 2] | |
return L[::-1] | |
def rshift(x, n): | |
"""For an integer x, calculate x >> n with the fastest (floor) | |
rounding. Unlike the plain Python expression (x >> n), n is | |
allowed to be negative, in which case a left shift is performed.""" | |
if n >= 0: return x >> n | |
else: return x << (-n) | |
def lshift(x, n): | |
"""For an integer x, calculate x << n. Unlike the plain Python | |
expression (x << n), n is allowed to be negative, in which case a | |
right shift with default (floor) rounding is performed.""" | |
if n >= 0: return x << n | |
else: return x >> (-n) | |
if BACKEND == 'sage': | |
import operator | |
rshift = operator.rshift | |
lshift = operator.lshift | |
def python_trailing(n): | |
"""Count the number of trailing zero bits in abs(n).""" | |
if not n: | |
return 0 | |
low_byte = n & 0xff | |
if low_byte: | |
return small_trailing[low_byte] | |
t = 8 | |
n >>= 8 | |
while not n & 0xff: | |
n >>= 8 | |
t += 8 | |
return t + small_trailing[n & 0xff] | |
if BACKEND == 'gmpy': | |
if gmpy.version() >= '2': | |
def gmpy_trailing(n): | |
"""Count the number of trailing zero bits in abs(n) using gmpy.""" | |
if n: return MPZ(n).bit_scan1() | |
else: return 0 | |
else: | |
def gmpy_trailing(n): | |
"""Count the number of trailing zero bits in abs(n) using gmpy.""" | |
if n: return MPZ(n).scan1() | |
else: return 0 | |
# Small powers of 2 | |
powers = [1<<_ for _ in range(300)] | |
def python_bitcount(n): | |
"""Calculate bit size of the nonnegative integer n.""" | |
bc = bisect(powers, n) | |
if bc != 300: | |
return bc | |
bc = int(math.log(n, 2)) - 4 | |
return bc + bctable[n>>bc] | |
def gmpy_bitcount(n): | |
"""Calculate bit size of the nonnegative integer n.""" | |
if n: return MPZ(n).numdigits(2) | |
else: return 0 | |
#def sage_bitcount(n): | |
# if n: return MPZ(n).nbits() | |
# else: return 0 | |
def sage_trailing(n): | |
return MPZ(n).trailing_zero_bits() | |
if BACKEND == 'gmpy': | |
bitcount = gmpy_bitcount | |
trailing = gmpy_trailing | |
elif BACKEND == 'sage': | |
sage_bitcount = sage_utils.bitcount | |
bitcount = sage_bitcount | |
trailing = sage_trailing | |
else: | |
bitcount = python_bitcount | |
trailing = python_trailing | |
if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy): | |
bitcount = gmpy.bit_length | |
# Used to avoid slow function calls as far as possible | |
trailtable = [trailing(n) for n in range(256)] | |
bctable = [bitcount(n) for n in range(1024)] | |
# TODO: speed up for bases 2, 4, 8, 16, ... | |
def bin_to_radix(x, xbits, base, bdigits): | |
"""Changes radix of a fixed-point number; i.e., converts | |
x * 2**xbits to floor(x * 10**bdigits).""" | |
return x * (MPZ(base)**bdigits) >> xbits | |
stddigits = '0123456789abcdefghijklmnopqrstuvwxyz' | |
def small_numeral(n, base=10, digits=stddigits): | |
"""Return the string numeral of a positive integer in an arbitrary | |
base. Most efficient for small input.""" | |
if base == 10: | |
return str(n) | |
digs = [] | |
while n: | |
n, digit = divmod(n, base) | |
digs.append(digits[digit]) | |
return "".join(digs[::-1]) | |
def numeral_python(n, base=10, size=0, digits=stddigits): | |
"""Represent the integer n as a string of digits in the given base. | |
Recursive division is used to make this function about 3x faster | |
than Python's str() for converting integers to decimal strings. | |
The 'size' parameters specifies the number of digits in n; this | |
number is only used to determine splitting points and need not be | |
exact.""" | |
if n <= 0: | |
if not n: | |
return "0" | |
return "-" + numeral(-n, base, size, digits) | |
# Fast enough to do directly | |
if size < 250: | |
return small_numeral(n, base, digits) | |
# Divide in half | |
half = (size // 2) + (size & 1) | |
A, B = divmod(n, base**half) | |
ad = numeral(A, base, half, digits) | |
bd = numeral(B, base, half, digits).rjust(half, "0") | |
return ad + bd | |
def numeral_gmpy(n, base=10, size=0, digits=stddigits): | |
"""Represent the integer n as a string of digits in the given base. | |
Recursive division is used to make this function about 3x faster | |
than Python's str() for converting integers to decimal strings. | |
The 'size' parameters specifies the number of digits in n; this | |
number is only used to determine splitting points and need not be | |
exact.""" | |
if n < 0: | |
return "-" + numeral(-n, base, size, digits) | |
# gmpy.digits() may cause a segmentation fault when trying to convert | |
# extremely large values to a string. The size limit may need to be | |
# adjusted on some platforms, but 1500000 works on Windows and Linux. | |
if size < 1500000: | |
return gmpy.digits(n, base) | |
# Divide in half | |
half = (size // 2) + (size & 1) | |
A, B = divmod(n, MPZ(base)**half) | |
ad = numeral(A, base, half, digits) | |
bd = numeral(B, base, half, digits).rjust(half, "0") | |
return ad + bd | |
if BACKEND == "gmpy": | |
numeral = numeral_gmpy | |
else: | |
numeral = numeral_python | |
_1_800 = 1<<800 | |
_1_600 = 1<<600 | |
_1_400 = 1<<400 | |
_1_200 = 1<<200 | |
_1_100 = 1<<100 | |
_1_50 = 1<<50 | |
def isqrt_small_python(x): | |
""" | |
Correctly (floor) rounded integer square root, using | |
division. Fast up to ~200 digits. | |
""" | |
if not x: | |
return x | |
if x < _1_800: | |
# Exact with IEEE double precision arithmetic | |
if x < _1_50: | |
return int(x**0.5) | |
# Initial estimate can be any integer >= the true root; round up | |
r = int(x**0.5 * 1.00000000000001) + 1 | |
else: | |
bc = bitcount(x) | |
n = bc//2 | |
r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up | |
# The following iteration now precisely computes floor(sqrt(x)) | |
# See e.g. Crandall & Pomerance, "Prime Numbers: A Computational | |
# Perspective" | |
while 1: | |
y = (r+x//r)>>1 | |
if y >= r: | |
return r | |
r = y | |
def isqrt_fast_python(x): | |
""" | |
Fast approximate integer square root, computed using division-free | |
Newton iteration for large x. For random integers the result is almost | |
always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly | |
0.1% probability. If x is very close to an exact square, the answer is | |
1 ulp wrong with high probability. | |
With 0 guard bits, the largest error over a set of 10^5 random | |
inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits | |
almost certainly guarantees a max 1 ulp error. | |
""" | |
# Use direct division-based iteration if sqrt(x) < 2^400 | |
# Assume floating-point square root accurate to within 1 ulp, then: | |
# 0 Newton iterations good to 52 bits | |
# 1 Newton iterations good to 104 bits | |
# 2 Newton iterations good to 208 bits | |
# 3 Newton iterations good to 416 bits | |
if x < _1_800: | |
y = int(x**0.5) | |
if x >= _1_100: | |
y = (y + x//y) >> 1 | |
if x >= _1_200: | |
y = (y + x//y) >> 1 | |
if x >= _1_400: | |
y = (y + x//y) >> 1 | |
return y | |
bc = bitcount(x) | |
guard_bits = 10 | |
x <<= 2*guard_bits | |
bc += 2*guard_bits | |
bc += (bc&1) | |
hbc = bc//2 | |
startprec = min(50, hbc) | |
# Newton iteration for 1/sqrt(x), with floating-point starting value | |
r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5) | |
pp = startprec | |
for p in giant_steps(startprec, hbc): | |
# r**2, scaled from real size 2**(-bc) to 2**p | |
r2 = (r*r) >> (2*pp - p) | |
# x*r**2, scaled from real size ~1.0 to 2**p | |
xr2 = ((x >> (bc-p)) * r2) >> p | |
# New value of r, scaled from real size 2**(-bc/2) to 2**p | |
r = (r * ((3<<p) - xr2)) >> (pp+1) | |
pp = p | |
# (1/sqrt(x))*x = sqrt(x) | |
return (r*(x>>hbc)) >> (p+guard_bits) | |
def sqrtrem_python(x): | |
"""Correctly rounded integer (floor) square root with remainder.""" | |
# to check cutoff: | |
# plot(lambda x: timing(isqrt, 2**int(x)), [0,2000]) | |
if x < _1_600: | |
y = isqrt_small_python(x) | |
return y, x - y*y | |
y = isqrt_fast_python(x) + 1 | |
rem = x - y*y | |
# Correct remainder | |
while rem < 0: | |
y -= 1 | |
rem += (1+2*y) | |
else: | |
if rem: | |
while rem > 2*(1+y): | |
y += 1 | |
rem -= (1+2*y) | |
return y, rem | |
def isqrt_python(x): | |
"""Integer square root with correct (floor) rounding.""" | |
return sqrtrem_python(x)[0] | |
def sqrt_fixed(x, prec): | |
return isqrt_fast(x<<prec) | |
sqrt_fixed2 = sqrt_fixed | |
if BACKEND == 'gmpy': | |
if gmpy.version() >= '2': | |
isqrt_small = isqrt_fast = isqrt = gmpy.isqrt | |
sqrtrem = gmpy.isqrt_rem | |
else: | |
isqrt_small = isqrt_fast = isqrt = gmpy.sqrt | |
sqrtrem = gmpy.sqrtrem | |
elif BACKEND == 'sage': | |
isqrt_small = isqrt_fast = isqrt = \ | |
getattr(sage_utils, "isqrt", lambda n: MPZ(n).isqrt()) | |
sqrtrem = lambda n: MPZ(n).sqrtrem() | |
else: | |
isqrt_small = isqrt_small_python | |
isqrt_fast = isqrt_fast_python | |
isqrt = isqrt_python | |
sqrtrem = sqrtrem_python | |
def ifib(n, _cache={}): | |
"""Computes the nth Fibonacci number as an integer, for | |
integer n.""" | |
if n < 0: | |
return (-1)**(-n+1) * ifib(-n) | |
if n in _cache: | |
return _cache[n] | |
m = n | |
# Use Dijkstra's logarithmic algorithm | |
# The following implementation is basically equivalent to | |
# http://en.literateprograms.org/Fibonacci_numbers_(Scheme) | |
a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE | |
while n: | |
if n & 1: | |
aq = a*q | |
a, b = b*q+aq+a*p, b*p+aq | |
n -= 1 | |
else: | |
qq = q*q | |
p, q = p*p+qq, qq+2*p*q | |
n >>= 1 | |
if m < 250: | |
_cache[m] = b | |
return b | |
MAX_FACTORIAL_CACHE = 1000 | |
def ifac(n, memo={0:1, 1:1}): | |
"""Return n factorial (for integers n >= 0 only).""" | |
f = memo.get(n) | |
if f: | |
return f | |
k = len(memo) | |
p = memo[k-1] | |
MAX = MAX_FACTORIAL_CACHE | |
while k <= n: | |
p *= k | |
if k <= MAX: | |
memo[k] = p | |
k += 1 | |
return p | |
def ifac2(n, memo_pair=[{0:1}, {1:1}]): | |
"""Return n!! (double factorial), integers n >= 0 only.""" | |
memo = memo_pair[n&1] | |
f = memo.get(n) | |
if f: | |
return f | |
k = max(memo) | |
p = memo[k] | |
MAX = MAX_FACTORIAL_CACHE | |
while k < n: | |
k += 2 | |
p *= k | |
if k <= MAX: | |
memo[k] = p | |
return p | |
if BACKEND == 'gmpy': | |
ifac = gmpy.fac | |
elif BACKEND == 'sage': | |
ifac = lambda n: int(sage.factorial(n)) | |
ifib = sage.fibonacci | |
def list_primes(n): | |
n = n + 1 | |
sieve = list(xrange(n)) | |
sieve[:2] = [0, 0] | |
for i in xrange(2, int(n**0.5)+1): | |
if sieve[i]: | |
for j in xrange(i**2, n, i): | |
sieve[j] = 0 | |
return [p for p in sieve if p] | |
if BACKEND == 'sage': | |
# Note: it is *VERY* important for performance that we convert | |
# the list to Python ints. | |
def list_primes(n): | |
return [int(_) for _ in sage.primes(n+1)] | |
small_odd_primes = (3,5,7,11,13,17,19,23,29,31,37,41,43,47) | |
small_odd_primes_set = set(small_odd_primes) | |
def isprime(n): | |
""" | |
Determines whether n is a prime number. A probabilistic test is | |
performed if n is very large. No special trick is used for detecting | |
perfect powers. | |
>>> sum(list_primes(100000)) | |
454396537 | |
>>> sum(n*isprime(n) for n in range(100000)) | |
454396537 | |
""" | |
n = int(n) | |
if not n & 1: | |
return n == 2 | |
if n < 50: | |
return n in small_odd_primes_set | |
for p in small_odd_primes: | |
if not n % p: | |
return False | |
m = n-1 | |
s = trailing(m) | |
d = m >> s | |
def test(a): | |
x = pow(a,d,n) | |
if x == 1 or x == m: | |
return True | |
for r in xrange(1,s): | |
x = x**2 % n | |
if x == m: | |
return True | |
return False | |
# See http://primes.utm.edu/prove/prove2_3.html | |
if n < 1373653: | |
witnesses = [2,3] | |
elif n < 341550071728321: | |
witnesses = [2,3,5,7,11,13,17] | |
else: | |
witnesses = small_odd_primes | |
for a in witnesses: | |
if not test(a): | |
return False | |
return True | |
def moebius(n): | |
""" | |
Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n` | |
is a product of `k` distinct primes and `mu(n) = 0` otherwise. | |
TODO: speed up using factorization | |
""" | |
n = abs(int(n)) | |
if n < 2: | |
return n | |
factors = [] | |
for p in xrange(2, n+1): | |
if not (n % p): | |
if not (n % p**2): | |
return 0 | |
if not sum(p % f for f in factors): | |
factors.append(p) | |
return (-1)**len(factors) | |
def gcd(*args): | |
a = 0 | |
for b in args: | |
if a: | |
while b: | |
a, b = b, a % b | |
else: | |
a = b | |
return a | |
# Comment by Juan Arias de Reyna: | |
# | |
# I learn this method to compute EulerE[2n] from van de Lune. | |
# | |
# We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1) | |
# | |
# where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies | |
# | |
# a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0 | |
# | |
# a(n,j) = a(n-1,j) when n+j is even | |
# a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd | |
# | |
# | |
# But we can use only one array unidimensional a(j) since to compute | |
# a(n,j) we only need to know a(n-1,k) where k and j are of different parity | |
# and we have not to conserve the used values. | |
# | |
# We cached up the values of Euler numbers to sufficiently high order. | |
# | |
# Important Observation: If we pretend to use the numbers | |
# EulerE[1], EulerE[2], ... , EulerE[n] | |
# it is convenient to compute first EulerE[n], since the algorithm | |
# computes first all | |
# the previous ones, and keeps them in the CACHE | |
MAX_EULER_CACHE = 500 | |
def eulernum(m, _cache={0:MPZ_ONE}): | |
r""" | |
Computes the Euler numbers `E(n)`, which can be defined as | |
coefficients of the Taylor expansion of `1/cosh x`: | |
.. math :: | |
\frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n | |
Example:: | |
>>> [int(eulernum(n)) for n in range(11)] | |
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521] | |
>>> [int(eulernum(n)) for n in range(11)] # test cache | |
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521] | |
""" | |
# for odd m > 1, the Euler numbers are zero | |
if m & 1: | |
return MPZ_ZERO | |
f = _cache.get(m) | |
if f: | |
return f | |
MAX = MAX_EULER_CACHE | |
n = m | |
a = [MPZ(_) for _ in [0,0,1,0,0,0]] | |
for n in range(1, m+1): | |
for j in range(n+1, -1, -2): | |
a[j+1] = (j-1)*a[j] + (j+1)*a[j+2] | |
a.append(0) | |
suma = 0 | |
for k in range(n+1, -1, -2): | |
suma += a[k+1] | |
if n <= MAX: | |
_cache[n] = ((-1)**(n//2))*(suma // 2**n) | |
if n == m: | |
return ((-1)**(n//2))*suma // 2**n | |
def stirling1(n, k): | |
""" | |
Stirling number of the first kind. | |
""" | |
if n < 0 or k < 0: | |
raise ValueError | |
if k >= n: | |
return MPZ(n == k) | |
if k < 1: | |
return MPZ_ZERO | |
L = [MPZ_ZERO] * (k+1) | |
L[1] = MPZ_ONE | |
for m in xrange(2, n+1): | |
for j in xrange(min(k, m), 0, -1): | |
L[j] = (m-1) * L[j] + L[j-1] | |
return (-1)**(n+k) * L[k] | |
def stirling2(n, k): | |
""" | |
Stirling number of the second kind. | |
""" | |
if n < 0 or k < 0: | |
raise ValueError | |
if k >= n: | |
return MPZ(n == k) | |
if k <= 1: | |
return MPZ(k == 1) | |
s = MPZ_ZERO | |
t = MPZ_ONE | |
for j in xrange(k+1): | |
if (k + j) & 1: | |
s -= t * MPZ(j)**n | |
else: | |
s += t * MPZ(j)**n | |
t = t * (k - j) // (j + 1) | |
return s // ifac(k) | |