#Java

快速幂运算是一种利用位运算和DP思想求的\(x^n\)的数值算法,它将时间复杂度\(O(n)\)降到\(O(log(n))\)。快速幂运算结合矩阵乘法,可以巧解不少DP问题。本篇会由浅入深,从最基本的快速幂运算算法,到应用矩阵快速幂运算解DP问题,结合三道Leetcode题目来具体讲解。

Leetcode 50. Pow(x, n) (Medium)

Leetcode 50. Pow(x, n) 是实数的快速幂运算问题,题目如下。

Implement pow(x, n), which calculates x raised to the power n (i.e. \(x^n\)).

Example 1:

1
2
Input: x = 2.00000, n = 10
Output: 1024.00000

Example 2:

1
2
Input: x = 2.10000, n = 3
Output: 9.26100

Example 3:

1
2
3
Input: x = 2.00000, n = -2
Output: 0.25000
Explanation: 2-2 = 1/22 = 1/4 = 0.25

快速幂运算解法分析

假设n是32位的int类型,将n写成二进制形式,那么n可以写成最多32个某位为 1(第k位为1则值为\(2^k\))的和。那么\(x^n\)最多可以由32个 \(x^{2^k}\)的乘积组合,例如:

\[ x^{\text{10011101}_{2}} = x^{1} \times x^{\text{100}_{2}} \times x^{\text{1000}_{2}} \times x^{\text{10000}_{2}} \times x^{\text{10000000}_{2}} \]

快速幂运算的特点就是通过32次循环,每次循环根据上轮\(x^{2^k}\)的值进行平方后得出这一轮的值:\(x^{2^k} \times x^{2^k} = x^{2^{k+1}}\),即循环计算出如下数列

\[ x^{1}, x^2=x^{\text{10}_{2}}, x^4=x^{\text{100}_{2}}, x^8=x^{\text{1000}_{2}}, x^{16}=x^{\text{10000}_{2}}, ..., x^{128} = x^{\text{10000000}_{2}} \]

在循环时,如果n的二进制形式在本轮对应的位的值是1,则将这次结果累乘计入最终结果。

下面是python 3 的代码,由于循环为32次,所以容易看出算法复杂度为 \(O(log(n))\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# AC
# Runtime: 32 ms, faster than 54.28% of Python3 online submissions for Pow(x, n).
# Memory Usage: 14.2 MB, less than 5.04% of Python3 online submissions for Pow(x, n).

class Solution:
def myPow(self, x: float, n: int) -> float:
ret = 1.0
i = abs(n)
while i != 0:
if i & 1:
ret *= x
x *= x
i = i >> 1
return 1.0 / ret if n < 0 else ret

对应的 Java 的代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// AC
// Runtime: 1 ms, faster than 42.98% of Java online submissions for Pow(x, n).
// Memory Usage: 38.7 MB, less than 48.31% of Java online submissions for Pow(x, n).

class Solution {
public double myPow(double x, int n) {
double ret = 1.0;
long i = Math.abs((long) n);
while (i != 0) {
if ((i & 1) > 0) {
ret *= x;
}
x *= x;
i = i >> 1;
}

return n < 0 ? 1.0 / ret : ret;
}
}

矩阵快速幂运算

快速幂运算也可以应用到计算矩阵的幂,即上面的x从实数变为方形矩阵。实现上,矩阵的幂需要矩阵乘法:$ A_{r c} B_{c p}$ ,Python中可以用numpy的 np.matmul(A, B)来完成,而Java版本中我们手动实现简单的矩阵相乘算法,从三重循环看出其算法复杂度为\(O(r \times c \times p)\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
public int[][] matrixProd(int[][] A, int[][] B) {
int R = A.length;
int C = B[0].length;
int P = A[0].length;
int[][] ret = new int[R][C];
for (int r = 0; r < R; r++) {
for (int c = 0; c < C; c++) {
for (int p = 0; p < P; p++) {
ret[r][c] += A[r][p] * B[p][c];
}
}
}
return ret;
}

Leetcode 509. Fibonacci Number (Easy)

有了快速矩阵幂运算,我们来看看如何具体解题。Fibonacci问题作为最基本的DP问题,在上一篇Leetcode 679 24 Game 的 Python 函数式实现中我们用python独有的yield来巧解,这次再拿它来做演示。

The Fibonacci numbers, commonly denoted F(n) form a sequence, called the Fibonacci sequence, such that each number is the sum of the two preceding ones, starting from 0 and 1. That is,

1
2
F(0) = 0,   F(1) = 1
F(N) = F(N - 1) + F(N - 2), for N > 1.

Given N, calculate F(N).

Example 1:

1
2
3
Input: 2
Output: 1
Explanation: F(2) = F(1) + F(0) = 1 + 0 = 1.

Example 2:

1
2
3
Input: 3
Output: 2
Explanation: F(3) = F(2) + F(1) = 1 + 1 = 2.

Example 3:

1
2
3
Input: 4
Output: 3
Explanation: F(4) = F(3) + F(2) = 2 + 1 = 3.

转换为矩阵幂运算

Fibonacci的二阶递推式如下:

\[ \begin{align*} F(n) =& F(n-1) + F(n-2) \\ F(n-1) =& F(n-1) \end{align*} \]

等价的矩阵递推形式为:

\[ \begin{bmatrix}F(n)\\F(n-1)\end{bmatrix} = \begin{bmatrix}1 & 1\\1 & 0\end{bmatrix} \begin{bmatrix}F(n-1)\\F(n-2)\end{bmatrix} \]

也就是每轮左乘一个2维矩阵。其循环形式为,即矩阵幂的形式:

\[ \begin{bmatrix}F(n)\\F(n-1)\end{bmatrix} = \begin{bmatrix}1 & 1\\1 & 0\end{bmatrix}^{n-1} \begin{bmatrix}F(1)\\F(0)\end{bmatrix} \]

AC代码

有了上面的矩阵幂公式,代码稍作改动即可。Java 版本代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/**
* AC
* Runtime: 0 ms, faster than 100.00% of Java online submissions for Fibonacci Number.
* Memory Usage: 37.9 MB, less than 18.62% of Java online submissions for Fibonacci Number.
*
* Method: Matrix Fast Power Exponentiation
* Time Complexity: O(log(N))
**/
class Solution {
public int fib(int N) {
if (N <= 1) {
return N;
}
int[][] M = {{1, 1}, {1, 0}};
// powers = M^(N-1)
N--;
int[][] powerDouble = M;
int[][] powers = {{1, 0}, {0, 1}};
while (N > 0) {
if (N % 2 == 1) {
powers = matrixProd(powers, powerDouble);
}
powerDouble = matrixProd(powerDouble, powerDouble);
N = N / 2;
}

return powers[0][0];
}

public int[][] matrixProd(int[][] A, int[][] B) {
int R = A.length;
int C = B[0].length;
int P = A[0].length;
int[][] ret = new int[R][C];
for (int r = 0; r < R; r++) {
for (int c = 0; c < C; c++) {
for (int p = 0; p < P; p++) {
ret[r][c] += A[r][p] * B[p][c];
}
}
}
return ret;
}

}

Python 3的numpy.matmul() 版本代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# AC
# Runtime: 256 ms, faster than 26.21% of Python3 online submissions for Fibonacci Number.
# Memory Usage: 29.4 MB, less than 5.25% of Python3 online submissions for Fibonacci Number.

class Solution:

def fib(self, N: int) -> int:
if N <= 1:
return N

import numpy as np
F = np.array([[1, 1], [1, 0]])

N -= 1
powerDouble = F
powers = np.array([[1, 0], [0, 1]])
while N > 0:
if N % 2 == 1:
powers = np.matmul(powers, powerDouble)
powerDouble = np.matmul(powerDouble, powerDouble)
N = N // 2

return powers[0][0]

或者也可以直接调用numpy.matrix_power() 代替手动的快速矩阵幂运算。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# AC
# Runtime: 116 ms, faster than 26.25% of Python3 online submissions for Fibonacci Number.
# Memory Usage: 29.2 MB, less than 5.25% of Python3 online submissions for Fibonacci Number.

class Solution:

def fib(self, N: int) -> int:
if N <= 1:
return N

from numpy.linalg import matrix_power
import numpy as np
F = np.array([[1, 1], [1, 0]])
F = matrix_power(F, N - 1)

return F[0][0]

Leetcode 1411. Number of Ways to Paint N × 3 Grid (Hard)

下面来看一道稍难一点的DP问题,1411. Number of Ways to Paint N × 3 Grid

You have a grid of size n x 3 and you want to paint each cell of the grid with exactly one of the three colours: Red, Yellow or Green while making sure that no two adjacent cells have the same colour (i.e no two cells that share vertical or horizontal sides have the same colour).

You are given n the number of rows of the grid.

Return the number of ways you can paint this grid. As the answer may grow large, the answer must be computed modulo 10^9 + 7.

Example 1:

1
2
3
Input: n = 1
Output: 12
Explanation: There are 12 possible way to paint the grid as shown:

Example 2:

1
2
Input: n = 2
Output: 54

Example 3:

1
2
Input: n = 3
Output: 246

Example 4:

1
2
Input: n = 7
Output: 106494

Example 5:

1
2
Input: n = 5000
Output: 30228214

标准DP解法

分析题目容易发现第i行的状态只取决于第i-1行的状态,第i行会有两种不同状态:三种颜色都有或者只有两种颜色。这个问题容易识别出是经典的双状态DP问题,那么我们定义dp2[i]为第i行只有两种颜色的数量,dp3[i]为第i行有三种颜色的数量。

先考虑dp3[i]和i-1行的关系。假设第i行包含3种颜色,即dp3[i],假设具体颜色为红,绿,黄,若i-1行包含两种颜色(即dp2[i-1]),此时dp2[i-1]只有以下2种可能:

dp2[i-1] -> dp3[i]
还是dp3[i] 红,绿,黄情况,若i-1行包含三种颜色(从dp3[i-1]转移过来),此时dp3[i-1]也只有以下2种可能:
dp3[i-1] -> dp3[i]

因此,dp3[i]= dp2[i-1] * 2 + dp3[i-1] * 2。

同理,若第i行包含两种颜色,即dp2[i],假设具体颜色为绿,黄,绿,若i-1行是两种颜色(dp2[i-1]),此时dp2[i-1]有如下3种可能:

dp2[i-1] -> dp2[i]
dp2[i]的另一种情况是由dp3[i-1]转移过来,则dp3[i-1]有2种可能,枚举如下:
dp3[i-1] -> dp2[i]

因此,dp2[i] = dp2[i-1] * 3 + dp3[i-1] * 2。 初始值dp2[1] = 6,dp3[1] = 6,最终答案为dp2[i] + dp3[i]。

很容易写出普通DP版本的Python 3代码,时间复杂度为\(O(n)\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
# AC
# Runtime: 36 ms, faster than 98.88% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.
# Memory Usage: 13.9 MB, less than 58.66% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.

class Solution:
def numOfWays(self, n: int) -> int:
MOD = 10 ** 9 + 7
dp2, dp3 = 6, 6
n -= 1
while n > 0:
dp2, dp3 = (dp2 * 3 + dp3 * 2) % MOD, (dp2 * 2 + dp3 * 2) % MOD
n -= 1
return (dp2 + dp3) % MOD

快速矩阵幂运算解法

和Fibonacci一样,我们将DP状态转移方程转换成矩阵乘法:

\[ \begin{bmatrix}dp2(n)\\dp3(n)\end{bmatrix} = \begin{bmatrix}3 & 2\\2 & 2\end{bmatrix} \begin{bmatrix}dp2(n-1)\\dp3(n-1)\end{bmatrix} \]

代入初始值,转换成矩阵幂形式

\[ \begin{bmatrix}dp2(n)\\dp3(n)\end{bmatrix} = \begin{bmatrix}3 & 2\\2 & 2\end{bmatrix}^{n-1}\begin{bmatrix}6\\6\end{bmatrix} \]

代码几乎和Fibonacci一模一样,仅仅多了mod 计算。下面是Java版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

/**
AC
Runtime: 0 ms, faster than 100.00% of Java online submissions for Number of Ways to Paint N × 3 Grid.
Memory Usage: 35.7 MB, less than 97.21% of Java online submissions for Number of Ways to Paint N × 3 Grid.
**/

class Solution {
public int numOfWays(int n) {
long MOD = (long) (1e9 + 7);
long[][] ret = {{6, 6}};
long[][] m = {{3, 2}, {2, 2}};
n -= 1;
while(n > 0) {
if ((n & 1) > 0) {
ret = matrixProd(ret, m, MOD);
}
m = matrixProd(m, m, MOD);
n >>= 1;
}
return (int) ((ret[0][0] + ret[0][1]) % MOD);

}

public long[][] matrixProd(long[][] A, long[][] B, long MOD) {
int R = A.length;
int C = B[0].length;
int P = A[0].length;
long[][] ret = new long[R][C];
for (int r = 0; r < R; r++) {
for (int c = 0; c < C; c++) {
for (int p = 0; p < P; p++) {
ret[r][c] += A[r][p] * B[p][c];
ret[r][c] = ret[r][c] % MOD;
}
}
}
return ret;
}

}

Python 3实现为

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# AC
# Runtime: 88 ms, faster than 39.07% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.
# Memory Usage: 30.2 MB, less than 11.59% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.

class Solution:
def numOfWays(self, n: int) -> int:
import numpy as np

MOD = int(1e9 + 7)
ret = np.array([[6, 6]])
m = np.array([[3, 2], [2, 2]])

n -= 1
while n > 0:
if n % 2 == 1:
ret = np.matmul(ret, m) % MOD
m = np.matmul(m, m) % MOD
n = n // 2
return int((ret[0][0] + ret[0][1]) % MOD)

本篇是TSP问题从DP算法到深度学习系列第二篇。

AIZU TSP 自底向上迭代DP解

上一篇中,我们用Python 3和Java 8完成了自顶向下递归版本的DP解。我们继续改进代码,将它转换成标准DP方式:自底向上的迭代DP版本。下图是3个点TSP问题的递归调用图。

将这个图反过来检查状态的依赖关系,可以很容易发现规律:首先计算状态位含有一个1的点,接着是两个1的节点,最后是状态位三个1的点。简而言之,在计算状态位为n+1个1的节点时需要用到n个1的节点的计算结果,如果能依照这样的 topological 顺序来的话,就可以去除递归,写成迭代(循环)版本的DP。

迭代算法的Java 伪代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for (int bitset_num = N; bitset_num >=0; bitset_num++) {
while(hasNextCombination(bitset_num)) {
int state = nextCombination(bitset_num);
// compute dp[state][v], v-th bit is set in state
for (int v = 0; v < n; v++) {
for (int u = 0; u < n; u++) {
// for each u not reached by this state
if (!include(state, u)) {
dp[state][v] = min(dp[state][v],
dp[new_state_include_u][u] + dist[v][u]);
}
}
}
}
}

举例来说,dp[00010][1] 是从顶点0出发,刚经过顶点1的最小距离 \(0 \rightarrow 1 \rightarrow ? \rightarrow ? \rightarrow ? \rightarrow 0\)

为了找到最小距离值,就必须遍历所有可能的下一个可能的顶点u (第一个问号位置)。 \[ (0 \rightarrow 1) + \begin{align*} \min \left\lbrace \begin{array}{r@{}l} 2 \rightarrow ? \rightarrow ? \rightarrow 0 + dist(1,2) \qquad\text{ new_state=[00110][2] } \qquad\\\\ 3 \rightarrow ? \rightarrow ? \rightarrow 0 + dist(1,3) \qquad\text{ new_state=[01010][3] } \qquad\\\\ 4 \rightarrow ? \rightarrow ? \rightarrow 0 + dist(1,4) \qquad\text{ new_state=[10010][4] } \qquad \end{array} \right. \end{align*} \]

迭代DP AC代码

以下是AC 的Java 算法核心代码。完整代码在 github/MyEncyclopedia 的tsp/alg_aizu/Main_loop.java

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public long solve() {
int N = g.V_NUM;
long[][] dp = new long[1 << N][N];
// init dp[][] with MAX
for (int i = 0; i < dp.length; i++) {
Arrays.fill(dp[i], Integer.MAX_VALUE);
}
dp[(1 << N) - 1][0] = 0;

for (int state = (1 << N) - 2; state >= 0; state--) {
for (int v = 0; v < N; v++) {
for (int u = 0; u < N; u++) {
if (((state >> u) & 1) == 0) {
dp[state][v] = Math.min(dp[state][v], dp[state | 1 << u][u] + g.edges[v][u]);
}
}
}
}
return dp[0][0] == Integer.MAX_VALUE ? -1 : dp[0][0];
}

很显然,时间算法复杂度对应了三重 for 循环,为 O(\(2^n * n * n\)) = O(\(2^n*n^2\) )。

类似的,Python 3 AC 代码如下。完整代码在 github/MyEncyclopedia 的tsp/alg_aizu/TSP_loop.py

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class TSPSolver:
g: Graph

def __init__(self, g: Graph):
self.g = g

def solve(self) -> int:
"""
:param v:
:param state:
:return: -1 means INF
"""
N = self.g.v_num
dp = [[INT_INF for c in range(N)] for r in range(1 << N)]

dp[(1 << N) - 1][0] = 0

for state in range((1 << N) - 2, -1, -1):
for v in range(N):
for u in range(N):
if ((state >> u) & 1) == 0:
if dp[state | 1 << u][u] != INT_INF and self.g.edges[v][u] != INT_INF:
if dp[state][v] == INT_INF:
dp[state][v] = dp[state | 1 << u][u] + self.g.edges[v][u]
else:
dp[state][v] = min(dp[state][v], dp[state | 1 << u][u] + self.g.edges[v][u])
return dp[0][0]

一个欧式空间TSP数据集

至此,TSP的DP解法全部讲解完毕。接下去,我们引入一个二维欧式空间的TSP数据集 PTR_NET on Google Drive ,这个数据集是 Pointer Networks 的作者 Oriol Vinyals 用于模型的训练测试而引入的。

数据集的每一行格式如下:

1
x1, y1, x2, y2, ... output 1 v1 v2 v3 ... 1

一行开始为n个点的x, y坐标,接着是 output,再接着是1,表示从顶点1出发,经v1,v2,...,返回1,注意顶点编号从1开始。

十个顶点数据集的一些数据示例如下:

1
2
3
4
0.607122 0.664447 0.953593 0.021519 0.757626 0.921024 0.586376 0.433565 0.786837 0.052959 0.016088 0.581436 0.496714 0.633571 0.227777 0.971433 0.665490 0.074331 0.383556 0.104392 output 1 3 8 6 10 9 5 2 4 7 1 
0.930534 0.747036 0.277412 0.938252 0.794592 0.794285 0.961946 0.261223 0.070796 0.384302 0.097035 0.796306 0.452332 0.412415 0.341413 0.566108 0.247172 0.890329 0.429978 0.232970 output 1 3 2 9 6 5 8 7 10 4 1
0.686712 0.087942 0.443054 0.277818 0.494769 0.985289 0.559706 0.861138 0.532884 0.351913 0.712561 0.199273 0.554681 0.657214 0.909986 0.277141 0.931064 0.639287 0.398927 0.406909 output 1 5 2 10 7 4 3 9 8 6 1

画出第一个例子的全部顶点和边。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import matplotlib.pyplot as plt
points='0.607122 0.664447 0.953593 0.021519 0.757626 0.921024 0.586376 0.433565 0.786837 0.052959 0.016088 0.581436 0.496714 0.633571 0.227777 0.971433 0.665490 0.074331 0.383556 0.104392'
float_list = list(map(lambda x: float(x), points.split(' ')))

x,y = [],[]
for idx, p in enumerate(float_list):
if idx % 2 == 0:
x.append(p)
else:
y.append(p)

for i in range(0, len(x)):
for j in range(0, len(x)):
if i == j:
continue
plt.plot((x[i],x[j]),(y[i],y[j]))

plt.show()
全连接的图

这个例子的最短TSP旅程为 \[ 1 \rightarrow 3 \rightarrow 8 \rightarrow 6 \rightarrow 10 \rightarrow 9 \rightarrow 5 \rightarrow 2 \rightarrow 4 \rightarrow 7 \rightarrow 1 \]

{linenos
1
2
3
4
5
6
7
8
tour_str = '1 3 8 6 10 9 5 2 4 7 1'
tour = list(map(lambda x: int(x), tour_str.split(' ')))

for i in range(0, len(tour)-1):
p1 = tour[i] - 1
p2 = tour[i + 1] - 1
plt.plot((x[p1],x[p2]),(y[p1],y[p2]))
plt.show()
最短路径

PTR_NET TSP 的Python代码

初始化Init Graph Edges

在之前的自顶向下的递归版本中,需要做一些改动。首先,是图的初始化,我们依然延续之前的邻接矩阵来表示,由于这次的图是无向图,对于任意两个顶点,需要初始化双向的边。

{linenos
1
2
3
4
5
6
7
8
g: Graph = Graph(N)
for v in range(N):
for u in range(N):
diff_x = coordinates[v][0] - coordinates[u][0]
diff_y = coordinates[v][1] - coordinates[u][1]
dist: float = math.sqrt(diff_x * diff_x + diff_y * diff_y)
g.setDist(u, v, dist)
g.setDist(v, u, dist)

辅助变量记录父节点

另一大改动是需要在遍历过程中保存的顶点关联信息,以便在最终找到最短路径值时可以回溯对应的完整路径。在下面代码中,使用parent[bitstate][v] 来保存此状态下最小路径对应的顶点u。

{linenos
1
2
3
4
5
6
7
8
9
10
11
ret: float = FLOAT_INF
u_min: int = -1
for u in range(self.g.v_num):
if (state & (1 << u)) == 0:
s: float = self._recurse(u, state | 1 << u)
if s + edges[v][u] < ret:
ret = s + edges[v][u]
u_min = u
dp[state][v] = ret
self.parent[state][v] = u_min

当最终最短行程确定后,根据parent的信息可以按图索骥找到完整的行程顶点信息。

{linenos
1
2
3
4
5
6
7
8
9
def _form_tour(self):
self.tour = [0]
bit = 0
v = 0
for _ in range(self.g.v_num - 1):
v = self.parent[bit][v]
self.tour.append(v)
bit = bit | (1 << v)
self.tour.append(0)

需要注意的是,有可能存在多个最短行程,它们的距离值是一致的。这种情况下,代码输出的最短路径可能和数据集output后行程路径不一致,但是的两者的总距离是一致的。下面的代码验证了这一点。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
tsp: TSPSolver = TSPSolver(g)
tsp.solve()

output_dist: float = 0.0
output_tour = list(map(lambda x: int(x) - 1, output.split(' ')))
for v in range(1, len(output_tour)):
pre_v = output_tour[v-1]
curr_v = output_tour[v]
diff_x = coordinates[pre_v][0] - coordinates[curr_v][0]
diff_y = coordinates[pre_v][1] - coordinates[curr_v][1]
dist: float = math.sqrt(diff_x * diff_x + diff_y * diff_y)
output_dist += dist

passed = abs(tsp.dist - output_dist) < 10e-5
if passed:
print(f'passed dist={tsp.tour}')
else:
print(f'Min Tour Distance = {output_dist}, Computed Tour Distance = {tsp.dist}, Expected Tour = {output_tour}, Result = {tsp.tour}')

本文所有代码在 github/MyEncyclopedia tsp/alg_plane 中。

旅行商问题(TSP)是计算机算法中经典的NP hard 问题。 在本系列文章中,我们将首先使用动态规划 AC aizu中的TSP问题,然后再利用深度学习求大规模下的近似解。深度学习应用解决问题时先以PyTorch实现监督学习算法 Pointer Network,进而结合强化学习来无监督学习,提高数据使用效率。 本系列完整列表如下:

  • 第一篇: 递归DP方法 AC AIZU TSP问题

  • 第二篇: 二维空间TSP数据集及其DP解法

  • 第三篇: 深度学习 Pointer Networks 的 Pytorch实现

  • 第四篇: 搜寻最有可能路径:Viterbi算法和其他

  • 第五篇: 深度强化学习无监督算法的 Pytorch实现

TSP 问题回顾

TSP可以用图模型来表达,无论有向图或无向图,无论全连通图或者部分连通的图都可以作为TSP问题。 Wikipedia TSP 中举了一个无向全连通的TSP例子。如下图所示,四个顶点A,B,C,D构成无向全连通图。TSP问题要求在所有遍历所有点后返回初始点的回路中找到最短的回路。例如,\(A \rightarrow B \rightarrow C \rightarrow D \rightarrow A\)\(A \rightarrow C \rightarrow B \rightarrow D \rightarrow A\) 都是有效的回路,但是TSP需要返回这些回路中的最短回路(注意,最短回路可能会有多条)。

Wikipedia 4个顶点组成的图

无论是哪种类型的图,我们都能用邻接矩阵表示出一个图。上面的Wikipedia中的图可以用下面的矩阵来描述。

\[ \begin{matrix} & \begin{matrix}A&B&C&D\end{matrix} \\\\ \begin{matrix}A\\\\B\\\\C\\\\D\end{matrix} & \begin{bmatrix}-&20&42&35\\\\20&-&30&34\\\\42&30&-&12\\\\35&34&12&-\end{bmatrix}\\\\ \end{matrix} \]

当然,大多数情况下,TSP问题会被限定在欧氏空间,即二维地图中的全连通无向图。因为,如果将顶点表示一个地理位置,一般来说它可以和其他所有顶点连通,回来的距离相同,由此构成无向图。

AIZU TSP 问题

AIZU在线题库 有一道有向不完全连通图的TSP问题。给定V个顶点和E条边,输出最小回路值。例如,题目里的例子如下所示,由4个顶点和6条单向边构成。

这个示例的答案是16,对应的回路是 \(0\rightarrow1\rightarrow3\rightarrow2\rightarrow0\),由下图的红色边构成。注意,这个题目可能不存在合法解,原因是无回路存在,此时返回-1,可以合理地理解成无穷大。

暴力解法

一种暴力方法是枚举所有可能的从某一顶点的回路,取其中的最小值即可。下面的 Python 示例如何枚举4个顶点构成的图中从顶点0出发的所有回路。

{linenos
1
2
3
4
5
from itertools import permutations
v = [1,2,3]
p = permutations(v)
for t in list(p):
print([0] + list(t) + [0])

所有从顶点0出发的回路如下:

{linenos
1
2
3
4
5
6
[0, 1, 2, 3, 0]
[0, 1, 3, 2, 0]
[0, 2, 1, 3, 0]
[0, 2, 3, 1, 0]
[0, 3, 1, 2, 0]
[0, 3, 2, 1, 0]

很显然,这种方式的时间复杂度是 O(\(n!\)),无法通过AIZU。

动态规划求解

我们可以使用位状态压缩的动态规划来AC这道题。 首先,需要将回路过程中的状态编码成二进制的表示。例如,在四顶点的例子中,如果顶点2和1都被访问过,并且此时停留在顶点1。将已经访问的顶点对应的位置1,那么编码成0110,此外,还需要保存当前顶点的位置,因此我们将代表状态的数组扩展成二维,第一维是位状态,第二维是顶点所在位置,即 \(dp[bitstate][v]\)。这个例子的状态表示就是 \(dp["0110"][1]\)

状态转移方程如下: \[ dp[bitstate][v] = \min ( dp[bitstate \cup \{u\}][u] + dist(v,u) \mid u \notin bitstate ) \] 这种方法对应的时间复杂度是 O(\(n^2*2^n\) ),因为总共有 \(2^n * n\) 个状态,而每个状态又需要一次遍历。虽然都是指数级复杂度,但是它们的巨大区别由下面可以看出区别。

\(n!\) \(n^2*2^n\)
n=8 40320 16384
n=10 3628800 102400
n=12 479001600 589824
n=14 87178291200 3211264

暂停思考一下为什么状态压缩DP能工作。注意到之前暴力解法中其实是有很多重复计算,下面红圈表示重复的计算节点。

在本篇中,我们将会用Python 3和Java 8 实现自顶向下的DP 缓存版本。这种方式比较符合直觉,因为我们不需要预先考虑计算节点的依赖关系。在Java中我们使用了一个小技巧,dp数组初始化成Integer.MAX_VALUE,如此只需要一条语句就能完成更新dp值。

1
res = Math.min(res, s + g.edges[v][u]);

当然,为了AC 这道题,我们需要区分出真正无法到达的情况并返回-1。 在Python实现中,也可以使用同样的技巧,但是这次示例一般的实现方法:将dp数组初始化成-1并通过 if-else 来区分不同情况。

1
2
3
4
5
6
7
INT_INF = -1

if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])

下面附完整的Python 3和Java 8的AC代码,同步在 github

AIZU Java 8 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// passed http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DPL_2_A
import java.util.Arrays;
import java.util.Scanner;

public class Main {
public static class Graph {
public final int V_NUM;
public final int[][] edges;

public Graph(int V_NUM) {
this.V_NUM = V_NUM;
this.edges = new int[V_NUM][V_NUM];
for (int i = 0; i < V_NUM; i++) {
Arrays.fill(this.edges[i], Integer.MAX_VALUE);
}
}

public void setDist(int src, int dest, int dist) {
this.edges[src][dest] = dist;
}

}

public static class TSP {
public final Graph g;
long[][] dp;

public TSP(Graph g) {
this.g = g;
}

public long solve() {
int N = g.V_NUM;
dp = new long[1 << N][N];
for (int i = 0; i < dp.length; i++) {
Arrays.fill(dp[i], -1);
}

long ret = recurse(0, 0);
return ret == Integer.MAX_VALUE ? -1 : ret;
}

private long recurse(int state, int v) {
int ALL = (1 << g.V_NUM) - 1;
if (dp[state][v] >= 0) {
return dp[state][v];
}
if (state == ALL && v == 0) {
dp[state][v] = 0;
return 0;
}
long res = Integer.MAX_VALUE;
for (int u = 0; u < g.V_NUM; u++) {
if ((state & (1 << u)) == 0) {
long s = recurse(state | 1 << u, u);
res = Math.min(res, s + g.edges[v][u]);
}
}
dp[state][v] = res;
return res;

}

}

public static void main(String[] args) {

Scanner in = new Scanner(System.in);
int V = in.nextInt();
int E = in.nextInt();
Graph g = new Graph(V);
while (E > 0) {
int src = in.nextInt();
int dest = in.nextInt();
int dist = in.nextInt();
g.setDist(src, dest, dist);
E--;
}
System.out.println(new TSP(g).solve());
}
}

AIZU Python 3 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import List

INT_INF = -1

class Graph:
v_num: int
edges: List[List[int]]

def __init__(self, v_num: int):
self.v_num = v_num
self.edges = [[INT_INF for c in range(v_num)] for r in range(v_num)]

def setDist(self, src: int, dest: int, dist: int):
self.edges[src][dest] = dist


class TSPSolver:
g: Graph
dp: List[List[int]]

def __init__(self, g: Graph):
self.g = g
self.dp = [[None for c in range(g.v_num)] for r in range(1 << g.v_num)]

def solve(self) -> int:
return self._recurse(0, 0)

def _recurse(self, v: int, state: int) -> int:
"""

:param v:
:param state:
:return: -1 means INF
"""
dp = self.dp
edges = self.g.edges

if dp[state][v] is not None:
return dp[state][v]

if (state == (1 << self.g.v_num) - 1) and (v == 0):
dp[state][v] = 0
return dp[state][v]

ret: int = INT_INF
for u in range(self.g.v_num):
if (state & (1 << u)) == 0:
s: int = self._recurse(u, state | 1 << u)
if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])
dp[state][v] = ret
return ret


def main():
V, E = map(int, input().split())
g: Graph = Graph(V)
for _ in range(E):
src, dest, dist = map(int, input().split())
g.setDist(src, dest, dist)

tsp: TSPSolver = TSPSolver(g)
print(tsp.solve())


if __name__ == "__main__":
main()

本系列,我们来看看在一种常见的组合游戏——回合制棋盘类游戏中,如何用算法来解决问题。首先,我们会介绍并解决搜索空间较小的问题,引入经典的博弈算法和相关理论,最终实现在大搜索空间中的Deep RL近似算法。在此基础上可以理解AlphaGo的原理和工作方式。 本系列的第一篇,我们介绍3个Leetcode中的零和回合制游戏,从最初的暴力解法,到动态规划最终演变成博弈论里的经典算法: minimax 以及 alpha beta 剪枝。

Leetcode 292 Nim Game (简单)

简单题 Leetcode 292 Nim Game

你和你的朋友,两个人一起玩 Nim游戏:桌子上有一堆石头,每次你们轮流拿掉 1 - 3 块石头。 拿掉最后一块石头的人就是获胜者。你作为先手。
你们是聪明人,每一步都是最优解。 编写一个函数,来判断你是否可以在给定石头数量的情况下赢得游戏。

示例:
输入: 4
输出: false
解释: 如果堆中有 4 块石头,那么你永远不会赢得比赛;因为无论你拿走 1 块、2 块 还是 3 块石头,最后一块石头总是会被你的朋友拿走。

定义 \(f(n)\) 为有\(n\)个石头并采取最优策略的游戏结果, \(f(n)\)的值只有可能是赢或者输。考察前几个结果:\(f(1) = f(2) = f(3) = Win\),然后来计算\(f(4)\)。因为玩家采取最优策略(只要有一种走法让对方必输,玩家获胜),对于4来说,玩家能走的可能是拿掉1块、2块或3块,但是无论剩余何种局面,对方都是必赢,因此,4就是必输。总的说来,递归关系如下: \[ f(n) = \neg (f(n-1) \land f(n-2) \land f(n-3)) \]

这个递归式可以直接翻译成Python 3代码
{linenos
1
2
3
4
5
6
7
8
9
10
11
# TLE
# Time Complexity: O(exponential)
class Solution_BruteForce:

def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
for i in range(1, 4):
if not self.canWinNim(n - i):
return True
return False
以上的递归公式和代码很像fibonacci数的递归定义和暴力解法,因此对应的时间复杂度也是指数级的,提交代码以后会TLE。下图画出了当n=7时的递归调用,注意 5 被扩展向下重复执行了两次,4重复了4次。
292 Nim Game 暴力解法调用图 n=7
我们采用和fibonacci一样的方式来优化算法:缓存较小n的结果以此来计算较大n的结果。 Python 中,我们可以只加一行lru_cache decorator,来取得这种动态规划效果,下面的代码将复杂度降到了 \(O(N)\)
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
# RecursionError: maximum recursion depth exceeded in comparison n=1348820612
# Time Complexity: O(N)
class Solution_DP:
from functools import lru_cache
@lru_cache(maxsize=None)
def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
for i in range(1, 4):
if not self.canWinNim(n - i):
return True
return False
再来画出调用图:这次5和4就不再被展开重复计算,图中绿色的节点表示缓存命中。
292 Nim Game 动归解法调用图 n=7

但还是没有AC,因为当n=1348820612时,这种方式会导致栈溢出。再改成下面的循环版本,可惜还是TLE。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# TLE for 1348820612
# Time Complexity: O(N)
class Solution:
def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
last3, last2, last1 = True, True, True
for i in range(4, n+1):
this = not (last3 and last2 and last1)
last3, last2, last1 = last2, last1, this
return last1

由此看来,AC 版本需要低于\(O(n)\)的算法复杂度。上面的写法似乎暗示输赢有周期性的规律。事实上,如果将输赢按照顺序画出来,就马上得出规律了:只要\(n \mod 4 = 0\) 就是输,否则赢。原因如下:当面临不能被4整除的数量时 \(4k+i (i=1,2,3)\) ,一方总是可以拿走 \(i\) 个,将\(4k\) 留给对手,而对方下轮又将返回不能被4整除的数,如此循环往复,直到这一方有1, 2, 3 个,最终获胜。

输赢分布

最终AC版本,只有一句语句。

{linenos
1
2
3
4
5
# AC
# Time Complexity: O(1)
class Solution:
def canWinNim(self, n: int) -> bool:
return not (n % 4 == 0)

Leetcode 486 Predict the Winner (中等)

中等难度题目: Leetcode 486 Predict the Winner.

给定一个表示分数的非负整数数组。 玩家1从数组任意一端拿取一个分数,随后玩家2继续从剩余数组任意一端拿取分数,然后玩家1拿,……。每次一个玩家只能拿取一个分数,分数被拿取之后不再可取。直到没有剩余分数可取时游戏结束。最终获得分数总和最多的玩家获胜。
给定一个表示分数的数组,预测玩家1是否会成为赢家。你可以假设每个玩家的玩法都会使他的分数最大化。

示例 1:
输入: [1, 5, 2]
输出: False
解释: 一开始,玩家1可以从1和2中进行选择。
如果他选择2(或者1),那么玩家2可以从1(或者2)和5中进行选择。如果玩家2选择了5,那么玩家1则只剩下1(或者2)可选。
所以,玩家1的最终分数为 1 + 2 = 3,而玩家2为 5。
因此,玩家1永远不会成为赢家,返回 False。

示例 2:
输入: [1, 5, 233, 7]
输出: True
解释: 玩家1一开始选择1。然后玩家2必须从5和7中进行选择。无论玩家2选择了哪个,玩家1都可以选择233。
最终,玩家1(234分)比玩家2(12分)获得更多的分数,所以返回 True,表示玩家1可以成为赢家。

对于当前玩家,他有两种选择:左边或者右边的数。定义 maxDiff(l, r) 为剩余子数组\([l,r]\)时,当前玩家能取得的最大分差,那么

\[ \begin{equation*} \operatorname{maxDiff}(l, r) = \max \begin{cases} nums[l] - \operatorname{maxDiff}(l + 1, r)\\\\ nums[r] - \operatorname{maxDiff}(l, r - 1) \end{cases} \end{equation*} \]

对应的时间复杂度可以写出递归式,显然是指数级的: \[ f(n) = 2f(n-1) = O(2^n) \]

采用暴力解法可以AC,但运算时间很长,接近TLE边缘 (6300ms)。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# AC
# Time Complexity: O(2^N)
# Slow: 6300ms
from typing import List

class Solution:

def maxDiff(self, l: int, r:int) -> int:
if l == r:
return self.nums[l]
return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
return self.maxDiff(0, len(nums) - 1) >= 0

从调用图也很容易看出是指数级的复杂度
486 Predict the Winner 暴力解法调用图 n=4

上图中我们有重复计算的节点,例如[1-2]节点被计算了两次。使用 lru_cache 大法,在maxDiff 上仅加了一句,就能以复杂度 \(O(n^2)\)和运行时间 43ms AC。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# AC
# Time Complexity: O(N^2)
# Fast: 43ms
from functools import lru_cache
from typing import List

class Solution:

@lru_cache(maxsize=None)
def maxDiff(self, l: int, r:int) -> int:
if l == r:
return self.nums[l]
return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
return self.maxDiff(0, len(nums) - 1) >= 0
动态规划解法调用图可以看出节点 [1-2] 这次没有被计算两次。
486 Predict the Winner 动归解法调用图 n=4

Leetcode 464 Can I Win (中等)

类似但稍有难度的题目 Leetcode 464 Can I Win。难点在于使用了位的状态压缩。

在 "100 game" 这个游戏中,两名玩家轮流选择从 1 到 10 的任意整数,累计整数和,先使得累计整数和达到 100 的玩家,即为胜者。
如果我们将游戏规则改为 “玩家不能重复使用整数” 呢?
例如,两个玩家可以轮流从公共整数池中抽取从 1 到 15 的整数(不放回),直到累计整数和 >= 100。
给定一个整数 maxChoosableInteger (整数池中可选择的最大数)和另一个整数 desiredTotal(累计和),判断先出手的玩家是否能稳赢(假设两位玩家游戏时都表现最佳)?
你可以假设 maxChoosableInteger 不会大于 20, desiredTotal 不会大于 300。

示例:
输入:
maxChoosableInteger = 10
desiredTotal = 11
输出:
false
解释:
无论第一个玩家选择哪个整数,他都会失败。
第一个玩家可以选择从 1 到 10 的整数。
如果第一个玩家选择 1,那么第二个玩家只能选择从 2 到 10 的整数。
第二个玩家可以通过选择整数 10(那么累积和为 11 >= desiredTotal),从而取得胜利.
同样地,第一个玩家选择任意其他整数,第二个玩家都会赢。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# AC
# Time Complexity: O:(2^m*m), m: maxChoosableInteger
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
def recurse(self, status: int, currentTotal: int) -> bool:
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return True
if not self.recurse(new_status, currentTotal + i):
return True
return False


def canIWin(self, maxChoosableInteger: int, desiredTotal: int) -> bool:
self.maxChoosableInteger = maxChoosableInteger
self.desiredTotal = desiredTotal

sum = maxChoosableInteger * (maxChoosableInteger + 1) / 2
if sum < desiredTotal:
return False
return self.recurse(0, 0)

上面的代码算法复杂度为\(O(m 2^m)\),m是maxChoosableInteger。由于所有状态的数量是\(2^m\),对于每个状态,最多会尝试 \(m\) 走法。

Minimax 算法

至此,我们AC了leetcode中的几道零和回合制博弈游戏。事实上,在这个领域有通用的算法:回合制博弈下的minimax。算法背景如下,两个玩家轮流玩,第一个玩家max的目的是将游戏的效用最大化,第二个玩家min则是最小化效用。比如,下面的节点表示玩家选取节点后游戏的效用,当两个玩家都能采取最优策略,Minimax 算法从底层节点来计算,游戏的结果是最终max 玩家会得到-7。

Wikipedia Minimax 例子

Minimax Python 3伪代码如下。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
def minimax(node: Node, depth: int, maximizingPlayer: bool) -> int:
if depth == 0 or is_terminal(node):
return evaluate_terminal(node)
if maximizingPlayer:
value:int = −∞
for child in node:
value = max(value, minimax(child, depth − 1, False))
return value
else: # minimizing player
value := +∞
for child in node:
value = min(value, minimax(child, depth − 1, True))
return value

Minimax: 486 Predict the Winner

我们知道486 Predict the Winner 是有minimax解法的,但如何具体实现,其难点在于如何定义合适的游戏价值或者效用。之前的解法中,我们定义maxDiff(l, r) 来表示当前玩家面临子区间 \([l, r]\) 时能取得的最大分差。对于minimax算法,max 玩家要最大化游戏价值,min玩家要最小化游戏价值。先考虑最简单情况即只有一个数x时,若定义max玩家在此局面下得到这个数时游戏价值为 +x,则min玩家为-x,即max玩家得到的所有数为正(\(+a_1 + a_2 + ... = A\)),min玩家得到的所有数为负(\(-b_1 - b_2 - ... = -B\))。至此,max玩家的目标就是 \(max(A-B)\) ,min玩家是 \(min(A-B)\)。有了精确的定义和优化目标,代码只需要套一下上面的模版。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# AC
from functools import lru_cache
from typing import List

class Solution:
# max_player: max(A - B)
# min_player: min(A - B)
@lru_cache(maxsize=None)
def minimax(self, l: int, r: int, isMaxPlayer: bool) -> int:
if l == r:
return self.nums[l] * (1 if isMaxPlayer else -1)

if isMaxPlayer:
return max(
self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))
else:
return min(
-self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
-self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
v = self.minimax(0, len(nums) - 1, True)
return v >= 0
Minimax 486 调用图 nums=[1, 5, 2, 4]

Minimax: 464 Can I Win

该题目是很典型的此类游戏,即结果为赢输平,但是中间的状态没有直接对应的游戏价值。对于这样的问题,一般定义为,max玩家胜,价值 +1,min玩家胜,价值-1,平则0。下面的AC代码实现了 Minimax 算法。算法中针对两个玩家都有剪枝(没有剪枝无法AC)。具体来说,max玩家一旦在某一节点取得胜利(value=1),就停止继续向下搜索,因为这是他能取得的最好分数。同理,min玩家一旦取得-1也直接返回上层节点。这个剪枝可以泛化成 alpha beta剪枝算法。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# AC
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
# currentTotal < desiredTotal
def minimax(self, status: int, currentTotal: int, isMaxPlayer: bool) -> int:
import math
if status == self.allUsed:
return 0 # draw: no winner

if isMaxPlayer:
value = -math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return 1 # shortcut
value = max(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
if value == 1:
return 1
return value
else:
value = math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return -1 # shortcut
value = min(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
if value == -1:
return -1
return value

Alpha-Beta 剪枝

在464 Can I Win minimax 算法代码实现中,我们发现有剪枝优化空间。对于每个节点,定义两个值alpha 和 beta,表示从根节点到目前局面时,max玩家保证能取得的最小值以及min玩家能保证取得的最大值。初始时,根节点alpha = −∞ , beta = +∞,表示游戏最终的价值在区间 [−∞, +∞]中。在向下遍历的过程中,子节点先继承父节点的 alpha beta 值进而继承区间 [alpha, beta]。当子节点在向下遍历的时候同步更新alpha 或者 beta,一旦区间[alpha, beta]非法就立即向上返回。举个Wikimedia的例子来进一步说明:

  1. 根节点初始时: alpha = −∞, beta = +∞

  2. 根节点,最左边子节点返回4后: alpha = 4, beta = +∞

  3. 根节点,中间子节点返回5后: alpha = 5, beta = +∞

  4. 最右Min节点(标1节点),初始时: alpha = 5, beta = +∞

  5. 最右Min节点(标1节点),第一个子节点返回1后: alpha = 5, beta = 1

此时,最右Min节点的alpha, beta形成了无效区间[5, 1],满足了剪枝条件,因此可以不用计算它的第二个和第三个子节点。如果剩余子节点返回值 > 1,比如2,由于这是个min节点,将会被已经到手的1替换。若其他子节点返回值 < 1,但由于min的父节点有效区间是[5, +∞],已经保证了>=5,小于5的值也会被忽略。

Wikimedia Alpha Beta 剪枝例子

Alpha Beta 剪枝 Python 3伪代码如下

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def alpha_beta(node: Node, depth: int, α: int, β: int, maximizingPlayer: bool) -> int:
if depth == 0 or is_terminal(node):
return evaluate_terminal(node)
if maximizingPlayer:
value: int = −∞
for child in node:
value = max(value, alphabeta(child, depth − 1, α, β, False))
α = max(α, value)
if α >= β:
break # β cut-off
return value
else:
value: int = +∞
for child in node:
value = min(value, alphabeta(child, depth − 1, α, β, True))
β = min(β, value)
if β <= α:
break # α cut-off
return value

Alpha-Beta Pruning: 486 Predict the Winner

用 Alpha-Beta 剪枝 再次AC 486。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# AC
import math
from functools import lru_cache
from typing import List

class Solution:
def alpha_beta(self, l: int, r: int, curr: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
if l == r:
return curr + self.nums[l] * (1 if isMaxPlayer else -1)

if isMaxPlayer:
ret = self.alpha_beta(l + 1, r, curr + self.nums[l], not isMaxPlayer, alpha, beta)
alpha = max(alpha, ret)
if alpha >= beta:
return alpha
ret = max(ret, self.alpha_beta(l, r - 1, curr + self.nums[r], not isMaxPlayer, alpha, beta))
return ret
else:
ret = self.alpha_beta(l + 1, r, curr - self.nums[l], not isMaxPlayer, alpha, beta)
beta = min(beta, ret)
if alpha >= beta:
return beta
ret = min(ret, self.alpha_beta(l, r - 1, curr - self.nums[r], not isMaxPlayer, alpha, beta))
return ret

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
v = self.alpha_beta(0, len(nums) - 1, 0, True, -math.inf, math.inf)
return v >= 0

Alpha-Beta Pruning: 464 Can I Win

464 Alpha-Beta 剪枝版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# AC
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
# currentTotal < desiredTotal
def alpha_beta(self, status: int, currentTotal: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
import math
if status == self.allUsed:
return 0 # draw: no winner

if isMaxPlayer:
value = -math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return 1 # shortcut
value = max(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
alpha = max(alpha, value)
if alpha >= beta:
return value
return value
else:
value = math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return -1 # shortcut
value = min(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
beta = min(beta, value)
if alpha >= beta:
return value
return value

C++, Java, Javascript AC 486 Predict the Winner

最后介绍一种不同的DP实现:用C++, Java, Javascript 实现自底向上的DP解法来AC leetcode 486,当然其他语言没有Python的lru_cache大法。以下实现中,注意DP解的构建顺序,先解决小规模的问题,并在此基础上计算稍大的问题。值得一提的是,以下的循环写法严格保证了 \(n^2\) 次循环,但是自顶向下的计划递归可能会少于 \(n^2\)次循环。

Java AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// AC
class Solution {
public boolean PredictTheWinner(int[] nums) {
int n = nums.length;
int[][] dp = new int[n][n];
for (int i = 0; i < n; i++) {
dp[i][i] = nums[i];
}

for (int l = n - 1; l >= 0; l--) {
for (int r = l + 1; r < n; r++) {
dp[l][r] = Math.max(
nums[l] - dp[l + 1][r],
nums[r] - dp[l][r - 1]);
}
}
return dp[0][n - 1] >= 0;
}
}

C++ AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// AC
class Solution {
public:
bool PredictTheWinner(vector<int>& nums) {
int n = nums.size();
vector<vector<int>> dp(n, vector<int>(n, 0));
for (int i = 0; i < n; i++) {
dp[i][i] = nums[i];
}
for (int l = n - 1; l >= 0; l--) {
for (int r = l + 1; r < n; r++) {
dp[l][r] = max(nums[l] - dp[l + 1][r], nums[r] - dp[l][r - 1]);
}
}
return dp[0][n - 1] >= 0;
}
};

Javascript AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* @param {number[]} nums
* @return {boolean}
*/
var PredictTheWinner = function(nums) {
const n = nums.length;
const dp = new Array(n).fill().map(() => new Array(n));

for (let i = 0; i < n; i++) {
dp[i][i] = nums[i];
}

for (let l = n - 1; l >=0; l--) {
for (let r = i + 1; r < n; r++) {
dp[l][r] = Math.max(nums[l] - dp[l + 1][r],nums[r] - dp[l][r - 1]);
}
}

return dp[0][n-1] >=0;
};
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×