由于研究Libra等数字货币编程技术的需要,学习了一段时间的Rust编程,一不小心刷题上瘾。
“欧拉计划”的网址: https://projecteuler.net
英文如果不过关,可以到中文翻译的网站: http://pe-cn.github.io/
这个网站提供了几百道由易到难的数学问题,你可以用任何办法去解决它,当然主要还得靠编程,编程语言不限,论坛里已经有Java、C#、Python、Lisp、Haskell等各种解法,当然如果你直接用google搜索答案就没任何乐趣了。
这次解答的是第650题:
https://projecteuler.net/problem=650
题目描述:
这道题挺折磨人,花了我好几天时间才搞定。
如果你会其它编程语言,也可以一试。
请
先
不
要
直
接
看
答
案
,
最
好
自
己
先
尝
试
一
下
。
解题过程:
遇到一个复杂的问题,可以尝试先解决简单的情况,然后慢慢逼近最终的问题。
第一步: 手工试算几个,找找规律
可以看到B(n)的求解与排列组合数的因子有关系。
B(2) = 2^1
B(3) = 3^2
B(4) = (2^5) * (3^1)
B(5) = (2^2) * (5^4)
如果一个数能够分解成几个素数乘积,利用等比数列的求和公式,可以推导出所有因子之和的公式:
公式中只有2个不同的质因子,可以推广到更多质因子的情况。
现在可以暴力计算了:
extern crate num_bigint;
use num_bigint::BigUint;
fn main() {
// 先把一些阶乘计算好,保存起来
let mut fact = vec![BigUint::from(1 as u64); 101];
let mut a = BigUint::from(1 as u64);
for n in 2..=100 {
a *= BigUint::from(n as u64);
fact[n] = a.clone();
}
//println!("{:?}", fact);
let mut s = 0;
for n in 1..=10 {
let mut prod = BigUint::from(1_u64);
for r in 1..=n {
let comb = &fact[n] / &fact[r] / &fact[n - r];
prod *= &comb;
//println!("{} {} {}", n, r, comb.to_string());
}
let b = prod.to_string().parse::<u64>().unwrap();
println!("B({}) = {}", n, b);
let f_all = primes::factors(b);
let f_uniq = primes::factors_uniq(b);
//println!("{:?}", f_all);
//println!("{:?}", f_uniq);
let mut d = 1;
for f in f_uniq {
let c = f_all.iter().filter(|&n| *n == f).count();
d *= (f.pow(c as u32 + 1) - 1) / (f-1);
}
//println!("D({}) = {}", n, d);
s += d;
println!("S({}) = {}", n, s);
}
}
可以正确地计算出S(10),但在计算S(11)时就会出现溢出错误。
第二步:
溢出发生在pow()函数的计算上,求排列组合数和求阶乘的运算量太大,没必要把乘积计算出来,可以将因子保存在一个向量中,不断添加和删除相应的元素即可。
extern crate num_bigint;
use num_bigint::BigUint;
fn main() {
let mut s = 0;
for n in 1..=100 {
let mut factors = vec![];
for i in 1..=n {
let mut f = comb_factors(n, i);
factors.append(&mut f);
}
factors.sort();
let d = factors_sum(&factors);
println!("D({}) = {:?}", n, d);
s = (s + d) % 1_000_000_007_u64;
println!("S({}) = {}", n, s);
}
}
fn factors_sum(v: &Vec<u64>) -> u64 {
let mut uniq = v.clone();
uniq.dedup();
let mut prod = BigUint::from(1_u64);
for p in uniq {
let c = v.iter().filter(|&x| *x == p).count() as u64;
let t = (big_pow(p, c + 1) - BigUint::from(1_u64)) / (BigUint::from(p - 1));
//println!("{} {} {}", p, c, t);
prod = prod * t % 1_000_000_007_u64;
}
let prod = prod % 1_000_000_007_u64;
prod.to_string().parse::<u64>().unwrap()
}
fn big_pow(a: u64, b: u64) -> BigUint {
let mut prod = BigUint::from(1 as u64);
for _i in 0..b {
prod *= BigUint::from(a as u64);
}
prod
}
// va中元素已经排好序
fn vec_remove(va: &mut Vec<u64>, vb: &Vec<u64>) {
for item in vb {
let index = va.binary_search(&item).unwrap();
//println!("{:?} {:?} {}", va, vb, index);
va.remove(index);
}
}
fn comb_factors(m: u64, n: u64) -> Vec<u64> {
let mut factors = vec![];
let mut x = m;
for i in 0..n {
let mut f = primes::factors(x);
factors.append(&mut f);
x -= 1;
}
factors.sort();
//println!("{:?}", factors);
for i in 2..=n {
let f = primes::factors(i);
//println!("{} {:?}", n, f);
vec_remove(&mut factors, &f);
}
factors.to_vec()
}
这次可以很容易地计算出S(100),但要计算S(1000)则要花相当长的时间,主要是因为数组越来越大,数组排序太花时间。
还得优化。
第三步 用HashMap
数组中元素的排序太慢,尝试换成字典来存储因子,在Rust中用HashMap来实现。比如,B(10)中含有12个2,10个3,8个5,3个7。
而B(100)中含有如下因子,字典中key是质因子,value是个数。
{97: 93, 19: 65, 83: 65, 41: 44, 59: 17, 79: 57, 61: 21, 67: 33, 2: 335, 3: 192, 37: 20, 23: 56, 31: 69, 47: 80, 13: 21, 71: 41, 7: 148, 73: 45, 11: 81,
43: 56, 17: 5, 53: 5, 5: 176, 89: 77, 29: 45}
修改后的代码:
extern crate num_bigint;
use num_bigint::BigUint;
use std::collections::HashMap;
fn main() {
let mut s = 0;
for n in 1..=10000 {
let map = comb_factors_hash_map(n);
//println!("{:?}", map);
let temp = factors_hash_map_sum(&map);
//println!("D({}) = {:?}", n, temp);
s = (s + temp) % 1_000_000_007_u64;
println!("S({}) = {}", n, s);
}
}
fn big_pow(a: u64, b: u64) -> BigUint {
let mut prod = BigUint::from(1 as u64);
for _i in 0..b {
prod *= BigUint::from(a as u64);
}
prod
}
fn factors_hash_map_sum(map: &HashMap<u64, u64>) -> u64 {
let mut prod = BigUint::from(1_u64);
for (&f, count) in map {
let t = (big_pow(f, count + 1) - BigUint::from(1_u64)) / (BigUint::from(f - 1));
prod = prod * t % 1_000_000_007_u64;
}
let prod = prod % 1_000_000_007_u64;
prod.to_string().parse::<u64>().unwrap()
}
fn comb_factors_hash_map(x: u64) -> HashMap<u64, u64> {
let mut map = HashMap::new();
let mut count = x as i64 - 1;
for n in (2..=x).rev() {
let f = primes::factors(n);
let a = factors_to_hash_map(&f);
if count >= 0 {
hash_map_add_count(&mut map, &a, count as u64);
} else {
hash_map_substract_count(&mut map, &a, count.abs() as u64);
}
count -= 2;
}
map
}
fn factors_to_hash_map(factors: &Vec<u64>) -> HashMap<u64, u64> {
let mut map = HashMap::new();
for f in factors {
let v = map.get(f).cloned(); // 如果不写cloned(),有警告,不理解原因
match v {
Some(x) => {
map.insert(*f, x + 1);
}
None => {
map.insert(*f, 1);
}
}
}
map
}
fn hash_map_add_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x + count * times);
}
None => {
map.insert(*f, *count * times);
}
}
}
}
fn hash_map_substract_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x - count * times);
}
None => {}
}
}
}
这次可以轻松地计算到S(1000),但后面越来越慢。还得继续优化。
第四步 加速pow()运算
计算一个数的n次方,有一个常用的算法可以加速,称为Exponentiation by squaring算法,相关网站:
https://en.wikipedia.org/wiki/Exponentiation_by_squaring
修改后程序运行速度提高,多运行一段时间,可以计算出S(5000)。
fn big_pow(base: u64, exp: u64) -> BigUint {
let mut result = BigUint::from(1 as u64);
let mut b = BigUint::from(base);
let mut e = exp;
while e > 0 {
if e % 2 == 1 {
result = result * &b;
}
e = e >> 1;
b = &b * &b;
}
result
}
第五步 递推公式
程序运行几个小时,应该能够计算出S(20000),在它计算的过程中,我还要尝试进一步的优化。但在这之后,我开始走弯路了,尝试缓存一些中间的计算结果来进行加速,效果都不理想。
google "euler project problem 650",发现一篇文章。
http://goatleaps.xyz/euler/maths/Project-Euler-650.html
里面提到一个B(n-1)递推出B(n)的公式,可以进一步优化。
我也用笨办法推导出了这条公式:
use std::collections::HashMap;
const MODULUS: u64 = 1_000_000_007_u64;
extern crate num_bigint;
use num_bigint::BigUint;
fn main() {
let mut factorial_factors = HashMap::new();
let mut map = HashMap::new();
let mut s = 1;
for n in 2..=20000 {
let factor_n = factors_map(n);
map_add_count(&mut map, &factor_n, n);
// 缓存 n! 阶乘的因子
map_add(&mut factorial_factors, &factor_n);
map_substract(&mut map, &factorial_factors);
//println!("{:?}", &map);
let d = map_sum(&map);
s = (s + d) % MODULUS;
if n == 10 {
assert_eq!(s, 141740594713218418 % MODULUS);
}
if n == 100 {
assert_eq!(s, 332792866_u64);
}
if n % 100 == 0 {
println!("D({}) = {:?} \t S({}) = {}", n, d, n, s);
}
}
println!("{}", s);
}
fn map_sum(map: &HashMap<u64, u64>) -> u64 {
let mut prod = BigUint::from(1_u64);
for (&f, count) in map {
let t = (big_pow(f, count + 1) - BigUint::from(1_u64)) / (BigUint::from(f - 1));
prod = prod * t % 1_000_000_007_u64;
}
let prod = prod % 1_000_000_007_u64;
prod.to_string().parse::<u64>().unwrap()
}
fn big_pow(base: u64, exp: u64) -> BigUint {
let mut result = BigUint::from(1 as u64);
let mut b = BigUint::from(base);
let mut e = exp;
while e > 0 {
if e % 2 == 1 {
result = result * &b;
}
e = e >> 1;
b = &b * &b;
}
result
}
fn factors_map(n:u64) -> HashMap<u64, u64> {
let mut map = HashMap::new();
let all_factors = primes::factors(n);
for f in &all_factors {
let v = map.get(f).cloned(); // 如果不写cloned(),有警告,不理解原因
match v {
Some(x) => {
map.insert(*f, x + 1);
}
None => {
map.insert(*f, 1);
}
}
}
map
}
fn map_add(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x + count);
}
None => {
map.insert(*f, *count);
}
}
}
}
fn map_add_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x + count * times);
}
None => {
map.insert(*f, *count * times);
}
}
}
}
fn map_substract(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x - count);
}
None => {}
}
}
}
程序并没有想像中提高运行速度,还得想办法。
第六步 倒数的模运算
问题仍出在计算因子和的公式上,计算乘方时用大整数库,当数字越来越大时,速度越来越慢。
这时得利用模运算的数学知识了,模运算有一些基本性质:如果 a ≡ b (mod m) 且有 c ≡ d (mod m),那么下面的模运算律成立:
a + c ≡ b + d (mod m)
a − c ≡ b − d (mod m)
a × c ≡ b × d (mod m)
这里可以看到加、减、乘都满足这种同余定理,但没有除法!
问题没完,在数学运算中:a / b == a * (1 / b),其中1 / b叫b的倒数。那么在模运算中,有没有这种类似于倒数的存在呢?答案是有!它就是乘法逆元!
相关文章链接: http://conw.net/archives/6/
有人已经写好了Rust的代码,直接可以拿来用。
http://rosettacode.org/wiki/Modular_inverse#Rust
pow()乘方函数也有带模运算的Rust代码,直接拿来用。
https://rob.co.bb/posts/2019-02-10-modular-exponentiation-in-rust/
这次优化彻底去掉了大整数运算,速度提升了不知多少倍。
在我的笔记本电脑上,8秒钟计算出S(20000)!
最终的源代码:
use std::collections::HashMap;
const MODULUS: u64 = 1_000_000_007_u64;
fn main() {
let mut factorial_factors = HashMap::new();
let mut map = HashMap::new();
let mut s = 1;
for n in 2..=20000 {
let factor_n = factors_map(n);
map_add_count(&mut map, &factor_n, n);
// 缓存 n! 阶乘的因子
map_add(&mut factorial_factors, &factor_n);
map_substract(&mut map, &factorial_factors);
//println!("{:?}", &map);
let d = map_sum(&map);
s = (s + d) % MODULUS;
if n == 10 {
assert_eq!(s, 141740594713218418 % MODULUS);
}
if n == 100 {
assert_eq!(s, 332792866_u64);
}
if n % 100 == 0 {
println!("D({}) = {:?} \t S({}) = {}", n, d, n, s);
}
}
println!("{}", s);
}
// http://rosettacode.org/wiki/Modular_inverse#Rust
// 求乘法逆元
fn mod_inv(a: isize, module: isize) -> isize {
let mut mn = (module, a);
let mut xy = (0, 1);
while mn.1 != 0 {
xy = (xy.1, xy.0 - (mn.0 / mn.1) * xy.1);
mn = (mn.1, mn.0 % mn.1);
}
while xy.0 < 0 {
xy.0 += module;
}
xy.0
}
// https://rob.co.bb/posts/2019-02-10-modular-exponentiation-in-rust/
fn mod_pow(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
let mut result = 1;
base = base % modulus;
while exp > 0 {
if exp % 2 == 1 {
result = result * base % modulus;
}
exp = exp >> 1;
base = base * base % modulus
}
result
}
fn map_sum(map: &HashMap<u64, u64>) -> u64 {
let mut prod = 1;
for (&f, count) in map {
if *count > 0 {
// 计算 f^(count+1) / (f-1)
let temp = mod_pow(f, count + 1, MODULUS) - 1;
// 计算(f-1)的乘法逆元
let inv = mod_inv(f as isize - 1, MODULUS as isize) as u64;
//println!("inv:{}", inv);
let temp = temp * inv % MODULUS;
prod = prod * temp % MODULUS;
//println!("f:{} count+1:{} prod: {}", f, count+1, prod);
}
}
prod
}
fn factors_map(n:u64) -> HashMap<u64, u64> {
let mut map = HashMap::new();
let all_factors = primes::factors(n);
for f in &all_factors {
let v = map.get(f).cloned(); // 如果不写cloned(),有警告,不理解原因
match v {
Some(x) => {
map.insert(*f, x + 1);
}
None => {
map.insert(*f, 1);
}
}
}
map
}
fn map_add(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x + count);
}
None => {
map.insert(*f, *count);
}
}
}
}
fn map_add_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x + count * times);
}
None => {
map.insert(*f, *count * times);
}
}
}
}
fn map_substract(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
for (f, count) in a {
let v = map.get(f).cloned();
match v {
Some(x) => {
map.insert(*f, x - count);
}
None => {}
}
}
}
优化要点:
1)不要轻易用大整数运算库
2)质因子分解
3)同余定理
4)Exponentiation by squaring乘方运算加速
5)找到递推公式
6)乘法逆元
卡了我好几天的优化算法是因为以前不知道“乘法逆元”的存在。
--- END ---