| use super::monty::monty_modpow; |
| use super::BigUint; |
| |
| use crate::big_digit::{self, BigDigit}; |
| |
| use num_integer::Integer; |
| use num_traits::{One, Pow, ToPrimitive, Zero}; |
| |
| impl Pow<&BigUint> for BigUint { |
| type Output = BigUint; |
| |
| #[inline] |
| fn pow(self, exp: &BigUint) -> BigUint { |
| if self.is_one() || exp.is_zero() { |
| BigUint::one() |
| } else if self.is_zero() { |
| BigUint::zero() |
| } else if let Some(exp) = exp.to_u64() { |
| self.pow(exp) |
| } else if let Some(exp) = exp.to_u128() { |
| self.pow(exp) |
| } else { |
| // At this point, `self >= 2` and `exp >= 2¹²⁸`. The smallest possible result given |
| // `2.pow(2¹²⁸)` would require far more memory than 64-bit targets can address! |
| panic!("memory overflow") |
| } |
| } |
| } |
| |
| impl Pow<BigUint> for BigUint { |
| type Output = BigUint; |
| |
| #[inline] |
| fn pow(self, exp: BigUint) -> BigUint { |
| Pow::pow(self, &exp) |
| } |
| } |
| |
| impl Pow<&BigUint> for &BigUint { |
| type Output = BigUint; |
| |
| #[inline] |
| fn pow(self, exp: &BigUint) -> BigUint { |
| if self.is_one() || exp.is_zero() { |
| BigUint::one() |
| } else if self.is_zero() { |
| BigUint::zero() |
| } else { |
| self.clone().pow(exp) |
| } |
| } |
| } |
| |
| impl Pow<BigUint> for &BigUint { |
| type Output = BigUint; |
| |
| #[inline] |
| fn pow(self, exp: BigUint) -> BigUint { |
| Pow::pow(self, &exp) |
| } |
| } |
| |
| macro_rules! pow_impl { |
| ($T:ty) => { |
| impl Pow<$T> for BigUint { |
| type Output = BigUint; |
| |
| fn pow(self, mut exp: $T) -> BigUint { |
| if exp == 0 { |
| return BigUint::one(); |
| } |
| let mut base = self; |
| |
| while exp & 1 == 0 { |
| base = &base * &base; |
| exp >>= 1; |
| } |
| |
| if exp == 1 { |
| return base; |
| } |
| |
| let mut acc = base.clone(); |
| while exp > 1 { |
| exp >>= 1; |
| base = &base * &base; |
| if exp & 1 == 1 { |
| acc *= &base; |
| } |
| } |
| acc |
| } |
| } |
| |
| impl Pow<&$T> for BigUint { |
| type Output = BigUint; |
| |
| #[inline] |
| fn pow(self, exp: &$T) -> BigUint { |
| Pow::pow(self, *exp) |
| } |
| } |
| |
| impl Pow<$T> for &BigUint { |
| type Output = BigUint; |
| |
| #[inline] |
| fn pow(self, exp: $T) -> BigUint { |
| if exp == 0 { |
| return BigUint::one(); |
| } |
| Pow::pow(self.clone(), exp) |
| } |
| } |
| |
| impl Pow<&$T> for &BigUint { |
| type Output = BigUint; |
| |
| #[inline] |
| fn pow(self, exp: &$T) -> BigUint { |
| Pow::pow(self, *exp) |
| } |
| } |
| }; |
| } |
| |
| pow_impl!(u8); |
| pow_impl!(u16); |
| pow_impl!(u32); |
| pow_impl!(u64); |
| pow_impl!(usize); |
| pow_impl!(u128); |
| |
| pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint { |
| assert!( |
| !modulus.is_zero(), |
| "attempt to calculate with zero modulus!" |
| ); |
| |
| if modulus.is_odd() { |
| // For an odd modulus, we can use Montgomery multiplication in base 2^32. |
| monty_modpow(x, exponent, modulus) |
| } else { |
| // Otherwise do basically the same as `num::pow`, but with a modulus. |
| plain_modpow(x, &exponent.data, modulus) |
| } |
| } |
| |
| fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint { |
| assert!( |
| !modulus.is_zero(), |
| "attempt to calculate with zero modulus!" |
| ); |
| |
| let i = match exp_data.iter().position(|&r| r != 0) { |
| None => return BigUint::one(), |
| Some(i) => i, |
| }; |
| |
| let mut base = base % modulus; |
| for _ in 0..i { |
| for _ in 0..big_digit::BITS { |
| base = &base * &base % modulus; |
| } |
| } |
| |
| let mut r = exp_data[i]; |
| let mut b = 0u8; |
| while r.is_even() { |
| base = &base * &base % modulus; |
| r >>= 1; |
| b += 1; |
| } |
| |
| let mut exp_iter = exp_data[i + 1..].iter(); |
| if exp_iter.len() == 0 && r.is_one() { |
| return base; |
| } |
| |
| let mut acc = base.clone(); |
| r >>= 1; |
| b += 1; |
| |
| { |
| let mut unit = |exp_is_odd| { |
| base = &base * &base % modulus; |
| if exp_is_odd { |
| acc *= &base; |
| acc %= modulus; |
| } |
| }; |
| |
| if let Some(&last) = exp_iter.next_back() { |
| // consume exp_data[i] |
| for _ in b..big_digit::BITS { |
| unit(r.is_odd()); |
| r >>= 1; |
| } |
| |
| // consume all other digits before the last |
| for &r in exp_iter { |
| let mut r = r; |
| for _ in 0..big_digit::BITS { |
| unit(r.is_odd()); |
| r >>= 1; |
| } |
| } |
| r = last; |
| } |
| |
| debug_assert_ne!(r, 0); |
| while !r.is_zero() { |
| unit(r.is_odd()); |
| r >>= 1; |
| } |
| } |
| acc |
| } |
| |
| #[test] |
| fn test_plain_modpow() { |
| let two = &BigUint::from(2u32); |
| let modulus = BigUint::from(0x1100u32); |
| |
| let exp = vec![0, 0b1]; |
| assert_eq!( |
| two.pow(0b1_00000000_u32) % &modulus, |
| plain_modpow(two, &exp, &modulus) |
| ); |
| let exp = vec![0, 0b10]; |
| assert_eq!( |
| two.pow(0b10_00000000_u32) % &modulus, |
| plain_modpow(two, &exp, &modulus) |
| ); |
| let exp = vec![0, 0b110010]; |
| assert_eq!( |
| two.pow(0b110010_00000000_u32) % &modulus, |
| plain_modpow(two, &exp, &modulus) |
| ); |
| let exp = vec![0b1, 0b1]; |
| assert_eq!( |
| two.pow(0b1_00000001_u32) % &modulus, |
| plain_modpow(two, &exp, &modulus) |
| ); |
| let exp = vec![0b1100, 0, 0b1]; |
| assert_eq!( |
| two.pow(0b1_00000000_00001100_u32) % &modulus, |
| plain_modpow(two, &exp, &modulus) |
| ); |
| } |
| |
| #[test] |
| fn test_pow_biguint() { |
| let base = BigUint::from(5u8); |
| let exponent = BigUint::from(3u8); |
| |
| assert_eq!(BigUint::from(125u8), base.pow(exponent)); |
| } |