|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ._IntegerBase import IntegerBase |
|
|
|
from Crypto.Util.number import long_to_bytes, bytes_to_long, inverse, GCD |
|
|
|
|
|
class IntegerNative(IntegerBase): |
|
"""A class to model a natural integer (including zero)""" |
|
|
|
def __init__(self, value): |
|
if isinstance(value, float): |
|
raise ValueError("A floating point type is not a natural number") |
|
try: |
|
self._value = value._value |
|
except AttributeError: |
|
self._value = value |
|
|
|
|
|
def __int__(self): |
|
return self._value |
|
|
|
def __str__(self): |
|
return str(int(self)) |
|
|
|
def __repr__(self): |
|
return "Integer(%s)" % str(self) |
|
|
|
|
|
def __hex__(self): |
|
return hex(self._value) |
|
|
|
|
|
def __index__(self): |
|
return int(self._value) |
|
|
|
def to_bytes(self, block_size=0, byteorder='big'): |
|
if self._value < 0: |
|
raise ValueError("Conversion only valid for non-negative numbers") |
|
result = long_to_bytes(self._value, block_size) |
|
if len(result) > block_size > 0: |
|
raise ValueError("Value too large to encode") |
|
if byteorder == 'big': |
|
pass |
|
elif byteorder == 'little': |
|
result = bytearray(result) |
|
result.reverse() |
|
result = bytes(result) |
|
else: |
|
raise ValueError("Incorrect byteorder") |
|
return result |
|
|
|
@classmethod |
|
def from_bytes(cls, byte_string, byteorder='big'): |
|
if byteorder == 'big': |
|
pass |
|
elif byteorder == 'little': |
|
byte_string = bytearray(byte_string) |
|
byte_string.reverse() |
|
else: |
|
raise ValueError("Incorrect byteorder") |
|
return cls(bytes_to_long(byte_string)) |
|
|
|
|
|
def __eq__(self, term): |
|
if term is None: |
|
return False |
|
return self._value == int(term) |
|
|
|
def __ne__(self, term): |
|
return not self.__eq__(term) |
|
|
|
def __lt__(self, term): |
|
return self._value < int(term) |
|
|
|
def __le__(self, term): |
|
return self.__lt__(term) or self.__eq__(term) |
|
|
|
def __gt__(self, term): |
|
return not self.__le__(term) |
|
|
|
def __ge__(self, term): |
|
return not self.__lt__(term) |
|
|
|
def __nonzero__(self): |
|
return self._value != 0 |
|
__bool__ = __nonzero__ |
|
|
|
def is_negative(self): |
|
return self._value < 0 |
|
|
|
|
|
def __add__(self, term): |
|
try: |
|
return self.__class__(self._value + int(term)) |
|
except (ValueError, AttributeError, TypeError): |
|
return NotImplemented |
|
|
|
def __sub__(self, term): |
|
try: |
|
return self.__class__(self._value - int(term)) |
|
except (ValueError, AttributeError, TypeError): |
|
return NotImplemented |
|
|
|
def __mul__(self, factor): |
|
try: |
|
return self.__class__(self._value * int(factor)) |
|
except (ValueError, AttributeError, TypeError): |
|
return NotImplemented |
|
|
|
def __floordiv__(self, divisor): |
|
return self.__class__(self._value // int(divisor)) |
|
|
|
def __mod__(self, divisor): |
|
divisor_value = int(divisor) |
|
if divisor_value < 0: |
|
raise ValueError("Modulus must be positive") |
|
return self.__class__(self._value % divisor_value) |
|
|
|
def inplace_pow(self, exponent, modulus=None): |
|
exp_value = int(exponent) |
|
if exp_value < 0: |
|
raise ValueError("Exponent must not be negative") |
|
|
|
if modulus is not None: |
|
mod_value = int(modulus) |
|
if mod_value < 0: |
|
raise ValueError("Modulus must be positive") |
|
if mod_value == 0: |
|
raise ZeroDivisionError("Modulus cannot be zero") |
|
else: |
|
mod_value = None |
|
self._value = pow(self._value, exp_value, mod_value) |
|
return self |
|
|
|
def __pow__(self, exponent, modulus=None): |
|
result = self.__class__(self) |
|
return result.inplace_pow(exponent, modulus) |
|
|
|
def __abs__(self): |
|
return abs(self._value) |
|
|
|
def sqrt(self, modulus=None): |
|
|
|
value = self._value |
|
if modulus is None: |
|
if value < 0: |
|
raise ValueError("Square root of negative value") |
|
|
|
|
|
x = value |
|
y = (x + 1) // 2 |
|
while y < x: |
|
x = y |
|
y = (x + value // x) // 2 |
|
result = x |
|
else: |
|
if modulus <= 0: |
|
raise ValueError("Modulus must be positive") |
|
result = self._tonelli_shanks(self % modulus, modulus) |
|
|
|
return self.__class__(result) |
|
|
|
def __iadd__(self, term): |
|
self._value += int(term) |
|
return self |
|
|
|
def __isub__(self, term): |
|
self._value -= int(term) |
|
return self |
|
|
|
def __imul__(self, term): |
|
self._value *= int(term) |
|
return self |
|
|
|
def __imod__(self, term): |
|
modulus = int(term) |
|
if modulus == 0: |
|
raise ZeroDivisionError("Division by zero") |
|
if modulus < 0: |
|
raise ValueError("Modulus must be positive") |
|
self._value %= modulus |
|
return self |
|
|
|
|
|
def __and__(self, term): |
|
return self.__class__(self._value & int(term)) |
|
|
|
def __or__(self, term): |
|
return self.__class__(self._value | int(term)) |
|
|
|
def __rshift__(self, pos): |
|
try: |
|
return self.__class__(self._value >> int(pos)) |
|
except OverflowError: |
|
if self._value >= 0: |
|
return 0 |
|
else: |
|
return -1 |
|
|
|
def __irshift__(self, pos): |
|
try: |
|
self._value >>= int(pos) |
|
except OverflowError: |
|
if self._value >= 0: |
|
return 0 |
|
else: |
|
return -1 |
|
return self |
|
|
|
def __lshift__(self, pos): |
|
try: |
|
return self.__class__(self._value << int(pos)) |
|
except OverflowError: |
|
raise ValueError("Incorrect shift count") |
|
|
|
def __ilshift__(self, pos): |
|
try: |
|
self._value <<= int(pos) |
|
except OverflowError: |
|
raise ValueError("Incorrect shift count") |
|
return self |
|
|
|
def get_bit(self, n): |
|
if self._value < 0: |
|
raise ValueError("no bit representation for negative values") |
|
try: |
|
try: |
|
result = (self._value >> n._value) & 1 |
|
if n._value < 0: |
|
raise ValueError("negative bit count") |
|
except AttributeError: |
|
result = (self._value >> n) & 1 |
|
if n < 0: |
|
raise ValueError("negative bit count") |
|
except OverflowError: |
|
result = 0 |
|
return result |
|
|
|
|
|
def is_odd(self): |
|
return (self._value & 1) == 1 |
|
|
|
def is_even(self): |
|
return (self._value & 1) == 0 |
|
|
|
def size_in_bits(self): |
|
|
|
if self._value < 0: |
|
raise ValueError("Conversion only valid for non-negative numbers") |
|
|
|
if self._value == 0: |
|
return 1 |
|
|
|
return self._value.bit_length() |
|
|
|
def size_in_bytes(self): |
|
return (self.size_in_bits() - 1) // 8 + 1 |
|
|
|
def is_perfect_square(self): |
|
if self._value < 0: |
|
return False |
|
if self._value in (0, 1): |
|
return True |
|
|
|
x = self._value // 2 |
|
square_x = x ** 2 |
|
|
|
while square_x > self._value: |
|
x = (square_x + self._value) // (2 * x) |
|
square_x = x ** 2 |
|
|
|
return self._value == x ** 2 |
|
|
|
def fail_if_divisible_by(self, small_prime): |
|
if (self._value % int(small_prime)) == 0: |
|
raise ValueError("Value is composite") |
|
|
|
def multiply_accumulate(self, a, b): |
|
self._value += int(a) * int(b) |
|
return self |
|
|
|
def set(self, source): |
|
self._value = int(source) |
|
|
|
def inplace_inverse(self, modulus): |
|
self._value = inverse(self._value, int(modulus)) |
|
return self |
|
|
|
def inverse(self, modulus): |
|
result = self.__class__(self) |
|
result.inplace_inverse(modulus) |
|
return result |
|
|
|
def gcd(self, term): |
|
return self.__class__(GCD(abs(self._value), abs(int(term)))) |
|
|
|
def lcm(self, term): |
|
term = int(term) |
|
if self._value == 0 or term == 0: |
|
return self.__class__(0) |
|
return self.__class__(abs((self._value * term) // self.gcd(term)._value)) |
|
|
|
@staticmethod |
|
def jacobi_symbol(a, n): |
|
a = int(a) |
|
n = int(n) |
|
|
|
if n <= 0: |
|
raise ValueError("n must be a positive integer") |
|
|
|
if (n & 1) == 0: |
|
raise ValueError("n must be odd for the Jacobi symbol") |
|
|
|
|
|
a = a % n |
|
|
|
if a == 1 or n == 1: |
|
return 1 |
|
|
|
if a == 0: |
|
return 0 |
|
|
|
e = 0 |
|
a1 = a |
|
while (a1 & 1) == 0: |
|
a1 >>= 1 |
|
e += 1 |
|
|
|
if (e & 1) == 0: |
|
s = 1 |
|
elif n % 8 in (1, 7): |
|
s = 1 |
|
else: |
|
s = -1 |
|
|
|
if n % 4 == 3 and a1 % 4 == 3: |
|
s = -s |
|
|
|
n1 = n % a1 |
|
|
|
return s * IntegerNative.jacobi_symbol(n1, a1) |
|
|
|
@staticmethod |
|
def _mult_modulo_bytes(term1, term2, modulus): |
|
if modulus < 0: |
|
raise ValueError("Modulus must be positive") |
|
if modulus == 0: |
|
raise ZeroDivisionError("Modulus cannot be zero") |
|
if (modulus & 1) == 0: |
|
raise ValueError("Odd modulus is required") |
|
|
|
number_len = len(long_to_bytes(modulus)) |
|
return long_to_bytes((term1 * term2) % modulus, number_len) |
|
|