10:22 PM
I swear binary exponentiation on matrices is actually magic.
While explaining this CSES problem to a friend, it occured to me that I could improve my O(n) solution, which used state machine DP, into a O(log n) one, by using matrix exp.
But even though I've used it quite a few times before, it still feels like magic. Take this task for example:
Efficiently compute the nth fibonacci number mod 1000.
Sounds easy, right? It should be a simple constructive DP: keep track of the last 2 fibonacci numbers, and iterate n times...
def fibonacci(n): pre = [0, 1]
for i in range(n): pre = [ pre[1], (pre[0] + pre[1]) % 1000 ]
return pre[0]
This works fine, though we can do better.
To accelerate our solution with matrix multiplication, the first step we have to do is use a matrix to represent transitions between DP states rather than doing so explicitly:
def vmul(v, m): # multiply vector by matrix ret = [ m[0][0]*v[0] + m[0][1]*v[1], m[1][0]*v[0] + m[1][1]*v[1] ]
return [i % 1000 for i in ret]
def fibonacci(n): pre, transform = [0, 1], [[0, 1], [1, 1]]
for i in range(n): pre = vmul(pre, transform)
return pre[0]
If this doesn't quite make sense, consider:
[0] x [0 1] = [1] [1] [1 1] [2]
[2] x [0 1] = [3] [3] [1 1] [5]
[2] x [0 1] x [0 1] = [5] [3] [1 1] [1 1] [8]
...
Each multiplication of the matrix converts the vector of terms i, i+1 into the vector of terms i+1, i+2, and matrix multiplication is associative.
Now for the magic..
def mmul(m1, m2): # multiply matrix by matrix ret = [[0, 0], [0, 0]]
for i in range(2): for j in range(2): for k in range(2): ret[i][j] = (ret[i][j] + m1[i][k]*m2[k][j]) % 1000
return ret
def fibonacci(n): pre, transform = [0, 1], [[0, 1], [1, 1]]
while n: # binary exponentiation! if n % 2: pre = vmul(pre, transform) transform = mmul(transform, transform) n //= 2
return pre[0]
Now, we can find fibonacci number n in O(log n) steps.. but how?
Most other log n speedups make immediate sense - like, for example, binary search: Knowing that the array or other ordered structure is sorted, it is obvious that we can skip over the majority of entries. Other divide-and-conquer algorithms physically skip over or divide the search space, so it makes sense there.. but how are we able to just skip over so many fibonacci numbers like that?
I suppose it's easier to see when looking at how the transformation matrix compounds upon itself, chaining jumps of 1 to jumps of 2, then 4, and so on, but for more complex matrix exp problems, especially ones that count paths through graphs, the transformation matrix just.. feels like cheating? Like, how does one pack so much information into such a small structure without any problematic overlaps?
Definitely magic.
tags: programming cses math