前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >通过欧拉计划学Rust编程(第650题)

通过欧拉计划学Rust编程(第650题)

作者头像
申龙斌
发布2020-02-17 12:59:24
7810
发布2020-02-17 12:59:24
举报
文章被收录于专栏:申龙斌的程序人生

由于研究Libra等数字货币编程技术的需要,学习了一段时间的Rust编程,一不小心刷题上瘾。

刷完欧拉计划中的63道基础题,能学会Rust编程吗?

“欧拉计划”的网址: https://projecteuler.net

英文如果不过关,可以到中文翻译的网站: http://pe-cn.github.io/

这个网站提供了几百道由易到难的数学问题,你可以用任何办法去解决它,当然主要还得靠编程,编程语言不限,论坛里已经有Java、C#、Python、Lisp、Haskell等各种解法,当然如果你直接用google搜索答案就没任何乐趣了。

这次解答的是第650题:

https://projecteuler.net/problem=650

题目描述:

这道题挺折磨人,花了我好几天时间才搞定。

如果你会其它编程语言,也可以一试。

解题过程:

遇到一个复杂的问题,可以尝试先解决简单的情况,然后慢慢逼近最终的问题。

第一步: 手工试算几个,找找规律

可以看到B(n)的求解与排列组合数的因子有关系。

代码语言:javascript
复制
B(2) = 2^1
B(3) = 3^2
B(4) = (2^5) * (3^1)
B(5) = (2^2) * (5^4)

如果一个数能够分解成几个素数乘积,利用等比数列的求和公式,可以推导出所有因子之和的公式:

公式中只有2个不同的质因子,可以推广到更多质因子的情况。

现在可以暴力计算了:

代码语言:javascript
复制
extern crate num_bigint;
use num_bigint::BigUint;

fn main() {
    // 先把一些阶乘计算好,保存起来
    let mut fact = vec![BigUint::from(1 as u64); 101];
    let mut a = BigUint::from(1 as u64);
    for n in 2..=100 {
        a *= BigUint::from(n as u64);
        fact[n] = a.clone();
    }
    //println!("{:?}", fact);

    let mut s = 0;
    for n in 1..=10 {
        let mut prod = BigUint::from(1_u64);
        for r in 1..=n {
            let comb = &fact[n] / &fact[r] / &fact[n - r];
            prod *= &comb;
            //println!("{} {} {}", n, r, comb.to_string());
        }
        let b = prod.to_string().parse::<u64>().unwrap();
        println!("B({}) = {}", n, b);

        let f_all = primes::factors(b);
        let f_uniq = primes::factors_uniq(b);
        //println!("{:?}", f_all);
        //println!("{:?}", f_uniq);

        let mut d = 1;
        for f in f_uniq {
            let c = f_all.iter().filter(|&n| *n == f).count();
            d *= (f.pow(c as u32 + 1) - 1) / (f-1);
        }
        //println!("D({}) = {}", n, d);

        s += d;
        println!("S({}) = {}", n, s);
    }
}

可以正确地计算出S(10),但在计算S(11)时就会出现溢出错误。

第二步:

溢出发生在pow()函数的计算上,求排列组合数和求阶乘的运算量太大,没必要把乘积计算出来,可以将因子保存在一个向量中,不断添加和删除相应的元素即可。

代码语言:javascript
复制
extern crate num_bigint;
use num_bigint::BigUint;

fn main() {
    let mut s = 0;
    for n in 1..=100 {
        let mut factors = vec![];
        for i in 1..=n {
            let mut f = comb_factors(n, i);
            factors.append(&mut f);
        }
        factors.sort();

        let d = factors_sum(&factors);
        println!("D({}) = {:?}", n, d);

        s = (s + d) % 1_000_000_007_u64;
        println!("S({}) = {}", n, s);
    }
}

fn factors_sum(v: &Vec<u64>) -> u64 {
    let mut uniq = v.clone();
    uniq.dedup();

    let mut prod = BigUint::from(1_u64);
    for p in uniq {
        let c = v.iter().filter(|&x| *x == p).count() as u64;
        let t = (big_pow(p, c + 1) - BigUint::from(1_u64)) / (BigUint::from(p - 1));
        //println!("{} {} {}", p, c, t);
        prod = prod * t % 1_000_000_007_u64;
    }
    let prod = prod % 1_000_000_007_u64;
    prod.to_string().parse::<u64>().unwrap()
}

fn big_pow(a: u64, b: u64) -> BigUint {
    let mut prod = BigUint::from(1 as u64);
    for _i in 0..b {
        prod *= BigUint::from(a as u64);
    }
    prod
}

// va中元素已经排好序
fn vec_remove(va: &mut Vec<u64>, vb: &Vec<u64>) {
    for item in vb {
        let index = va.binary_search(&item).unwrap();
        //println!("{:?} {:?} {}", va, vb, index);
        va.remove(index);
    }
}

fn comb_factors(m: u64, n: u64) -> Vec<u64> {
    let mut factors = vec![];
    let mut x = m;
    for i in 0..n {
        let mut f = primes::factors(x);
        factors.append(&mut f);
        x -= 1;
    }
    factors.sort();
    //println!("{:?}", factors);
    for i in 2..=n {
        let f = primes::factors(i);
        //println!("{} {:?}", n, f);
        vec_remove(&mut factors, &f);
    }
    factors.to_vec()
}

这次可以很容易地计算出S(100),但要计算S(1000)则要花相当长的时间,主要是因为数组越来越大,数组排序太花时间。

还得优化。

第三步 用HashMap

数组中元素的排序太慢,尝试换成字典来存储因子,在Rust中用HashMap来实现。比如,B(10)中含有12个2,10个3,8个5,3个7。

而B(100)中含有如下因子,字典中key是质因子,value是个数。

代码语言:javascript
复制
{97: 93, 19: 65, 83: 65, 41: 44, 59: 17, 79: 57, 61: 21, 67: 33, 2: 335, 3: 192, 37: 20, 23: 56, 31: 69, 47: 80, 13: 21, 71: 41, 7: 148, 73: 45, 11: 81,
43: 56, 17: 5, 53: 5, 5: 176, 89: 77, 29: 45}

修改后的代码:

代码语言:javascript
复制
extern crate num_bigint;
use num_bigint::BigUint;

use std::collections::HashMap;
fn main() {
    let mut s = 0;
    for n in 1..=10000 {
        let map = comb_factors_hash_map(n);
        //println!("{:?}", map);

        let temp = factors_hash_map_sum(&map);
        //println!("D({}) = {:?}", n, temp);

        s = (s + temp) % 1_000_000_007_u64;
        println!("S({}) = {}", n, s);
    }
}

fn big_pow(a: u64, b: u64) -> BigUint {
    let mut prod = BigUint::from(1 as u64);
    for _i in 0..b {
        prod *= BigUint::from(a as u64);
    }
    prod
}

fn factors_hash_map_sum(map: &HashMap<u64, u64>) -> u64 {
    let mut prod = BigUint::from(1_u64);
    for (&f, count) in map {
        let t = (big_pow(f, count + 1) - BigUint::from(1_u64)) / (BigUint::from(f - 1));
        prod = prod * t % 1_000_000_007_u64;
    }
    let prod = prod % 1_000_000_007_u64;
    prod.to_string().parse::<u64>().unwrap()
}

fn comb_factors_hash_map(x: u64) -> HashMap<u64, u64> {
    let mut map = HashMap::new();

    let mut count = x as i64 - 1;
    for n in (2..=x).rev() {
        let f = primes::factors(n);
        let a = factors_to_hash_map(&f);
        if count >= 0 {
            hash_map_add_count(&mut map, &a, count as u64);
        } else {
            hash_map_substract_count(&mut map, &a, count.abs() as u64);
        }
        count -= 2;
    }
    map
}

fn factors_to_hash_map(factors: &Vec<u64>) -> HashMap<u64, u64> {
    let mut map = HashMap::new();
    for f in factors {
        let v = map.get(f).cloned(); // 如果不写cloned(),有警告,不理解原因
        match v {
            Some(x) => {
                map.insert(*f, x + 1);
            }
            None => {
                map.insert(*f, 1);
            }
        }
    }
    map
}

fn hash_map_add_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x + count * times);
            }
            None => {
                map.insert(*f, *count * times);
            }
        }
    }
}

fn hash_map_substract_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x - count * times);
            }
            None => {}
        }
    }
}

这次可以轻松地计算到S(1000),但后面越来越慢。还得继续优化。

第四步 加速pow()运算

计算一个数的n次方,有一个常用的算法可以加速,称为Exponentiation by squaring算法,相关网站:

https://en.wikipedia.org/wiki/Exponentiation_by_squaring

修改后程序运行速度提高,多运行一段时间,可以计算出S(5000)。

代码语言:javascript
复制
fn big_pow(base: u64, exp: u64) -> BigUint {
    let mut result = BigUint::from(1 as u64);
    let mut b = BigUint::from(base);
    let mut e = exp;
    while e > 0 {
        if e % 2 == 1 {
            result = result * &b;
        }
        e = e >> 1;
        b = &b * &b;
    }
    result
}

第五步 递推公式

程序运行几个小时,应该能够计算出S(20000),在它计算的过程中,我还要尝试进一步的优化。但在这之后,我开始走弯路了,尝试缓存一些中间的计算结果来进行加速,效果都不理想。

google "euler project problem 650",发现一篇文章。

http://goatleaps.xyz/euler/maths/Project-Euler-650.html

里面提到一个B(n-1)递推出B(n)的公式,可以进一步优化。

我也用笨办法推导出了这条公式:

代码语言:javascript
复制
use std::collections::HashMap;

const MODULUS: u64 = 1_000_000_007_u64;
extern crate num_bigint;
use num_bigint::BigUint;

fn main() {
    let mut factorial_factors = HashMap::new();
    let mut map = HashMap::new();
    let mut s = 1;
    for n in 2..=20000 {
        let factor_n = factors_map(n);
        map_add_count(&mut map, &factor_n, n);

        // 缓存 n! 阶乘的因子
        map_add(&mut factorial_factors, &factor_n);

        map_substract(&mut map, &factorial_factors);

        //println!("{:?}", &map);
        let d = map_sum(&map);
        
        s = (s + d) % MODULUS;
        if n == 10 {
            assert_eq!(s, 141740594713218418 % MODULUS);
        }
        if n == 100 {
            assert_eq!(s, 332792866_u64);
        }
        if n % 100 == 0 {
            println!("D({}) = {:?} \t S({}) = {}", n, d, n, s);
        }
    }
    println!("{}", s);

}

fn map_sum(map: &HashMap<u64, u64>) -> u64 {
    let mut prod = BigUint::from(1_u64);
    for (&f, count) in map {
        let t = (big_pow(f, count + 1) - BigUint::from(1_u64)) / (BigUint::from(f - 1));
        prod = prod * t % 1_000_000_007_u64;
    }
    let prod = prod % 1_000_000_007_u64;
    prod.to_string().parse::<u64>().unwrap()
}

fn big_pow(base: u64, exp: u64) -> BigUint {
    let mut result = BigUint::from(1 as u64);
    let mut b = BigUint::from(base);
    let mut e = exp;
    while e > 0 {
        if e % 2 == 1 {
            result = result * &b;
        }
        e = e >> 1;
        b = &b * &b;
    }
    result
}

fn factors_map(n:u64) -> HashMap<u64, u64> {
    let mut map = HashMap::new();
    let all_factors = primes::factors(n);
    for f in &all_factors {
        let v = map.get(f).cloned(); // 如果不写cloned(),有警告,不理解原因
        match v {
            Some(x) => {
                map.insert(*f, x + 1);
            }
            None => {
                map.insert(*f, 1);
            }
        }
    }
    map
}

fn map_add(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x + count);
            }
            None => {
                map.insert(*f, *count);
            }
        }
    }
}

fn map_add_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x + count * times);
            }
            None => {
                map.insert(*f, *count * times);
            }
        }
    }
}

fn map_substract(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x - count);
            }
            None => {}
        }
    }
}

程序并没有想像中提高运行速度,还得想办法。

第六步 倒数的模运算

问题仍出在计算因子和的公式上,计算乘方时用大整数库,当数字越来越大时,速度越来越慢。

这时得利用模运算的数学知识了,模运算有一些基本性质:如果 a ≡ b (mod m) 且有 c ≡ d (mod m),那么下面的模运算律成立:

代码语言:javascript
复制
a + c ≡ b + d (mod m)
a − c ≡ b − d (mod m)
a × c ≡ b × d (mod m)

这里可以看到加、减、乘都满足这种同余定理,但没有除法!

问题没完,在数学运算中:a / b == a * (1 / b),其中1 / b叫b的倒数。那么在模运算中,有没有这种类似于倒数的存在呢?答案是有!它就是乘法逆元!

相关文章链接: http://conw.net/archives/6/

有人已经写好了Rust的代码,直接可以拿来用。

http://rosettacode.org/wiki/Modular_inverse#Rust

pow()乘方函数也有带模运算的Rust代码,直接拿来用。

https://rob.co.bb/posts/2019-02-10-modular-exponentiation-in-rust/

这次优化彻底去掉了大整数运算,速度提升了不知多少倍。

在我的笔记本电脑上,8秒钟计算出S(20000)!

最终的源代码:

代码语言:javascript
复制
use std::collections::HashMap;

const MODULUS: u64 = 1_000_000_007_u64;

fn main() {
    let mut factorial_factors = HashMap::new();
    let mut map = HashMap::new();
    let mut s = 1;
    for n in 2..=20000 {
        let factor_n = factors_map(n);
        map_add_count(&mut map, &factor_n, n);

        // 缓存 n! 阶乘的因子
        map_add(&mut factorial_factors, &factor_n);

        map_substract(&mut map, &factorial_factors);

        //println!("{:?}", &map);
        let d = map_sum(&map);
        
        s = (s + d) % MODULUS;
        if n == 10 {
            assert_eq!(s, 141740594713218418 % MODULUS);
        }
        if n == 100 {
            assert_eq!(s, 332792866_u64);
        }
        if n % 100 == 0 {
            println!("D({}) = {:?} \t S({}) = {}", n, d, n, s);
        }
    }
    println!("{}", s);

}

// http://rosettacode.org/wiki/Modular_inverse#Rust
// 求乘法逆元
fn mod_inv(a: isize, module: isize) -> isize {
    let mut mn = (module, a);
    let mut xy = (0, 1);

    while mn.1 != 0 {
        xy = (xy.1, xy.0 - (mn.0 / mn.1) * xy.1);
        mn = (mn.1, mn.0 % mn.1);
    }

    while xy.0 < 0 {
        xy.0 += module;
    }
    xy.0
}

// https://rob.co.bb/posts/2019-02-10-modular-exponentiation-in-rust/
fn mod_pow(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
    let mut result = 1;
    base = base % modulus;
    while exp > 0 {
        if exp % 2 == 1 {
            result = result * base % modulus;
        }
        exp = exp >> 1;
        base = base * base % modulus
    }
    result
}


fn map_sum(map: &HashMap<u64, u64>) -> u64 {
    let mut prod = 1;
    for (&f, count) in map {
        if *count > 0 {
            // 计算 f^(count+1) / (f-1)
            let temp = mod_pow(f, count + 1, MODULUS) - 1;
            // 计算(f-1)的乘法逆元
            let inv = mod_inv(f as isize - 1, MODULUS as isize) as u64;
            //println!("inv:{}", inv);
            let temp = temp * inv % MODULUS;
            prod = prod * temp % MODULUS;
            //println!("f:{} count+1:{} prod: {}", f, count+1, prod);
        }
    }
    prod
}

fn factors_map(n:u64) -> HashMap<u64, u64> {
    let mut map = HashMap::new();
    let all_factors = primes::factors(n);
    for f in &all_factors {
        let v = map.get(f).cloned(); // 如果不写cloned(),有警告,不理解原因
        match v {
            Some(x) => {
                map.insert(*f, x + 1);
            }
            None => {
                map.insert(*f, 1);
            }
        }
    }
    map
}

fn map_add(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x + count);
            }
            None => {
                map.insert(*f, *count);
            }
        }
    }
}

fn map_add_count(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>, times: u64) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x + count * times);
            }
            None => {
                map.insert(*f, *count * times);
            }
        }
    }
}

fn map_substract(map: &mut HashMap<u64, u64>, a: &HashMap<u64, u64>) {
    for (f, count) in a {
        let v = map.get(f).cloned();
        match v {
            Some(x) => {
                map.insert(*f, x - count);
            }
            None => {}
        }
    }
}

优化要点:

1)不要轻易用大整数运算库

2)质因子分解

3)同余定理

4)Exponentiation by squaring乘方运算加速

5)找到递推公式

6)乘法逆元

卡了我好几天的优化算法是因为以前不知道“乘法逆元”的存在。

--- END ---

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-01-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 申龙斌的程序人生 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档