面试中的斐波那契数列问题

更新于 2024-05-24 04:14

起因和面试规则

很久以前,在我呆过的一个团队中,因为某些客观原因,比较难招到合适的程序员。为了改进招聘和面试的效率,我们改进了技术面试的流程。在前后端技术面试中,安排了一个现场写代码的环节,主要规则是这样的:

  • 时间不限语言不限使用的工具不限,
  • 可以 Google 可以百度,但不可向他人求助,
  • 每个问题的代码运行时间应该不超过 5 分钟,
  • 从 6 个大问题 14 个小问题中任选 N 个小问题解决,
  • 如果能成功解决了 2 个或 2 个以上小问题,则进入下一轮技术终面,否则面试流程结束。

题目

在数学上,著名的 斐波那契数列 / Fibonacci Sequence 以递归的方法来定义:

\begin{align} F_n = \begin{cases} 0 & \text{for} \;n = 0, \\ 1 & \text{for} \;n = 1, \\ F_{n-1} + F_{n-2} & \text{for} \; n \geq 2. \end{cases} \end{align}
  1. 试算出 F_{100} 的准确数值。
  2. 请问斐波那契数列中,第一个有 1000 位数字的是第几项?
  3. 试算出第 10 亿项除以 1000000007 的余数,即 F_{10^{9}} \ mod \ (10^{9}+7) .
  4. 试算出第 10^{1000} 项除以 1000000007 的余数,即 F_{10^{1000}} \ mod \ (10^{9}+7) .

虽然我们当时主要招 Scala 后端程序员,但这道面试题通常也用在前端程序员的面试中,规则也一样。

质疑

自从在知乎一个回答里抱怨了一篇,透露了这道题目以后,收到很多质疑:

  • 题目太简单了,一个循环就出来了。
  • 题目太简单了,一个递归就出来了。
  • 题目太简单了,一个矩阵快速幂就全秒了。
  • 题目太难了,搬转工程师不应该考算法。
  • 题目太难了,不应该考数学。
  • 题目太难了,不应该考递归。

为什么要考这道题?

出这道题,倒不是为了为难人选,只是为了初步筛选出对计算机有一定的理解,可以用代码解决某类问题(尤其是初级的抽象和计算),具有一定工程能力的人。这 个目标人选,我相信是可以通过测试大致上测出来的,即使测得不那么准,具有一定的假阳性错误(False Positive Errors), 也是可以在后续的面谈环节中加以鉴 别的。

这道题不是考察数学或者数论知识的。否则,第一个小问题会变成这样:

某人上台阶,可以一步上一级,也可以上两级。请问上100级的台阶,一共有多少种走法?

(这是一道考察斐波那契数列递归的经典小学奥数题。)

也不是考算法的,因为这么明显的问题,算法总是可以很容易 Google 到,而面试规则并不禁止 Google.

计算是需要认真对待的问题。在充分掌握了一门编程语言,已知数学公式和算法的条件下,完成一个正确,高效的计算实现往往不如预料中的容易。

举个明显的例子,当我们知道拉马努金这个著名的计算 \pi 的公式:

\frac{1}{\pi} = \frac{2\sqrt{2}}{99^2} \sum_{k=0}^{\infty} \frac{(4k)!(1103+26390k)}{(k!)^4 396^{4k}}

又已知这个公式的 k 每计算多 1 项,结果精度可增加 8 位。是不是非常浅显易懂,清晰明了?好,现在你尝试用这个公式计算 \pi 到小数点后一百万位吧。这是不是马上就不那么通俗易懂了?

一些典型的错误

当时包括本题在内的整套笔试题通过率只有 10\% ~ 20\% , 这道看起来人畜无害的斐波那契题是失败的重灾区。我总结了一些典型的错误,列举如下。

一上来就直接递归的,这类最为常见:

def fib(n):
    if n == 1 or n == 2:
        return 1
    else:
        return fib(n-1) + fib(n-2)

# 或者更简洁的, 一行搞掂
def fib(n): return fib(n-1) + fib(n-2) if (n > 2) else 1

这种直接递归的,无论机器性能多猛,一般在 n = 40 左右就是极限了,n = 100 更是递归到太阳熄灭都不会有结果。

犯这种错的,一般对算法的复杂度没有概念,仅仅会用简单的递归函数。我们是重度使用 Scala 函数式编程的团队,最怕的就是这种写法。

之所以把第一个小问题设计为 n = 100, 就是为了暴露这个问题。

第二类错误,不能处理大整数溢出的。

Java:

public class Fib {

    public static long fib(int n) {
        long a = 1L, b = 1L, c = 1L;
        for (int i = 1; i < n; i++) {
            c = b;
            b = a + b;
            a = c;
        }
        return c;
    }

    public static void main(String args[]) {
        System.out.println(fib(100)); // 3736710778780434371
    }
}

JavaScript:

function fib(n) {
    for (i = 0, x = 0, y = 1; i < n; i++) {
        var z = y
        y = x + y
        x = z
    }
    return x
}

console.log(fib(100)) // 354224848179262000000

还有用其他语言算出来负数的,不一一列举。

有人抱怨这里故意埋大整数溢出的坑考他们。我认为,能够正确地处理比较明显的数值溢出,是每一位程序员应该具备的基本素质。就算考虑实际的项目,很多时候 仍然需要注意数值的溢出。比如,我们以前就有过把某些业务数据的统计值类型错误地设计成 32 位整数,导致溢出。当时的业务是互联网广告,一张月报表的广告 请求数,真的很容易超过 21 亿的。那么我们无脑地把所有统计值的类型设计成 64 位整数不行么?好像也不行,因为这样又太浪费了。

有不少人选意识到了大整数溢出的问题,会直接用数组之类的实现大数加法。严格来说,这其实也可以归结为一种“错误”。当前主流的编程语言,他们犯的错叫做 “不熟悉自己使用的编程语言,重复制造无意义的轮子”。

第一个小问题设计成 n = 100 而不成 n = 50 的原因在于制造大于 64 位无符号整数的结果。美中不足的地方在于,用 Python 做题的人可能在毫无意识的 情况下避开了这个坑。

第三类错误:直接用通项公式算的。

我们知道,斐波那契数列有通项公式

F_n = \cfrac{1}{\sqrt {5}}[(\cfrac {1 + \sqrt {5}} {2})^{n} - (\cfrac {1 - \sqrt {5}} {2})^{n}]

直接拿这个公式去计算,一般是算不出正确答案的。犯这个错的人,一般对工程学意义上的“计算”是没有概念的。本题不太考查数学能力,因此即使知道了这个公式 也没什么作用。如果能用这个公式算出正确结果,说明已经能够从容处理高精度小数在大量计算中的问题,达到考察目的。

第四类错误:两个数颠来倒去反复加的时候,加错的。

def fib(n):
    if n == 1 or n == 2:
        return 1
    else:
        (a, b) = (1, 1)
        for i in range(n):
            b = a + b
            a = b
        return b

fib(10) # 1024 WTF!!!!

这类错误难以一一列举,基本属于久没写代码,或者不会写代码,一想好简单,一写就错的。

第五类错误:一上来就用动态规划(DP),但是实现错误的。实际上这个问题不属于典型的动态规划问题,最多只能算是动态规划的极端版 本。犯这个错的,一般刷过算法题,但对算法理解过浅,不能处理实际工程项目里更加普遍更加复杂的算法问题。

第六错误,一上来就祭出矩阵快速幂,因姿势不对而失败的。这种一般属于理论知识较好,但工程能力稍欠,实现已有算法能力不够强 的。

前3小题“标准”答案

因为可用的编程语言和算法不限,本题其实是没有绝对标准答案的。不过我们可以探讨下一些可行的做法。

第1小题

计算第 100​ 项的准确值。

直接循环

这个问题,其实非常奇怪:如果我们让面试者直接用笔算,相信 90% 的人算到 F_{50} 都是没问题的,然而写代码算,能算对的却很少。

其实这个这个问题,直接把笔算的过程翻译成代码就行了。

首先要解决大整数的问题。我们看看目前主流语言对大整数的支持情况:

  • Scala 有 BigInt

  • Java 有 BigInteger

  • JavaScript 有 BigInt (MDN文档)

  • Python 默认的 int 就是

  • Haksell 有 Integer

    Scala:

def fib(n: Int) = {
  var (a, b, c) = (BigInt(1), BigInt(1), BigInt(1))
  var i = 1
  while (i < n) {
    c = b
    b = a + b
    a = c
    i += 1
  }
  c
}

@main def main(): Unit = println(fib(100)) // 354224848179261915075

Java:

import java.math.BigInteger;

public class FibJava {

    public static BigInteger fib(int n) {
        BigInteger a = BigInteger.valueOf(1L), b = BigInteger.valueOf(1L), c = BigInteger.valueOf(1L);
        for (int i = 1; i < n; i++) {
            c = b;
            b = a.add(b);
            a = c;
        }
        return c;
    }

    public static void main(String args[]) {
        System.out.println(fib(100)); // 354224848179261915075
    }
}

Python:

def fib(n):
    a, b = 1, 1
    i = 1
    while (i < n):
        a, b = b, a + b
        i = i + 1
    return a

print(fib(100)) # 354224848179261915075

JavaScript:

function fib(n) {
    for (i = 0, x = 0n, y = 1n; i < n; i++) {
        y = x + y
        x = y - x
    }
    return x
}

console.log(fib(100)) // 354224848179261915075n

Haskell:

-- 这个循环的版本,本人不会

递归

这里要用尾递归

Scala:

@scala.annotation.tailrec
def fibRec(x: BigInt, y: BigInt, i: Int): BigInt = if (i == 0) x else fibRec(y, x + y, i - 1)

def fib(n: Int): BigInt = fibRec(0, 1, n)

@main def printFib100(): Unit = println(fib(100)) // 354224848179261915075

Java:

import java.math.BigInteger;

public class FibJava {

    private static BigInteger _fib(BigInteger x, BigInteger y, int i) {
        return (i == 0) ? x : _fib(y, x.add(y), i - 1);
    }

    public static BigInteger fib(int n) {
        return _fib(BigInteger.valueOf(0L), BigInteger.valueOf(1L), n);
    }

    public static void main(String args[]) {
        System.out.println(fib(100)); // 354224848179261915075
    }
}

Python

def fib(n):
    def f(x, y, i): return x if not i else f(y, x + y, i-1)
    return f(0, 1, n)

print(fib(100)) # 354224848179261915075

JavaScript:

"use strict";
function fib(n) {
    let f = (x, y, i) => (i === 0) ? x : f(y, x+y, i-1)
    return f(0n, 1n, n)
}

console.log(fib(100)) // 354224848179261915075n

Haskell:

fib' x y 0 = x
fib' x y i = fib' y (x + y) (i-1)

fib n = fib' 0 1 n

main = do print $ fib 100

其他解法:

Haksell 有个著名的生成公式:

fibs = 0 : 1 : (zipWith (+) fibs (tail fibs))
~ $ ghci
GHCi, version 8.6.5: http://www.haskell.org/ghc/  :? for help
λ> fibs = 0 : 1 : (zipWith (+) fibs (tail fibs))
λ> take 101 $ fibs
    [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597,
     2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418,
     317811, 514229, 832040, 1346269, 2178309, 3524578, 5702887, 9227465,
     14930352, 24157817, 39088169, 63245986, 102334155, 165580141, 267914296,
     433494437, 701408733, 1134903170, 1836311903, 2971215073, 4807526976,
     7778742049, 12586269025, 20365011074, 32951280099, 53316291173,
     86267571272, 139583862445, 225851433717, 365435296162, 591286729879,
     956722026041, 1548008755920, 2504730781961, 4052739537881, 6557470319842,
     10610209857723, 17167680177565, 27777890035288, 44945570212853,
     72723460248141, 117669030460994, 190392490709135, 308061521170129,
     498454011879264, 806515533049393, 1304969544928657, 2111485077978050,
     3416454622906707, 5527939700884757, 8944394323791464, 14472334024676221,
     23416728348467685, 37889062373143906, 61305790721611591,
     99194853094755497, 160500643816367088, 259695496911122585,
     420196140727489673, 679891637638612258, 1100087778366101931,
     1779979416004714189, 2880067194370816120, 4660046610375530309,
     7540113804746346429, 12200160415121876738, 19740274219868223167,
     31940434634990099905, 51680708854858323072, 83621143489848422977,
     135301852344706746049, 218922995834555169026, 354224848179261915075]
λ> fibs !! 100
    354224848179261915075

Scala 的 LazyList 和 Haskell 的 List 有相似的“懒惰”性,其实也可以做到:

$ scala-cli
Welcome to Scala 3.4.1 (21, Java OpenJDK 64-Bit Server VM).
Type in expressions for evaluation. Or try :help.

scala> val fibs: LazyList[BigInt] = BigInt(0) #:: BigInt(1) #:: fibs.zip(fibs.tail).map(x => x._1 + x._2)
val fibs: LazyList[BigInt] = LazyList(<not computed>)

scala> fibs.take(101).zipWithIndex.map(x => x._2 -> x._1).foreach(println)
(0,0)
(1,1)
(2,1)
(3,2)
(4,3)
(5,5)
(6,8)
(7,13)
(8,21)
(9,34)
(10,55)
(11,89)
(12,144)
(13,233)
(14,377)
(15,610)
(16,987)
(17,1597)
(18,2584)
(19,4181)
(20,6765)
(21,10946)
(22,17711)
(23,28657)
(24,46368)
(25,75025)
(26,121393)
(27,196418)
(28,317811)
(29,514229)
(30,832040)
(31,1346269)
(32,2178309)
(33,3524578)
(34,5702887)
(35,9227465)
(36,14930352)
(37,24157817)
(38,39088169)
(39,63245986)
(40,102334155)
(41,165580141)
(42,267914296)
(43,433494437)
(44,701408733)
(45,1134903170)
(46,1836311903)
(47,2971215073)
(48,4807526976)
(49,7778742049)
(50,12586269025)
(51,20365011074)
(52,32951280099)
(53,53316291173)
(54,86267571272)
(55,139583862445)
(56,225851433717)
(57,365435296162)
(58,591286729879)
(59,956722026041)
(60,1548008755920)
(61,2504730781961)
(62,4052739537881)
(63,6557470319842)
(64,10610209857723)
(65,17167680177565)
(66,27777890035288)
(67,44945570212853)
(68,72723460248141)
(69,117669030460994)
(70,190392490709135)
(71,308061521170129)
(72,498454011879264)
(73,806515533049393)
(74,1304969544928657)
(75,2111485077978050)
(76,3416454622906707)
(77,5527939700884757)
(78,8944394323791464)
(79,14472334024676221)
(80,23416728348467685)
(81,37889062373143906)
(82,61305790721611591)
(83,99194853094755497)
(84,160500643816367088)
(85,259695496911122585)
(86,420196140727489673)
(87,679891637638612258)
(88,1100087778366101931)
(89,1779979416004714189)
(90,2880067194370816120)
(91,4660046610375530309)
(92,7540113804746346429)
(93,12200160415121876738)
(94,19740274219868223167)
(95,31940434634990099905)
(96,51680708854858323072)
(97,83621143489848422977)
(98,135301852344706746049)
(99,218922995834555169026)
(100,354224848179261915075)

scala> fibs(100)
val res1: BigInt = 354224848179261915075

scala> fibs(1000).toString.length
val res1: Int = 209

scala> fibs(10000).toString.length
val res2: Int = 2090

scala> fibs(100000).toString.length
val res3: Int = 20899

第2小题

求第一个有 1000 位数字的项数

* 在做出第一问的情况下,这个小问题是很容易的。反正数据很小,直接暴力转成字符串再数一下就好了。

~ $ scala-cli
Welcome to Scala 3.4.1 (21, Java OpenJDK 64-Bit Server VM).
Type in expressions for evaluation. Or try :help.

scala> val fibs: LazyList[BigInt] = BigInt(0) #:: BigInt(1) #:: fibs.zip(fibs.tail).map(x => x._1 + x._2)
val fibs: LazyList[BigInt] = LazyList(<not computed>)

scala> fibs.zipWithIndex.dropWhile(_._1.toString.length < 1000).head._2
val res0: Int = 4782

验证一下:

scala> fibs(4781)
val res23: BigInt = 661337322839244020529448762061498838625822185407709121508540998594682904585308540198620399347330470400653280666540992810413273565832926917994398713638718230004471776413151146318909185589434770973461837586080318941036906375808818987789429626217366616261579890860369055555894732519548901040157604298546742957666096845809463021127999592568557636384690462092341124092401245911116667639704650585763476554594656814646920798543755041993472555505701157143290739289844688760895075749533130953287080934600205342326904398216342904642143410026582439596278979961166556037913414174756579068802168337413360918567937945101952123966744780579475400056938418442029794142856905251486526028946063939727834195575354173400454719814829506586601492013100468780852140836021109820860257494909387619656513364393990237289782342713423952649505274336811640790273740875206602634583227097792146230883268422965993230492657630338344967767776216497296016612316692840479306680187737608910337658122657075262622220319456797676633847360981

scala> fibs(4782)
val res24: BigInt = 1070066266382758936764980584457396885083683896632151665013235203375314520604694040621889147582489792657804694888177591957484336466672569959512996030461262748092482186144069433051234774442750273781753087579391666192149259186759553966422837148943113074699503439547001985432609723067290192870526447243726117715821825548491120525013201478612965931381792235559657452039506137551467837543229119602129934048260706175397706847068202895486902666185435124521900369480641357447470911707619766945691070098024393439617474103736912503231365532164773697023167755051595173518460579954919410967778373229665796581646513903488154256310184224190259846088000110186255550245493937113651657039447629584714548523425950428582425306083544435428212611008992863795048006894330309773217834864543113205765659868456288616808718693835297350643986297640660000723562917905207051164077614812491885830945940566688339109350944456576357666151619317753792891661581327159616877487983821820492520348473874384736771934512787029218636250627816

scala> fibs(4781).toString.length
val res25: Int = 999

scala> fibs(4782).toString.length
val res26: Int = 1000

4781 项刚好有 999 位数字,而第 4782 项刚好有 1000 位数字。

第3小题

计算第 10 亿项除以 1000000007 的余数,即 F_{10^{9}} \ mod \ (10^{9}+7)

为方便期间,我们记 f(n) = F_n\mod(10^9+7)

Scala 的暴力递归

很多朋友的第一反应是“快速矩阵幂”。其实,此时祭出快速矩阵幂完全是杀鸡用牛刀了。下面这个直接用尾递归暴力计算的办法,在我本人的电脑上只需要不到 3 秒钟就出结果了:

def fibM(n: Int) = {
  @scala.annotation.tailrec
  def f(x: Int, y: Int, i: Int): Int = if (i == 0) x else f(y, (x + y) % 1_000_000_007, i - 1)
  f(0, 1, n)
}

验证:

~ $ scala-cli
Welcome to Scala 3.4.1 (21, Java OpenJDK 64-Bit Server VM).
Type in expressions for evaluation. Or try :help.

scala> def fibM(n: Int) = {
         @scala.annotation.tailrec
         def f(x: Int, y: Int, i: Int): Int = if (i == 0) x else f(y, (x + y) % 1_000_000_007, i - 1)
         f(0, 1, n)
       }
def fibM(n: Int): Int

scala> def main() = {
         val t = System.currentTimeMillis
         val r = fibM(1_000_000_000)
         println(s"result : $r, time cost: ${System.currentTimeMillis-t} ms")
       }
def main(): Unit

scala> main()
result : 21, time cost: 2443 ms

需要注意的技巧:

  1. 在每次相加以后对 10^9+7 取模,极大减少计算量,否则性能会有几个数量级的差异。
  2. 使用 Int 而不是 LongBigInt 之类的大数结构,进一步减少运算代价。
  3. 因为编译器会把尾递归优化成循环,所以这个写法和直接写循环的效果一样。

其他语言的解法

C++ 的循环:

#include <iostream>

int fib(int n) {
    int a = 0, b = 1, c = 0;
    for (int i = 0 ; i < n ; i++) {
        c = b;
        b = (a + b) % 1000000007;
        a = c;
    }
    return a;
}

int main() {
    std::cout << fib(1000000000) << "\n";
    return 0;
}
~ $ g++ -O2 fibonacci.cc -o fibonacci-cpp
~ $ time ./fibonacci-cpp
21
./fibonacci-cpp  2.47s user 0.00s system 99% cpu 2.473 total

C++ 启用 gcc -O2 编译优化,运行耗时 2.47s .

Rust 尾递归:

fn main() {
    fn fib_m(n: i32) -> i32 {
        fn _fib(x: i32, y: i32, i: i32, n: i32) -> i32 {
            if i == n { x } else { _fib(y, (x+y) % 1000000007, i+1, n) }
        }
        _fib(0, 1, 0, n)
    }
    println!("fibM(10^9) = {}", fib_m(1_000_000_000));
}
~ $ time ./fibonacci_rs
fibM(10^9) = 21
./fibonacci_rs  2.53s user 0.00s system 99% cpu 2.500 total

Rust 这个解法耗时 2.50s.

JavaScript 循环:

function fibM(n) {
    for (i = 0, x = 0, y = 1; i < n; i++) {
        var z = y
        y = (x + y) % 1000000007
        x = z
    }
    return x
}

console.log(fibM(1000000000))
~ $ time node fibonacci-loop.js
21
node fibonacci-loop.js  4.08s user 0.01s system 99% cpu 4.088 total

JavaScript 这个解法通过 NodeJS 运行耗时 4.1s .

JavaScript 尾递归:

'use strict'; 

function fibM(n) {
    let f = (x, y, i) => (i === 0) ? x : f(y, (x+y) % 1000000007, i-1)
    return f(0, 1, n)
}

console.log(fibM(1000000000)) // RangeError: Maximum call stack size exceeded

栈溢出!

这是因为不同的 JavaScript 运行时对尾递归优化的支持不一样,当前的 NodeJS 还没有支持。

Python 循环:

def fib(n):
    a, b = 0, 1
    i = 1
    while (i <= n):
        a, b = b, (a + b) % 1000000007
        i = i + 1
    return a

print(fib(1000000000))

Python 循环测试

首先在一个比较旧的系统上,比较 Python 2.7, 3.7, 3.83.9 的执行效率:

~ $ time python2.7 fibonacci-loop.py
21
python2.7 fibonacci-loop.py  57.04s user 0.01s system 99% cpu 57.103 total
~ $ time python fibonacci-loop.py
21
python3.7 fibonacci-loop.py  108.94s user 0.02s system 99% cpu 1:49.00 total
~ $ time python3.8 fibonacci-loop.py
21
python3.8 fibonacci-loop.py  112.31s user 0.01s system 99% cpu 1:52.33 total
~ $ time python3.9 fibonacci-loop.py
21
python3.9 fibonacci-loop.py  120.74s user 0.01s system 99% cpu 2:00.81 total
~ $ time pypy fibonacci-loop.py
21
pypy fibonacci-loop.py  3.11s user 0.01s system 99% cpu 3.134 total
~ $ time pypy3 fibonacci-loop.py
21
pypy3 fibonacci-loop.py  3.13s user 0.01s system 99% cpu 3.150 total

然后在一个比较新的系统上(也就是写作本文的系统),比较 Python 2.7, 3.9, 3.103.11 的执行效率:

~ $ time python2.7 fibonacci-loop.py
21
python2.7 fibonacci-loop.py  53.61s user 0.01s system 99% cpu 53.764 total

~ $ time python3.9 fibonacci-loop.py
21
python3.9 fibonacci-loop.py  84.24s user 0.01s system 99% cpu 1:24.49 total

~ $ time python3.10 fibonacci-loop.py
21
python3.10 fibonacci-loop.py  82.86s user 0.01s system 99% cpu 1:23.10 total

~ $ time python3.11 fibonacci-loop.py
21
python3.11 fibonacci-loop.py  73.94s user 0.00s system 99% cpu 1:14.14 total

~ $ time python3.12 fibonacci-loop.py
21
python3.12 fibonacci-loop.py  81.61s user 0.01s system 99% cpu 1:21.83 total

$ time python3.13 fibonacci-loop.py
21
python3.13 fibonacci-loop.py  82.50s user 0.00s system 99% cpu 1:22.76 total

从两个系统上不同 Python 版本的运行结果可以看出,在直接循环的性能表现上,从 CPython 2.73.7 发生了断崖式下降,然后从 3.7 一路稳 定下降到 3.9, 然后稳定下来直到最新的 3.13 Alpha 5. 而从 pypy 相较于同期的 Cpython, 性能又呈飞跃式下降。

Python 尾递归:

def fib(n):
    def f(x, y, i): return x if not i else f(y, (x + y) % 1000000007, i-1)
    return f(0, 1, n)

print(fib(1000000000)) # RecursionError: maximum recursion depth exceeded

栈溢出。这是因为 Python 并不直接支持尾递归优化,递归无法优化为循环。

矩阵快速幂

公式

根据矩阵乘法的定义

\begin{align*} \begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix} \begin{bmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \end{bmatrix} &= \begin{bmatrix} a_{11}b_{11} + a_{12}b_{21} & a_{11}b_{12} + a_{12}b_{22} \\ a_{21}b_{11} + a_{21}b_{22} & a_{21}b_{12} + a_{22}b_{22} \end{bmatrix} \\ \begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix} \begin{bmatrix} b_{11} \\ b_{21} \end{bmatrix} &= \begin{bmatrix} a_{11}b_{11} + a_{12}b_{21} \\ a_{21}b_{11} + a_{21}b_{22} \end{bmatrix}\\ \end{align*}

\begin{align*} \begin{bmatrix} F_{n} \\ F_{n -1} \end{bmatrix} &= \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix} \begin{bmatrix} F_{n-1} \\ F_{n - 2} \end{bmatrix} \\ &= \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix}^2 \begin{bmatrix} F_{n - 2} \\ F_{n - 3} \end{bmatrix} \\ &= \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix}^3 \begin{bmatrix} F_{n - 3} \\ F_{n - 4} \end{bmatrix} \\ &= \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix}^{n-1} \begin{bmatrix} F_2 \\ F_1 \end{bmatrix} \\ &= \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix}^n \begin{bmatrix} F_1 \\ F_0 \end{bmatrix} \\ \end{align*}

由此得到

\begin{align} \begin{bmatrix} F_{n} \\ F_{n -1} \end{bmatrix} &= \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix}^n \begin{bmatrix} 1 \\ 0 \end{bmatrix} \\ \end{align}

下面用数学归纳法证明:

\begin{align} \begin{bmatrix} F_{n + 1} && F_{n} \\ F_{n} && F_{n-1} \end{bmatrix} = \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix}^n = A^n, n \geq 1 \end{align}

显然 n = 1 时上式成立。

现在假设 n=k 时上式成立,即

\begin{align*} \begin{bmatrix} F_{k + 1} && F_{k} \\ F_{k} && F_{k-1} \end{bmatrix} = \begin{bmatrix} 1 && 1 \\ 1 && 0 \end{bmatrix}^k = A^k \end{align*}

在上式两边同时乘以 \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}, 容易验证得当 n=k+1 时式子也成立,证毕。

Scala 实现

将一个数 n 转换为它的二进制表示时,可以使用以下算法:

不断进行 n = \lfloor n / 2 \rfloor 的操作,并且将每一步的 n2 的结果收集起来,最后倒序排列。

例如,将十进制数 13 转换为二进制:

\begin{align*} 13 / 2 &= 6 \text{ 余 } 1 \\ 6 / 2 &= 3 \text{ 余 } 0 \\ 3 / 2 &= 1 \text{ 余 } 1 \\ 1 / 2 &= 0 \text{ 余 } 1 \\ \end{align*}

因此,数 13 的二进制表示为 1101 .

把这个过程逆过来,就是快速乘到 13 的一个路径:

\begin{align*} 13 / 2 &= 6 \text{ 余 } 1, 6 \cdot 2 + 1 = 13, \\ 6 / 2 &= 3 \text{ 余 } 0, 3 \cdot 2 + 0 = 6, \\ 3 / 2 &= 1 \text{ 余 } 1, 1 \cdot 2 + 1 = 3, \\ 1 / 2 &= 0 \text{ 余 } 1, 0 \cdot 2 + 1 = 1, \\ \end{align*}

上述过程用 Scala 翻译为矩阵的快速乘法,就是

def fastPow(n: BigInt): Matrix = {
  @scala.annotation.tailrec
  def f(x: Matrix, z: Matrix, ys: List[Char]): Matrix =
    ys match {
      case '0' :: xs => f(x * x, z, xs)
      case '1' :: xs => f(x * x, x * z, xs)
      case _         => z
    }
  val A = Matrix(1, 1, 1, 0)
  val I = Matrix(1, 0, 0, 1)
  f(A, I, toBinary(n).toList)
}

其中, I 是单位矩阵 \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}, 而 A 是需要反复相乘的矩阵 \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}.

于是我们可以写出整个快速矩阵幂的完整实现:


object FastPowMatrix {

  val Q = 1000000007L

  final case class Matrix(a11: Long, a12: Long, a21: Long, a22: Long) {

    /** 取模矩阵乘法
      * @param x
      *   被乘数
      * @param m
      *   模
      * @return
      */
    def *(x: Matrix, m: Long = Q): Matrix =
      Matrix(
        a11 = (a11 * x.a11 + a12 * x.a21) % m,
        a12 = (a11 * x.a12 + a12 * x.a22) % m,
        a21 = (a21 * x.a11 + a22 * x.a21) % m,
        a22 = (a21 * x.a12 + a22 * x.a22) % m
      )

    def fastPow(n: Long): Matrix = {
      @scala.annotation.tailrec
      def f(x: Matrix, z: Matrix, ys: List[Char]): Matrix =
        ys match {
          case '0' :: xs => f(x * x, z, xs)
          case '1' :: xs => f(x * x, x * z, xs)
          case _         => z
        }
      val A                                               = Matrix(1, 1, 1, 0)
      val I                                               = Matrix(1, 0, 0, 1)
      f(A, I, toBinary(n).toList)
    }

    private def toBinary(n: Long): List[Char] = {
      @scala.annotation.tailrec
      def f(n: Long, xs: List[Char]): List[Char] =
        if (n == 0L) xs else f(n >> 1, (if ((n & 1L) == 1L) '1' else '0') :: xs)
      if (n == 0L) List('0') else f(n, Nil).reverse
    }

    infix def ^(p: Long): Matrix   = fastPow(p)
  }

  def fibM(n: Long): Long = {
    val m = Matrix(1, 1, 1, 0)
    val r = m ^ n
    r.a21
  }
}

def testFastPowMatrix(n: Long): Unit = {
  import FastPowMatrix.*
  val t0 = System.nanoTime()
  val r  = fibM(n)
  val t1 = System.nanoTime()
  println(s"fibM($n) = $r, time cost: ${(t1 - t0) / 1_000_000.0} ms")
}

验证一下:

scala> testFastPowMatrix(1_000_000_000L)
fibM(1000000000) = 21, time cost: 0.034571 ms

emmmmmm,只需要 0.035ms. 继续加大剂量:

scala> testFastPowMatrix(1_000_000_000L)
fibM(1000000000) = 21, time cost: 0.034571 ms

scala> testFastPowMatrix(10_000_000_000L)
fibM(10000000000) = 815449418, time cost: 0.03718 ms

scala> testFastPowMatrix(100_000_000_000L)
fibM(100000000000) = 224788301, time cost: 0.039211 ms

scala> testFastPowMatrix(1_000_000_000_000L)
fibM(1000000000000) = 730695249, time cost: 0.038571 ms

scala> testFastPowMatrix(1000_000_000_000_000L)
fibM(1000000000000000) = 648325137, time cost: 0.0467 ms

scala> testFastPowMatrix(1_000_000_000_000_000_000L)
fibM(1000000000000000000) = 209783453, time cost: 0.054571 ms

scala> testFastPowMatrix(Long.MaxValue)
fibM(9223372036854775807) = 884968410, time cost: 0.064401 ms

可见我们一直加大 nLong.MaxValue = 9223372036854775807 也即 2^{63} - 1, 仍然只花费了不到 0.07ms.

这其实很好理解,因为我们的快速幂算法时间复杂度为 O(\log_2{n}) .

由此我们可以猜测,其实这个算法对第3小题求 f(10^{1000}) = F_{10^{1000}} \ mod \ (10^{9}+7) 也是胜任的,只是这个 10^{1000} 远超我们实现中的 Long 数值范畴,需要考虑能处理 nBigInt 的实现。

可以考虑一个基于数值泛型的实现,因为不管输入 n 为什么类型,只需要取得 n 的二进制表示即可。考虑以下 type class:

given BinaryNumeric[Long] with {
  def half(x: Long): Long      = x >> 1
  def isZero(x: Long): Boolean = x == 0L
  def isOdd(x: Long): Boolean  = (x & 1L) == 1L
}

重写 toBinary() 使之泛型化:

def toBinary[T](n: T)(using ev: BinaryNumeric[T]): List[Char] = {
  @scala.annotation.tailrec
  def f(n: T, xs: List[Char]): List[Char] =
    if (ev.isZero(n)) xs else f(ev.half(n), (if (ev.isOdd(n)) '1' else '0') :: xs)
  if (ev.isZero(n)) List('0') else f(n, Nil).reverse
}

或者直接用 Scala 标准库里的 type class Integral 更方便一些(虽然会带来轻微的性能下降):

object FastPowMatrix {
  val Q = 1000000007L

  def toBinary[T](n: T)(using ev: Integral[T]): List[Char] = {
    val two                                 = ev.fromInt(2)
    @scala.annotation.tailrec
    def f(n: T, xs: List[Char]): List[Char] =
      if (n == ev.zero) xs else f(ev.quot(n, two), (if (ev.rem(n, two) == ev.one) '1' else '0') :: xs)
    if (n == ev.zero) List('0') else f(n, Nil).reverse
  }

  final case class Matrix(a11: Long, a12: Long, a21: Long, a22: Long) {

    /** 取模矩阵乘法
      * @param x
      *   被乘数
      * @param m
      *   模
      * @return
      */
    infix def *(x: Matrix, m: Long = Q): Matrix =
      Matrix(
        a11 = (a11 * x.a11 + a12 * x.a21) % m,
        a12 = (a11 * x.a12 + a12 * x.a22) % m,
        a21 = (a21 * x.a11 + a22 * x.a21) % m,
        a22 = (a21 * x.a12 + a22 * x.a22) % m
      )

    def fastPow[T: Integral](n: T): Matrix = {
      @scala.annotation.tailrec
      def f(x: Matrix, z: Matrix, ys: List[Char]): Matrix =
        ys match {
          case '0' :: xs => f(x * x, z, xs)
          case '1' :: xs => f(x * x, x * z, xs)
          case _         => z
        }
      val A                                               = Matrix(1, 1, 1, 0)
      val I                                               = Matrix(1, 0, 0, 1)
      f(A, I, toBinary(n).toList)
    }

    infix def ^^[T: Integral](n: T): Matrix = fastPow(n)
  }

  private def fibM_[T: Integral](n: T): Long = {
    val m = Matrix(1, 1, 1, 0)
    val r = m ^^ n
    r.a21
  }

  def fibM(n: Int | Long | BigInt): Long = n match
  case n: Int    => fibM_(n)
  case n: Long   => fibM_(n)
  case n: BigInt => fibM_(n)
}

def testFastPowMatrix(n: Int | Long | BigInt): Unit = {
  import FastPowMatrix.*
  val t0 = System.nanoTime()
  val r  = fibM(n)
  val t1 = System.nanoTime()
  println(s"time cost: ${(t1 - t0) / 1_000_000.0} ms, fib($n) = $r")
}

直接令 n=10^{1000} 代入运行试试看:

scala> val b = List.fill(1000)(BigInt(10)).product
val b: BigInt = 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

scala> testFastPowMatrix(b)
time cost: 6.901779 ms, result = 552179166

只需要 6.9ms . 这其实很好理解,算法时间复杂度为 O(\log_2{n}) , 而 \log_2{10^{1000}} \approx 1000 \cdot \log_2{10}.

至此,本题已经完全解决。

超级矩阵快速幂

\begin{align*} A^{p + q} &= A^p \cdot A^q \\ A^{p \cdot q} &= (A^p)^q \\ A^{p^q} &= A^{p^{q - 1} \cdot p} \\ \end{align*}

\begin{align*} A^{p^q} &= A^{p^{q - 1} \cdot p} \\ &= (A^{p^{q - 1}})^ p \\ &= ((A^{p^{q - 2}})^p)^ p \\ &= {\left(A^{p^{q - 1}}\right)}^p\\ &= {\left({\left(A^{p^{q - 2}}\right)}^p\right)}^p\\ &= {\left({\left({\left(A^{p^{q - 3}}\right)}^p\right)}^p\right)}^p\\ &= \underbrace{{\left({\left({\left(A^p\right)}^p \cdots \right)}^p\right)}^p}_{q}\\ \end{align*}

就是说,如果我们要计算 A^{p^q}, 只需要将 A 不断作 p 次方操作,连续做 q 次就好了。

于是很容易在原来的实现基础上写出更快的改进:

package org.weiwen.math

object FastPowMatrix {

  val Q = 1000000007L

  def toBinary[T](n: T)(using ev: Integral[T]): List[Char] = {
    val two                                 = ev.fromInt(2)
    @scala.annotation.tailrec
    def f(n: T, xs: List[Char]): List[Char] =
      if (n == ev.zero) xs else f(ev.quot(n, two), (if (ev.rem(n, two) == ev.one) '1' else '0') :: xs)
    if (n == ev.zero) List('0') else f(n, Nil).reverse
  }

  final case class Matrix(a11: Long, a12: Long, a21: Long, a22: Long) {

    /** 取模矩阵乘法
      * @param x
      *   被乘数
      * @param m
      *   模
      * @return
      */
    infix def *(x: Matrix, m: Long = Q): Matrix =
      Matrix(
        a11 = (a11 * x.a11 + a12 * x.a21) % m,
        a12 = (a11 * x.a12 + a12 * x.a22) % m,
        a21 = (a21 * x.a11 + a22 * x.a21) % m,
        a22 = (a21 * x.a12 + a22 * x.a22) % m
      )

    // for A^n
    def fastPow[T: Integral](n: T): Matrix = fastPow(toBinary(n))

    def fastPow[T: Integral](binaryList: List[Char]): Matrix = {
      @scala.annotation.tailrec
      def f(x: Matrix, z: Matrix, ys: List[Char]): Matrix =
        ys match {
          case '0' :: xs => f(x * x, z, xs)
          case '1' :: xs => f(x * x, x * z, xs)
          case _         => z
        }
      f(this, Matrix(1, 0, 0, 1), binaryList)
    }

    // for A^{p^q}
    def fastPow[T: Integral](p: T, q: Long): Matrix = {
      val binaryList                    = toBinary(p)
      @scala.annotation.tailrec
      def f(z: Matrix, r: Long): Matrix = if r == 0L then z else f(z.fastPow(binaryList), r - 1)
      q match
        case 0L => this // A^{p^0} = A^1 = A
        case _  => f(this.fastPow(p), q - 1)
    }

    // for A^n
    infix def ^^[T: Integral](n: T): Matrix = fastPow(n)

    // for A^{p^q}
    infix def ^^[T: Integral](p: T, q: Long): Matrix = fastPow(p, q)
  }

  val A = Matrix(1, 1, 1, 0)

  def fibM_[T: Integral](n: T): Long = (A ^^ n).a21

  def fibM_[T: Integral](p: T, q: Long): Long = (A ^^ (p, q)).a21

  def fibM(n: Int | Long | BigInt): Long = n match
    case n: Int    => fibM_(n)
    case n: Long   => fibM_(n)
    case n: BigInt => fibM_(n)

  def fibM(p: Int | Long | BigInt, q: Long): Long = p match
    case p: Int    => fibM_(p, q)
    case p: Long   => fibM_(p, q)
    case p: BigInt => fibM_(p, q)
}

def testFastPowMatrix(n: Int | Long | BigInt): Unit = {
  import FastPowMatrix.*
  val t0 = System.nanoTime()
  val r  = fibM(n)
  val t1 = System.nanoTime()
  println(s"time cost: ${(t1 - t0) / 1_000_000.0} ms, result = $r")
}

def testFastPowMatrix(p: BigInt, q: Long): Unit = {
  import FastPowMatrix.*
  val t0 = System.nanoTime()
  val r  = fibM(p, q)
  val t1 = System.nanoTime()
  println(s"time cost: ${(t1 - t0) / 1_000_000.0} ms, result = $r, p = $p, q = $q")
}  

测试一下:

scala> testFastPowMatrix(10, 9)
time cost: 48.3245 ms, result = 21, p = 10, q = 9

scala> testFastPowMatrix(10, 9)
time cost: 0.186812 ms, result = 21, p = 10, q = 9

scala> testFastPowMatrix(10, 9)
time cost: 0.191382 ms, result = 21, p = 10, q = 9

scala> testFastPowMatrix(10, 1000)
time cost: 1.510469 ms, result = 552179166, p = 10, q = 1000

scala> testFastPowMatrix(10, 1000)
time cost: 0.612577 ms, result = 552179166, p = 10, q = 1000

scala> testFastPowMatrix(10, 1000)
time cost: 0.539977 ms, result = 552179166, p = 10, q = 1000

scala> testFastPowMatrix(10, 9)
time cost: 0.141252 ms, result = 21, p = 10, q = 9

scala> testFastPowMatrix(10, 1000)
time cost: 0.687648 ms, result = 552179166, p = 10, q = 1000

scala> testFastPowMatrix(10, 10000)
time cost: 3.667316 ms, result = 508797063, p = 10, q = 10000

scala> testFastPowMatrix(10, 100_000)
time cost: 16.345542 ms, result = 322994487, p = 10, q = 100000

scala> testFastPowMatrix(10, 1_000_000)
time cost: 168.761524 ms, result = 305562778, p = 10, q = 1000000

scala> testFastPowMatrix(10, 10_000_000)
time cost: 730.109804 ms, result = 63990505, p = 10, q = 10000000

scala> testFastPowMatrix(10, 100_000_000)
time cost: 4525.576374 ms, result = 761244216, p = 10, q = 100000000

可以看到计算 f(10^{1000}) 的耗时从之前的 6.9ms 下降到 0.69ms, 而计算 f(10^{100000000}) = f(10^{10^8}) 也仅需 4.5s.

注意到 10^{10^8} = 10^{100000000} = 10^{1000 \cdot 100000} = (10^{1000})^{100000}, testFastPowMatrix(10, 100_000_000) 可以用 testFastPowMatrix(10^1000, 100_000) 替代计算:

scala> testFastPowMatrix(List.fill(1000)(BigInt(10)).product, 100_000)
time cost: 3268.0884 ms, result = 761244216, p = 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, q = 100000

使用这个小技巧,成功地把计算 f(10^{10^8}) 的耗时从 4.5s 缩减到 3.3s.

注意到特征矩阵 Matrix 中, a12 总是和 a21 相等,可以从 case class 中去掉 a21, 进一步优化计算代价。重构如下:

object FibModMatrix {

  val Q = 1000000007L

  def toBinary[T](n: T)(using ev: Integral[T]): List[Char] = {
    val two                                 = ev.fromInt(2)
    @scala.annotation.tailrec
    def f(n: T, xs: List[Char]): List[Char] =
      if (n == ev.zero) xs else f(ev.quot(n, two), (if (ev.rem(n, two) == ev.one) '1' else '0') :: xs)
    if (n == ev.zero) List('0') else f(n, Nil).reverse
  }

  trait ModMatrixOps[M <: ModMatrixOps[M]] { this: M =>

    // modulus
    def mod: Long

    // identity matrix
    def I: M

    //  characteristic matrix
    def A: M

    infix def *(y: M): M

    def fastPow[T: Integral](n: T): M = {
      val bs = toBinary(n)
      fastPow(bs)
    }

    def fastPow[T: Integral](binaryList: List[Char]): M = {
      @scala.annotation.tailrec
      def f(x: M, z: M, ys: List[Char]): M =
        ys match {
          case '0' :: xs => f(x * x, z, xs)
          case '1' :: xs => f(x * x, x * z, xs)
          case _         => z
        }
      f(this, I, binaryList)
    }

    def fastPow[T: Integral](p: T, q: Long): M = {
      val binaryList          = toBinary(p)
      @scala.annotation.tailrec
      def f(z: M, r: Long): M = if r == 0L then z else f(z.fastPow(binaryList), r - 1)
      q match
        case 0L => this // A^{p^0} = A^1 = A
        case _  => f(this.fastPow(p), q - 1)
    }
  }

  final case class FibModMatrix(a11: Long, a12: Long, a22: Long) extends ModMatrixOps[FibModMatrix] {

    def a21: Long = a12

    def mod = Q

    def I = FibModMatrix.I
    def A = FibModMatrix.A

    /** 取模矩阵乘法
      * @param x
      *   被乘数
      * @param m
      *   模
      * @return
      */
    override infix def *(y: FibModMatrix): FibModMatrix = {
      FibModMatrix(
        a11 = (a11 * y.a11 + a12 * y.a21) % mod,
        a12 = (a11 * y.a12 + a12 * y.a22) % mod,
        a22 = (a21 * y.a12 + a22 * y.a22) % mod
      )
    }

    // for A^n
    infix def ^^[T: Integral](n: T): FibModMatrix = fastPow(n)

    // for A^{p^q}
    infix def ^^[T: Integral](p: T, q: Long): FibModMatrix = fastPow(p, q)

  }

  object FibModMatrix {
    val A = FibModMatrix(1, 1, 0)
    val I = FibModMatrix(1, 0, 1)
  }

  def fibM_[T: Integral](n: T): Long = (FibModMatrix.A ^^ n).a12

  def fibM_[T: Integral](p: T, q: Long): Long = (FibModMatrix.A ^^ (p, q)).a12

  def fibM(n: Int | Long | BigInt): Long = n match
    case n: Int    => fibM_(n)
    case n: Long   => fibM_(n)
    case n: BigInt => fibM_(n)

  def fibM(p: Int | Long | BigInt, q: Long): Long = p match
    case p: Int    => fibM_(p, q)
    case p: Long   => fibM_(p, q)
    case p: BigInt => fibM_(p, q)
}

def testFastPowMatrixOpt(n: Int | Long | BigInt): Unit = {
  import FibModMatrix.*
  val t0 = System.nanoTime()
  val r  = fibM(n)
  val t1 = System.nanoTime()
  println(s"time cost: ${(t1 - t0) / 1_000_000.0} ms, result = $r")
}

def testFastPowMatrixOpt(p: BigInt, q: Long): Unit = {
  import FibModMatrix.*
  val t0 = System.nanoTime()
  val r  = fibM(p, q)
  val t1 = System.nanoTime()
  println(s"time cost: ${(t1 - t0) / 1_000_000.0} ms, result = $r")
}

测试一下:

scala> testFastPowMatrixOpt(List.fill(1000)(BigInt(10)).product, 100_000)
time cost: 2788.598254 ms, result = 761244216

scala> testFastPowMatrixOpt(List.fill(1000)(BigInt(10)).product, 100_000)
time cost: 2785.836456 ms, result = 761244216

scala> testFastPowMatrixOpt(List.fill(1000)(BigInt(10)).product, 100_000)
time cost: 2828.199939 ms, result = 761244216

可见,计算 f(10^{100000000}) = f(10^{10^8}) 的耗时又从 3.3s 缩减到 2.8s.

scala> testFastPowMatrixOpt(10, 1000)
time cost: 0.090972 ms, result = 552179166

而计算 f(10^{1000}) 的耗时,更是直接下降到了 0.09ms .

再来计算第一小题中的 f(10^9):

scala> testFastPowMatrixOpt(10, 9)
time cost: 0.02341 ms, result = 21

耗时 0.02ms, 简直不费吹灰之力。

拓展至 k 阶常系数齐次线性递推数列

公式

对于 k 阶常系数齐次线性递推数列

\begin{align} a_n = \begin{cases} a_n & \text{for} \; 0 \leq n \leq k-1, \\ c_0a_{n-1} + c_1a_{n-2} + c_2a_{n-3} + c_3a_{n-4} + \ldots + c_{k-1}a_{n-k} = \sum_{i=0}^{k-1}c_{i}a_{n-i-1} & \text{for} \;n \geq k. \\ \end{cases} \end{align}

对任意 n \geq -1 亦有递推式

\begin{align*} a_{n+k+1} &= c_0a_{n+k} + c_1a_{n+k-1} + c_2a_{n+k-2} + c_3a_{n+k-3} + \ldots + c_{k-1}a_{n+1} \\ &= \sum_{i=0}^{k-1}c_{i}a_{n+k-i} \end{align*}

容易验证矩阵递推式

\begin{align} \begin{bmatrix} a_{n+k+1} \\ a_{n+k} \\ a_{n+k-1} \\ a_{n+k-2} \\ a_{n+k-3} \\ \vdots \\ a_{n} \\ \end{bmatrix} &= \begin{bmatrix} c_0 && c_1 && c_2 && c_3 && \ldots && c_{k-2} && c_{k-1} \\ 1 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 1 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 1 && 0 && \ldots && 0 && 0 \\ \vdots && \vdots && \vdots && \vdots && \ldots && \vdots && \vdots \\ 0 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 0 && 0 && \ldots && 1 && 0 \\ \end{bmatrix} \begin{bmatrix} a_{n+k} \\ a_{n+k-1} \\ a_{n+k-2} \\ a_{n+k-3} \\ a_{n+k-4} \\ \vdots \\ a_{n+1} \\ \end{bmatrix} \\ &= \begin{bmatrix} c_0 && c_1 && c_2 && c_3 && \ldots && c_{k-2} && c_{k-1} \\ 1 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 1 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 1 && 0 && \ldots && 0 && 0 \\ \vdots && \vdots && \vdots && \vdots && \ldots && \vdots && \vdots \\ 0 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 0 && 0 && \ldots && 1 && 0 \\ \end{bmatrix}^{\large 2} \begin{bmatrix} a_{n+k-1} \\ a_{n+k-2} \\ a_{n+k-3} \\ a_{n+k-4} \\ a_{n+k-5} \\ \vdots \\ a_{n} \\ \end{bmatrix} \\ &= \begin{bmatrix} c_0 && c_1 && c_2 && c_3 && \ldots && c_{k-2} && c_{k-1} \\ 1 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 1 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 1 && 0 && \ldots && 0 && 0 \\ \vdots && \vdots && \vdots && \vdots && \ldots && \vdots && \vdots \\ 0 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 0 && 0 && \ldots && 1 && 0 \\ \end{bmatrix}^{\large 3} \begin{bmatrix} a_{n+k-2} \\ a_{n+k-3} \\ a_{n+k-4} \\ a_{n+k-5} \\ a_{n+k-6} \\ \vdots \\ a_{n-1} \\ \end{bmatrix} \\ &= \ldots \\ &= \begin{bmatrix} c_0 && c_1 && c_2 && c_3 && \ldots && c_{k-2} && c_{k-1} \\ 1 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 1 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 1 && 0 && \ldots && 0 && 0 \\ \vdots && \vdots && \vdots && \vdots && \ldots && \vdots && \vdots \\ 0 && 0 && 0 && 0 && \ldots && 0 && 0 \\ 0 && 0 && 0 && 0 && \ldots && 1 && 0 \\ \end{bmatrix}^{\large n} \begin{bmatrix} a_{k-1} \\ a_{k-2} \\ a_{k-3} \\ a_{k-4} \\ a_{k-5} \\ \vdots \\ a_{0} \\ \end{bmatrix} \\ &= A^nB \end{align}

可见仍然可以用类似前面讨论过的矩阵快速幂来计算 a_n 的项。

Scala 实现

由于暂时没有找到 Scala / Java 下既方便又快的矩阵运算库,我们可以简单手搓一个。至于其性能嘛,先不作追求。

手搓代码如下:

object MatrixOps {

  private[MatrixOps] def toBinary[T](n: T)(using ev: Integral[T]): List[Char] = {
    val two                                 = ev.fromInt(2)
    @scala.annotation.tailrec
    def f(n: T, xs: List[Char]): List[Char] =
      if (n == ev.zero) xs else f(ev.quot(n, two), (if (ev.rem(n, two) == ev.one) '1' else '0') :: xs)
    if (n == ev.zero) List('0') else f(n, Nil).reverse
  }
}

trait MatrixOps[A: Numeric, M <: MatrixOps[A, M]] { this: M =>

  // modulus
  def mod: A => A = identity

  // identity matrix
  def I: M

  //  characteristic matrix
  def A: M

  infix def *(y: M): M

  def ^[T: Integral](n: T): M = {
    val bs = MatrixOps.toBinary(n)
    ^(bs)
  }

  // for A^n
  private def ^[T: Integral](binaryList: List[Char]): M = {
    @scala.annotation.tailrec
    def f(x: M, z: M, ys: List[Char]): M =
      ys match {
        case '0' :: xs => f(x * x, z, xs)
        case '1' :: xs => f(x * x, x * z, xs)
        case _         => z
      }
    f(this, I, binaryList)
  }

  // for A^{p^q}
  def ^[T: Integral, S: Integral](p: T, q: S): M = {
    import scala.math.Numeric.Implicits.infixNumericOps
    val binaryList          = MatrixOps.toBinary(p)
    @scala.annotation.tailrec
    def f(z: M, r: S): M = if r == 0L then z else f(z.^[T](binaryList), r - Integral[S].one)
    q match
      case 0L => this // A^{p^0} = A^1 = A
      case _  => f(this.^(p), q - Integral[S].one)
  }
}

trait Hflorcc[A: Numeric](cx: List[A], initial: List[A]) {

  lazy val size = cx.size

  val one = Numeric[A].one

  val zero = Numeric[A].zero

  lazy val id = Matrix(
    size,
    size,
    Vector.tabulate(size) { i =>
      Vector.tabulate(size) { j =>
        if i == j then one else zero
      }
    }
  )

  lazy val a = Matrix(
    size,
    size,
    Vector.tabulate(size) {
      case 0 => cx.toVector
      case i =>
        Vector.tabulate(size) { j =>
          if j == i - 1 then one else zero
        }
    }
  )

  def modulus: A => A = identity

  final case class Matrix(rows: Int, cols: Int, data: Vector[Vector[A]]) extends MatrixOps[A, Matrix] {

    override def mod: A => A = modulus
    def I                    = id

    def A = a

    infix def *(y: Matrix): Matrix = {
      import scala.math.Numeric.Implicits.infixNumericOps
      require(cols == y.rows)
      val colsRange = (0 until y.cols).toVector
      Matrix(
        rows = rows,
        cols = y.cols,
        data = data.map { r =>
          colsRange.map { j =>
            r.foldLeft((0, zero)) { case ((k, acc), a) => (k + 1, mod(acc + a * y.data(k)(j))) }._2
          }
        }
      )
    }
  }

  val initM = Matrix(size, 1, initial.reverse.toVector.map(Vector(_)))

  // calculate a_n
  def apply[T: Integral](n: T): A = ((a ^ n) * initM).data.last.last

  // calculate a_{p^q}
  def apply[T: Integral, S: Integral](p: T, q: S): A = ((a ^ (p, q)) * initM).data.last.last
}

这样,求菲波那挈数列也就可以这样实现了:

object fib extends Hflorcc(List(BigInt(1), BigInt(1)), List(BigInt(0), BigInt(1)))

object fibM extends Hflorcc[Long](List(1, 1), List(0, 1)) {
  override val modulus: Long => Long = _ % 1_000_000_007L
}

测试一下:

println(fib(0)) // 0
println(fib(1)) // 1
println(fib(2)) // 1
println(fib(10)) // 55
println(fib(100)) // 354224848179261915075
println(fib(200)) // 280571172992510140037611932413038677189525
println(fib(300)) // 222232244629420445529739893461909967206666939096499764990979600
println(fibM(1000000000)) // 21
println(fibM(10, 9)) // 21
println(fibM(10, 1000)) // 552179166
println(fibM(10, 100000000)) // 761244216

我们来看菲波那挈数列的相似数列: 一个 3 阶线性递推数列

\begin{align} a_n = \begin{cases} 1 & \text{for} \;n = 0, \\ 1 & \text{for} \;n = 1, \\ 1 & \text{for} \;n = 2, \\ a_{n-1} + a_{n-2} + a_{n-3}& \text{for} \; n \geq 3. \end{cases} \end{align}

可以简单实现其通项公式的计算:

object Fa extends Hflorcc[BigInt](List(1, 1, 1).map(BigInt(_)), List(1, 1, 1).map(BigInt(_)))

打印这个数列的前几项:

scala> (0 to 100).map(x => x -> Fa(x)).foreach(println)
(0,1)
(1,1)
(2,1)
(3,3)
(4,5)
(5,9)
(6,17)
(7,31)
(8,57)
(9,105)
(10,193)
(11,355)
(12,653)
(13,1201)
(14,2209)
(15,4063)
(16,7473)
(17,13745)
(18,25281)
(19,46499)
(20,85525)
(21,157305)
(22,289329)
(23,532159)
(24,978793)
(25,1800281)
(26,3311233)
(27,6090307)
(28,11201821)
(29,20603361)
(30,37895489)
(31,69700671)
(32,128199521)
(33,235795681)
(34,433695873)
(35,797691075)
(36,1467182629)
(37,2698569577)
(38,4963443281)
(39,9129195487)
(40,16791208345)
(41,30883847113)
(42,56804250945)
(43,104479306403)
(44,192167404461)
(45,353450961809)
(46,650097672673)
(47,1195716038943)
(48,2199264673425)
(49,4045078385041)
(50,7440059097409)
(51,13684402155875)
(52,25169539638325)
(53,46294000891609)
(54,85147942685809)
(55,156611483215743)
(56,288053426793161)
(57,529812852694713)
(58,974477762703617)
(59,1792344042191491)
(60,3296634657589821)
(61,6063456462484929)
(62,11152435162266241)
(63,20512526282340991)
(64,37728417907092161)
(65,69393379351699393)
(66,127634323541132545)
(67,234756120799924099)
(68,431783823692756037)
(69,794174268033812681)
(70,1460714212526492817)
(71,2686672304253061535)
(72,4941560784813367033)
(73,9088947301592921385)
(74,16717180390659349953)
(75,30747688477065638371)
(76,56553816169317909709)
(77,104018685037042898033)
(78,191320189683426446113)
(79,351892690889787253855)
(80,647231565610256598001)
(81,1190444446183470297969)
(82,2189568702683514149825)
(83,4027244714477241045795)
(84,7407257863344225493589)
(85,13624071280504980689209)
(86,25058573858326447228593)
(87,46089903002175653411391)
(88,84772548141007081329193)
(89,155921025001509181969177)
(90,286783476144691916709761)
(91,527477049287208180008131)
(92,970181550433409278687069)
(93,1784442075865309375404961)
(94,3282100675585926834100161)
(95,6036724301884645488192191)
(96,11103267053335881697697313)
(97,20422092030806454019989665)
(98,37562083386026981205879169)
(99,69087442470169316923566147)
(100,127071617887002752149434981)

而对于 4 阶线性递推数列

\begin{align} a_n = \begin{cases} 1 & \text{for} \;n = 0, 1, 2, 3 \\ 19a_{n-1} + 17a_{n-2} + 5a_{n-3} + 2a_{n-4}& \text{for} \; n \geq 4. \end{cases} \end{align}

其通项公式:

object Fb extends Hflorcc[BigInt](List(19, 17, 5, 2).map(BigInt(_)), List(1, 1, 1, 1).map(BigInt(_)))

打印前几项:

scala> (0 to 50).map(x => x -> Fb(x)).foreach(println)
(0,1)
(1,1)
(2,1)
(3,1)
(4,43)
(5,841)
(6,16717)
(7,332137)
(8,6599083)
(9,131114173)
(10,2605047817)
(11,51758509153)
(12,1028366255827)
(13,20432140983745)
(14,405956907681613)
(15,8065772990971753)
(16,160255171696481107)
(17,3184037051900036389)
(18,63262182683711092201)
(19,1256927508277275719473)
(20,24973320458494220381563)
(21,496183535339626230014521)
(22,9858444781154053917544237)
(23,195872951400009695996934121)
(24,3891710502197418260130836059)
(25,77322624306527561296481515405)
(26,1536288022007942131643192764777)
(27,30523807329775656153953496341185)
(28,606463632182409515806053376732867)
(29,12049549821430619720713032522579313)
(30,239407019966975630627283579200701165)
(31,4756669092112429116061847731573517161)
(32,94508092765946256843474105015675108163)
(33,1877738271318367914391637250695625511285)
(34,37307926891570572836621731354299630541321)
(35,741254215354267094582920223214060136552081)
(36,14727642556429898267383698808553745791464747)
(37,292616825344191097288211813088433081762944745)
(38,5813870491929493595640135173004421131894693901)
(39,115513146098724484996588253444076560745199672681)
(40,2295078113650400421344034939841589086301401230475)
(41,45599862229146260100651441265030693144218017173597)
(42,906002903757313230487400199550456575853255934967625)
(43,18000959246144887252048391644170292823216566461833761)
(44,357652864508000214808551189907798162074513806557620019)
(45,7106041947079712023319933570079342401675998670155139617)
(46,141186712493188771157545324818560026764378498788606297101)
(47,2805174516711974047238346894976816660747702456541868786025)
(48,55734735445376130567753078296705249317533310174665704923315)
(49,1107368085892610142870194751343238579084469079571143504670149)
(50,22001792380539533181964009431690043444360246054101330128953513)

数学通项公式简便的情形

对于一些具有容易计算的通项公式的情形,直接用通项公式计算会更快。

比如,对于 3 阶线性递推数列

\begin{align} a_n = \begin{cases} 10 & \text{for} \;n = 0, \\ 27 & \text{for} \;n = 1, \\ 87 & \text{for} \;n = 2, \\ 10 a_{n-1} - 31 a_{n-2} + 30 a_{n-3}& \text{for} \; n \geq 3. \end{cases} \end{align}

因为其特征方程 x^3 - 10x^2 + 31x - 30 = 0 容易化成 (x-2)(x-3)(x-5) = 0, 因此有根 x_{1,2,3} = \{2,3,5\}. 由此可以得到递推数列的通项公式

a_n = 7 \cdot 2^n + 3^n + 2 \cdot 5^n, \ \text{for}\ n \geq 0

我们直接用上式计算数列的通项公式,会比前面实现的矩阵快速幂更加高效,尤其在 n 比较大的时候,还可以对 2^n, 3^n, 5^n 的计算进行快速幂优化。

参考资料

  • 我的朋友 counter 也写过一篇介绍菲波那挈数列和 Scala 的文章:5分钟学不会斐波那契数列,如果你对这个话题感兴趣,推荐阅读。