Tech Blog

旅行商问题(TSP)是计算机算法中经典的NP hard 问题。 在本系列文章中,我们将首先使用动态规划 AC aizu中的TSP问题,然后再利用深度学习求大规模下的近似解。深度学习应用解决问题时先以PyTorch实现监督学习算法 Pointer Network,进而结合强化学习来无监督学习,提高数据使用效率。 本系列完整列表如下:

  • 第一篇: 递归DP方法 AC AIZU TSP问题

  • 第二篇: 二维空间TSP数据集及其DP解法

  • 第三篇: 深度学习 Pointer Networks 的 Pytorch实现

  • 第四篇: 搜寻最有可能路径:Viterbi算法和其他

  • 第五篇: 深度强化学习无监督算法的 Pytorch实现

TSP 问题回顾

TSP可以用图模型来表达,无论有向图或无向图,无论全连通图或者部分连通的图都可以作为TSP问题。 Wikipedia TSP 中举了一个无向全连通的TSP例子。如下图所示,四个顶点A,B,C,D构成无向全连通图。TSP问题要求在所有遍历所有点后返回初始点的回路中找到最短的回路。例如,\(A \rightarrow B \rightarrow C \rightarrow D \rightarrow A\)\(A \rightarrow C \rightarrow B \rightarrow D \rightarrow A\) 都是有效的回路,但是TSP需要返回这些回路中的最短回路(注意,最短回路可能会有多条)。

Wikipedia 4个顶点组成的图

无论是哪种类型的图,我们都能用邻接矩阵表示出一个图。上面的Wikipedia中的图可以用下面的矩阵来描述。

\[ \begin{matrix} & \begin{matrix}A&B&C&D\end{matrix} \\\\ \begin{matrix}A\\\\B\\\\C\\\\D\end{matrix} & \begin{bmatrix}-&20&42&35\\\\20&-&30&34\\\\42&30&-&12\\\\35&34&12&-\end{bmatrix}\\\\ \end{matrix} \]

当然,大多数情况下,TSP问题会被限定在欧氏空间,即二维地图中的全连通无向图。因为,如果将顶点表示一个地理位置,一般来说它可以和其他所有顶点连通,回来的距离相同,由此构成无向图。

AIZU TSP 问题

AIZU在线题库 有一道有向不完全连通图的TSP问题。给定V个顶点和E条边,输出最小回路值。例如,题目里的例子如下所示,由4个顶点和6条单向边构成。

这个示例的答案是16,对应的回路是 \(0\rightarrow1\rightarrow3\rightarrow2\rightarrow0\),由下图的红色边构成。注意,这个题目可能不存在合法解,原因是无回路存在,此时返回-1,可以合理地理解成无穷大。

暴力解法

一种暴力方法是枚举所有可能的从某一顶点的回路,取其中的最小值即可。下面的 Python 示例如何枚举4个顶点构成的图中从顶点0出发的所有回路。

{linenos
1
2
3
4
5
from itertools import permutations
v = [1,2,3]
p = permutations(v)
for t in list(p):
print([0] + list(t) + [0])

所有从顶点0出发的回路如下:

{linenos
1
2
3
4
5
6
[0, 1, 2, 3, 0]
[0, 1, 3, 2, 0]
[0, 2, 1, 3, 0]
[0, 2, 3, 1, 0]
[0, 3, 1, 2, 0]
[0, 3, 2, 1, 0]

很显然,这种方式的时间复杂度是 O(\(n!\)),无法通过AIZU。

动态规划求解

我们可以使用位状态压缩的动态规划来AC这道题。 首先,需要将回路过程中的状态编码成二进制的表示。例如,在四顶点的例子中,如果顶点2和1都被访问过,并且此时停留在顶点1。将已经访问的顶点对应的位置1,那么编码成0110,此外,还需要保存当前顶点的位置,因此我们将代表状态的数组扩展成二维,第一维是位状态,第二维是顶点所在位置,即 \(dp[bitstate][v]\)。这个例子的状态表示就是 \(dp["0110"][1]\)

状态转移方程如下: \[ dp[bitstate][v] = \min ( dp[bitstate \cup \{u\}][u] + dist(v,u) \mid u \notin bitstate ) \] 这种方法对应的时间复杂度是 O(\(n^2*2^n\) ),因为总共有 \(2^n * n\) 个状态,而每个状态又需要一次遍历。虽然都是指数级复杂度,但是它们的巨大区别由下面可以看出区别。

\(n!\) \(n^2*2^n\)
n=8 40320 16384
n=10 3628800 102400
n=12 479001600 589824
n=14 87178291200 3211264

暂停思考一下为什么状态压缩DP能工作。注意到之前暴力解法中其实是有很多重复计算,下面红圈表示重复的计算节点。

在本篇中,我们将会用Python 3和Java 8 实现自顶向下的DP 缓存版本。这种方式比较符合直觉,因为我们不需要预先考虑计算节点的依赖关系。在Java中我们使用了一个小技巧,dp数组初始化成Integer.MAX_VALUE,如此只需要一条语句就能完成更新dp值。

1
res = Math.min(res, s + g.edges[v][u]);

当然,为了AC 这道题,我们需要区分出真正无法到达的情况并返回-1。 在Python实现中,也可以使用同样的技巧,但是这次示例一般的实现方法:将dp数组初始化成-1并通过 if-else 来区分不同情况。

1
2
3
4
5
6
7
INT_INF = -1

if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])

下面附完整的Python 3和Java 8的AC代码,同步在 github

AIZU Java 8 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// passed http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DPL_2_A
import java.util.Arrays;
import java.util.Scanner;

public class Main {
public static class Graph {
public final int V_NUM;
public final int[][] edges;

public Graph(int V_NUM) {
this.V_NUM = V_NUM;
this.edges = new int[V_NUM][V_NUM];
for (int i = 0; i < V_NUM; i++) {
Arrays.fill(this.edges[i], Integer.MAX_VALUE);
}
}

public void setDist(int src, int dest, int dist) {
this.edges[src][dest] = dist;
}

}

public static class TSP {
public final Graph g;
long[][] dp;

public TSP(Graph g) {
this.g = g;
}

public long solve() {
int N = g.V_NUM;
dp = new long[1 << N][N];
for (int i = 0; i < dp.length; i++) {
Arrays.fill(dp[i], -1);
}

long ret = recurse(0, 0);
return ret == Integer.MAX_VALUE ? -1 : ret;
}

private long recurse(int state, int v) {
int ALL = (1 << g.V_NUM) - 1;
if (dp[state][v] >= 0) {
return dp[state][v];
}
if (state == ALL && v == 0) {
dp[state][v] = 0;
return 0;
}
long res = Integer.MAX_VALUE;
for (int u = 0; u < g.V_NUM; u++) {
if ((state & (1 << u)) == 0) {
long s = recurse(state | 1 << u, u);
res = Math.min(res, s + g.edges[v][u]);
}
}
dp[state][v] = res;
return res;

}

}

public static void main(String[] args) {

Scanner in = new Scanner(System.in);
int V = in.nextInt();
int E = in.nextInt();
Graph g = new Graph(V);
while (E > 0) {
int src = in.nextInt();
int dest = in.nextInt();
int dist = in.nextInt();
g.setDist(src, dest, dist);
E--;
}
System.out.println(new TSP(g).solve());
}
}

AIZU Python 3 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import List

INT_INF = -1

class Graph:
v_num: int
edges: List[List[int]]

def __init__(self, v_num: int):
self.v_num = v_num
self.edges = [[INT_INF for c in range(v_num)] for r in range(v_num)]

def setDist(self, src: int, dest: int, dist: int):
self.edges[src][dest] = dist


class TSPSolver:
g: Graph
dp: List[List[int]]

def __init__(self, g: Graph):
self.g = g
self.dp = [[None for c in range(g.v_num)] for r in range(1 << g.v_num)]

def solve(self) -> int:
return self._recurse(0, 0)

def _recurse(self, v: int, state: int) -> int:
"""

:param v:
:param state:
:return: -1 means INF
"""
dp = self.dp
edges = self.g.edges

if dp[state][v] is not None:
return dp[state][v]

if (state == (1 << self.g.v_num) - 1) and (v == 0):
dp[state][v] = 0
return dp[state][v]

ret: int = INT_INF
for u in range(self.g.v_num):
if (state & (1 << u)) == 0:
s: int = self._recurse(u, state | 1 << u)
if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])
dp[state][v] = ret
return ret


def main():
V, E = map(int, input().split())
g: Graph = Graph(V)
for _ in range(E):
src, dest, dist = map(int, input().split())
g.setDist(src, dest, dist)

tsp: TSPSolver = TSPSolver(g)
print(tsp.solve())


if __name__ == "__main__":
main()

Leetcode 1227 是一道有意思的概率题,本篇将从多个角度来讨论这道题。题目如下

有 n 位乘客即将登机,飞机正好有 n 个座位。第一位乘客的票丢了,他随便选了一个座位坐下。 剩下的乘客将会: 如果他们自己的座位还空着,就坐到自己的座位上, 当他们自己的座位被占用时,随机选择其他座位,第 n 位乘客坐在自己的座位上的概率是多少?

示例 1: 输入:n = 1 输出:1.00000 解释:第一个人只会坐在自己的位置上。

示例 2: 输入: n = 2 输出: 0.50000 解释:在第一个人选好座位坐下后,第二个人坐在自己的座位上的概率是 0.5。

提示: 1 <= n <= 10^5

假设规模为n时答案为f(n),一般来说,这种递推问题在数学形式上可能有关于n的简单数学表达式(closed form),或者肯定有f(n)关于f(n-k)的递推表达式。工程上,我们可以通过通过多次模拟即蒙特卡罗模拟来算得近似的数值解。

Monte Carlo 模拟发现规律

首先,我们先来看看如何高效的用代码来模拟。根据题意的描述过程,直接可以写出下面代码。seats为n大小的bool 数组,每个位置表示此位置是否已经被占据。然后依次给第i个人按题意分配座位。注意,每次参数随机数范围在[0,n-1],因此,会出现已经被占据的情况,此时需要再次随机,直至分配到空位。

暴力直接模拟

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def simulate_bruteforce(n: int) -> bool:
"""
Simulates one round. Unbounded time complexity.
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
seats[random.randint(0, n - 1)] = True
else:
if not seats[i]: # i-th has his seat
seats[i] = True
else:
while True:
rnd = random.randint(0, n - 1) # random until no conflicts
if not seats[rnd]:
seats[rnd] = True
break
return not seats[n-1]

运行上面的代码来模拟 n 从 2 到10 的情况,每种情况跑500次模拟,输出如下

1
2
3
4
5
6
7
8
9
10
1 => 1.0
2 => 0.55
3 => 0.54
4 => 0.486
5 => 0.488
6 => 0.498
7 => 0.526
8 => 0.504
9 => 0.482
10 => 0.494

发现当 n>=2 时,似乎概率都是0.5。

标准答案

其实,这道题的标准答案就是 n=1 为1,n>=2 为0.5。下面是 python 3 标准答案。本篇后面会从多个角度来探讨为什么是0.5 。

{linenos
1
2
3
class Solution:
def nthPersonGetsNthSeat(self, n: int) -> float:
return 1.0 if n == 1 else 0.5

O(n) 改进算法

上面的暴力直接模拟版本有个最大的问题是当n很大时,随机分配座位会产生大量冲突,因此,最坏复杂度是没有任何上限的。解决方法是每次发生随机分配时保证不冲突,能直接选到空位。下面是一种最坏复杂度O(n)的模拟过程,seats数组初始话成 0,1,...,n-1,表示座位号。当第i个人登机时,seats[i:n] 的值为他可以选择的座位集合,而seats[0:i]为已经被占据的座位集合。由于[i: n]是连续空间,产生随机数就能保证不冲突。当第i个人选完座位时,将他选中的seats[k]和seats[i] 交换,保证第i+i个人面临的seats[i+1:n]依然为可选座位集合。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def simulate_online(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [i for i in range(n)]

def swap(i, j):
tmp = seats[i]
seats[i] = seats[j]
seats[j] = tmp

# for each person, the seats array idx available are [i, n-1]
for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
swap(rnd, 0)
else:
if seats[i] == i: # i-th still has his seat
pass
else:
rnd = random.randint(i, n - 1) # selects idx from [i, n-1]
swap(rnd, i)
return seats[n-1] == n - 1

递推思维

这一节我们用数学递推思维来解释0.5的解。令f(n) 为第 n 位乘客坐在自己的座位上的概率,考察第一个人的情况(first step analysis),有三种可能

  1. 第一个人选了第一个即自己的座位,那么最后一个人一定能保证坐在自己的座位。
  2. 第一个人选了最后一个人的座位,无论中间什么过程,最后一个人无法坐到自己座位
  3. 第一个人选了第i个座位,(1<i<n),那么第i个人前面的除了第一个外的人都会坐在自己位置上,第i个人由于没有自己座位,随机在剩余的座位1,座位 [i+1,n] 中随机选择,此时,问题转变为f(n-i+1),如下图所示。
第一个人选了位置i
第i个人将问题转换成f(n-i+1)

通过上面分析,得到概率递推关系如下

\[ f(n) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 1 & & p=\frac{1}{n} \quad \text{选了第一个位置} \\\\\\ f(n-i+1) & & p=\frac{1}{n} \quad \text{选了第i个位置,1<i<n} \\\\\\ 0 & & p=\frac{1}{n} \quad \text{选了第n个位置} \end{array} \right. \end{align*} \]

即f(n)的递推式为: \[ f(n) = \frac{1}{n} + \frac{1}{n} \times [ f(n-1) + f(n-2) + ...+ f(2)], \quad n>=2 \] 同理,f(n+1)递推式如下 \[ f(n+1) = \frac{1}{n+1} + \frac{1}{n+1} \times [ f(n) + f(n-1) + ...+ f(2)] \] \((n+1)f(n+1) - nf(n)\) 抵消 \(f(n-1) + ...f(2)\) 项,可得 \[ (n+1)f(n+1) - nf(n) = f(n) \]\[ f(n+1) = f(n) = \frac{1}{2} \quad n>=2 \]

用数学归纳法也可以证明 n>=2 时 f(n)=0.5。

简化的思考方式

我们再仔细思考一下上面的第三种情况,就是第一个人坐了第i个座位,1<i<n,此时,程序继续,不产生结果,直至产生结局1或者2,也就是case 1和2是真正的结局节点,它们产生的概率相同,因此答案是1/2。

从调用图可以看出这种关系,由于中间节点 f(4),f(3),f(2)生成Case 1和2的概率一样,因此无论它们之间是什么关系,最后结果都是1/2.

知乎上有个很形象的类比理解方式

考虑一枚硬币,正面向上的概率为 1/n,反面也是,立起来的概率为 (n-2)/n 。我们规定硬币立起来重新抛,但重新抛时,n会至少减小1。求结果为反面的概率。这样很显然结果为 1/2 。

这里,正面向上对应Case 2,反面对应Case 1。

这种思想可以写出如下代码,seats为 n 大小的bool 数组,当第i个人(0<i<n)发现自己座位被占的话,此时必然seats[0]没有被占,同时seats[i+1:]都是空的。假设seats[0]被占的话,要么是第一个人占的,要么是第p个人(p<i)坐了,两种情况下乱序都已经恢复了,此时第i个座位一定是空的。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def simulate(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
seats[rnd] = True
else:
if not seats[i]: # i-th still has his seat
seats[i] = True
else:
# 0 must not be available, now we have 0 and [i+1, n-1],
rnd = random.randint(i, n - 1)
if rnd == i:
seats[0] = True
else:
seats[rnd] = True
return not seats[n-1]

上一篇我们从原理层面解析了AlphaGo Zero如何改进MCTS算法,通过不断自我对弈,最终实现从零棋力开始训练直至能够打败任何高手。在本篇中,我们在已有的N子棋OpenAI Gym 环境中用Pytorch实现一个简化版的AlphaGo Zero算法。本篇所有代码在 github MyEncyclopedia/ConnectNGym 中,其中部分参考了SongXiaoJun 的 AlphaZero_Gomoku

AlphaGo Zero MCTS 树节点

上一篇中,我们知道AlphaGo Zero 的MCTS树搜索是基于传统MCTS 的UCT (UCB for Tree)的改进版PUCT(Polynomial Upper Confidence Trees)。局面节点的PUCT值由两部分组成,分别是代表Exploitation的action value Q值,和代表Exploration的U值。 \[ PUCT(s, a) =Q(s,a) + U(s,a) \] U值计算由这些参数决定:系数\(c_{puct}\),节点先验概率P(s, a) ,父节点访问次数,本节点的访问次数。具体公式如下 \[ U(s, a)=c_{p u c t} \cdot P(s, a) \cdot \frac{\sqrt{\Sigma_{b} N(s, b)}}{1+N(s, a)} \]

因此在实现过程中,对于一个树节点来说,需要保存其Q值、节点访问次数 _visit_num和先验概率 _prior。其中,_prior在节点初始化后不变,Q值和 visit_num随着游戏MCTS模拟进程而改变。此外,节点保存了 parent和_children变量,用于维护父子关系。c_puct为class variable,作为全局参数。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
class TreeNode:
"""
MCTS Tree Node
"""

c_puct: ClassVar[int] = 5 # class-wise global param c_puct, exploration weight factor.

_parent: TreeNode
_children: Dict[int, TreeNode] # map from action to TreeNode
_visit_num: int
_Q: float # Q value of the node, which is the mean action value.
_prior: float

和上面的计算公式相对应,下列代码根据节点状态计算PUCT(s, a)。

{linenos
1
2
3
4
5
6
7
8
9
10
class TreeNode:

def get_puct(self) -> float:
"""
Computes AlphaGo Zero PUCT (polynomial upper confidence trees) of the node.

:return: Node PUCT value.
"""
U = (TreeNode.c_puct * self._prior * np.sqrt(self._parent._visit_num) / (1 + self._visit_num))
return self._Q + U

AlphaGo Zero MCTS在playout时遇到已经被展开的节点,会根据selection规则选择子节点,该规则本质上是在所有子节点中选择最大的PUCT值的节点。

\[ a=\operatorname{argmax}_a(PUCT(s, a))=\operatorname{argmax}_a(Q(s,a) + U(s,a)) \]

{linenos
1
2
3
4
5
6
7
8
9
class TreeNode:

def select(self) -> Tuple[Pos, TreeNode]:
"""
Selects an action(Pos) having max UCB value.

:return: Action and corresponding node
"""
return max(self._children.items(), key=lambda act_node: act_node[1].get_puct())

新的叶节点一旦在playout时产生,关联的 v 值会一路向上更新至根节点,具体新节点的v值将在下一节中解释。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class TreeNode:

def propagate_to_root(self, leaf_value: float):
"""
Updates current node with observed leaf_value and propagates to root node.

:param leaf_value:
:return:
"""
if self._parent:
self._parent.propagate_to_root(-leaf_value)
self._update(leaf_value)

def _update(self, leaf_value: float):
"""
Updates the node by newly observed leaf_value.

:param leaf_value:
:return:
"""
self._visit_num += 1
# new Q is updated towards deviation from existing Q
self._Q += 0.5 * (leaf_value - self._Q)

AlphaGo Zero MCTS Player 实现

AlphaGo Zero MCTS 在训练阶段分为如下几个步骤。游戏初始局面下,整个局面树的建立由子节点的不断被探索而丰富起来。AlphaGo Zero对弈一次即产生了一次完整的游戏开始到结束的动作系列。在对弈过程中的某一游戏局面,需要采样海量的playout,又称MCTS模拟,以此来决定此局面的下一步动作。一次playout可视为在真实游戏状态树的一种特定采样,playout可能会产生游戏结局,生成真实的v值;也可能explore 到新的叶子节点,此时v值依赖策略价值网络的输出,目的是利用训练的神经网络来产生高质量的游戏对战局面。每次playout会从当前给定局面递归向下,向下的过程中会遇到下面三种节点情况。

  • 若局面节点是游戏结局(叶子节点),可以得到游戏的真实价值 z。从底部节点带着z向上更新沿途节点的Q值,直至根节点(初始局面)。
  • 若局面节点从未被扩展过(叶子节点),此时会将局面编码输入到策略价值双头网络,输出结果为网络预估的action分布和v值。Action分布作为节点先验概率P(s, a)来初始化子节点,预估的v值和上面真实游戏价值z一样,从叶子节点向上沿途更新到根节点。
  • 若局面节点已经被扩展过,则根据PUCT的select规则继续选择下一节点。

海量的playout模拟后,建立了游戏状态树的节点信息。但至此,AI玩家只是收集了信息,还仍未给定局面落子,而落子的决定由Play规则产生。下图展示了给定局面(Current节点)下,MCST模拟进行的多次playout探索后生成的局面树,play规则根据这些节点信息,产生Current 节点的动作分布 \(\pi\) ,确定下一步落子。

MCTS Playout和Play关系

Play 给定局面

对于当前需要做落子决定的某游戏局面\(s_0\),根据如下play公式生成落子分布 $$ ,子局面的落子概率正比于其访问次数的某次方。其中,某次方的倒数称为温度参数(Temperature)。

\[ \pi\left(a \mid s_{0}\right)=\frac{N\left(s_{0}, a\right)^{1 / \tau}}{\sum_{b} N\left(s_{0}, b\right)^{1 / \tau}} \]

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class MCTSAlphaGoZeroPlayer(BaseAgent):

def _next_step_play_act_probs(self, game: ConnectNGame) -> Tuple[List[Pos], ActionProbs]:
"""
For the given game status, run playouts number of times specified by self._playout_num.
Returns the action distribution according to AlphaGo Zero MCTS play formula.

:param game:
:return: actions and their probability
"""

for n in range(self._playout_num):
self._playout(copy.deepcopy(game))

act_visits = [(act, node._visit_num) for act, node in self._current_root._children.items()]
acts, visits = zip(*act_visits)
act_probs = softmax(1.0 / MCTSAlphaGoZeroPlayer.temperature * np.log(np.array(visits) + 1e-10))

return acts, act_probs

在训练模式时,考虑到偏向exploration的目的,在\(\pi\) 落子分布的基础上增加了 Dirichlet 分布。

\[ P(s,a) = (1-\epsilon)*\pi(a \mid s) + \epsilon * \boldsymbol{\eta} \quad (\boldsymbol{\eta} \sim \operatorname{Dir}(0.3)) \]

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MCTSAlphaGoZeroPlayer(BaseAgent):

def get_action(self, board: PyGameBoard) -> Pos:
"""
Method defined in BaseAgent.

:param board:
:return: next move for the given game board.
"""
return self._get_action(copy.deepcopy(board.connect_n_game))[0]

def _get_action(self, game: ConnectNGame) -> Tuple[MoveWithProb]:
epsilon = 0.25
avail_pos = game.get_avail_pos()
move_probs: ActionProbs = np.zeros(game.board_size * game.board_size)
assert len(avail_pos) > 0

# the pi defined in AlphaGo Zero paper
acts, act_probs = self._next_step_play_act_probs(game)
move_probs[list(acts)] = act_probs
if self._is_training:
# add Dirichlet Noise when training in favour of exploration
p_ = (1-epsilon) * act_probs + epsilon * np.random.dirichlet(0.3 * np.ones(len(act_probs)))
move = np.random.choice(acts, p=p_)
assert move in game.get_avail_pos()
else:
move = np.random.choice(acts, p=act_probs)

self.reset()
return move, move_probs

一次完整的对弈

一次完整的AI对弈就是从初始局面迭代play直至游戏结束,对弈生成的数据是一系列的 $(s, , z) $。

如下图 s0 到 s5 是某次井字棋的对弈。最终结局是先手黑棋玩家赢,即对于黑棋玩家 z = +1。需要注意的是:z = +1 是对于所有黑棋面临的局面,即s0, s2, s4,而对应的其余白棋玩家来说 z = -1。

一局完整对弈

\[ \begin{align*} &0: (s_0, \vec{\pi_0}, +1) \\ &1: (s_1, \vec{\pi_1}, -1) \\ &2: (s_2, \vec{\pi_2}, +1) \\ &3: (s_3, \vec{\pi_3}, -1) \\ &4: (s_4, \vec{\pi_4}, +1) \end{align*} \]

以下代码展示如何在AI对弈时收集数据 $(s, , z) $

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class MCTSAlphaGoZeroPlayer(BaseAgent):

def self_play_one_game(self, game: ConnectNGame) \
-> List[Tuple[NetGameState, ActionProbs, NDArray[(Any), np.float]]]:
"""

:param game:
:return:
Sequence of (s, pi, z) of a complete game play. The number of list is the game play length.
"""

states: List[NetGameState] = []
probs: List[ActionProbs] = []
current_players: List[np.float] = []

while not game.game_over:
move, move_probs = self._get_action(game)
states.append(convert_game_state(game))
probs.append(move_probs)
current_players.append(game.current_player)
game.move(move)

current_player_z = np.zeros(len(current_players))
current_player_z[np.array(current_players) == game.game_result] = 1.0
current_player_z[np.array(current_players) == -game.game_result] = -1.0
self.reset()

return list(zip(states, probs, current_player_z))

Playout 代码实现

一次playout会从当前局面根据PUCT selection规则下沉到叶子节点,如果此叶子节点非游戏终结点,则会扩展当前节点生成下一层新节点,其先验分布由策略价值网络输出的action分布决定。一次playout最终会得到叶子节点的 v 值,并沿着MCTS树向上更新沿途的所有父节点 Q值。 从上一篇文章已知,游戏节点的数量随着参数而指数级增长,举例来说,井字棋(k=3,m=n=3)的状态数量是5478,k=3,m=n=4时是6035992 ,k=m=n=4时是9722011 。如果我们将初始局面节点作为根节点,同时保存海量playout探索得到的局面节点,实现时会发现我们无法将所有探索到的局面节点都保存在内存中。这里的一种解决方法是在一次self play中每轮playout之后,将根节点重置成落子的节点,从而有效控制整颗局面树中的节点数量。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class MCTSAlphaGoZeroPlayer(BaseAgent):

def _playout(self, game: ConnectNGame):
"""
From current game status, run a sequence down to a leaf node, either because game ends or unexplored node.
Get the leaf value of the leaf node, either the actual reward of game or action value returned by policy net.
And propagate upwards to root node.

:param game:
"""
player_id = game.current_player

node = self._current_root
while True:
if node.is_leaf():
break
act, node = node.select()
game.move(act)

# now game state is a leaf node in the tree, either a terminal node or an unexplored node
act_and_probs: Iterator[MoveWithProb]
act_and_probs, leaf_value = self._policy_value_net.policy_value_fn(game)

if not game.game_over:
# case where encountering an unexplored leaf node, update leaf_value estimated by policy net to root
for act, prob in act_and_probs:
game.move(act)
child_node = node.expand(act, prob)
game.undo()
else:
# case where game ends, update actual leaf_value to root
if game.game_result == ConnectNGame.RESULT_TIE:
leaf_value = ConnectNGame.RESULT_TIE
else:
leaf_value = 1 if game.game_result == player_id else -1
leaf_value = float(leaf_value)

# Update leaf_value and propagate up to root node
node.propagate_to_root(-leaf_value)

编码游戏局面

为了将信息有效的传递给策略神经网络,必须从当前玩家的角度编码游戏局面。局面不仅要反映棋盘上黑白棋子的位置,也需要考虑最后一个落子的位置以及是否为当前玩家棋局。因此,我们将某局面按照当前玩家来编码,返回类型为4个棋盘大小组成的ndarray,即shape [4, board_size, board_size],其中

  1. 第一个数组编码当前玩家的棋子位置
  2. 第二个数组编码对手玩家棋子位置
  3. 第三个表示最后落子位置
  4. 第四个全1表示此局面为先手(黑棋)局面,全0表示白棋局面

例如之前游戏对弈中的前四步:

s1->s2 后局面s2的编码:当前玩家为黑棋玩家,编码局面s2 返回如下ndarray,数组[0] 为s2黑子位置,[1]为白子位置,[2]表示最后一个落子(1, 1) ,[3] 全1表示当前是黑棋落子的局面。

编码黑棋玩家局面 s2
s2->s3 后局面s3的编码:当前玩家为白棋玩家,编码返回如下,数组[0] 为s3白子位置,[1]为黑子位置,[2]表示最后一个落子(1, 0) ,[3] 全0表示当前是白棋落子的局面。
编码白棋玩家局面 s3

具体代码实现如下。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
NetGameState = NDArray[(4, Any, Any), np.int]


def convert_game_state(game: ConnectNGame) -> NetGameState:
"""
Converts game state to type NetGameState as ndarray.

:param game:
:return:
Of shape 4 * board_size * board_size.
[0] is current player positions.
[1] is opponent positions.
[2] is last move location.
[3] all 1 meaning move by black player, all 0 meaning move by white.
"""
state_matrix = np.zeros((4, game.board_size, game.board_size))

if game.action_stack:
actions = np.array(game.action_stack)
move_curr = actions[::2]
move_oppo = actions[1::2]
for move in move_curr:
state_matrix[0][move] = 1.0
for move in move_oppo:
state_matrix[1][move] = 1.0
# indicate the last move location
state_matrix[2][actions[-1]] = 1.0
if len(game.action_stack) % 2 == 0:
state_matrix[3][:, :] = 1.0 # indicate the colour to play
return state_matrix[:, ::-1, :]

策略价值网络训练

策略价值网络是一个共享参数 \(\theta\) 的双头网络,给定上面的游戏局面编码会产生预估的p和v。

\[ \vec{p_{\theta}}, v_{\theta}=f_{\theta}(s) \] 结合真实游戏对弈后产生三元组数据 $(s, , z) $ ,按照论文中的loss 来训练神经网络。 \[ l=\sum_{t}\left(v_{\theta}\left(s_{t}\right)-z_{t}\right)^{2}-\vec{\pi_{t}} \cdot \log \left(\vec{p_{\theta}}\left(s_{t}\right)\right) + c {\lVert \theta \rVert}^2 \]

下面代码为Pytorch backward部分。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def backward_step(self, state_batch: List[NetGameState], probs_batch: List[ActionProbs],
value_batch: List[NDArray[(Any), np.float]], lr) -> Tuple[float, float]:
if self.use_gpu:
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
probs_batch = Variable(torch.FloatTensor(probs_batch).cuda())
value_batch = Variable(torch.FloatTensor(value_batch).cuda())
else:
state_batch = Variable(torch.FloatTensor(state_batch))
probs_batch = Variable(torch.FloatTensor(probs_batch))
value_batch = Variable(torch.FloatTensor(value_batch))

self.optimizer.zero_grad()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

log_act_probs, value = self.policy_value_net(state_batch)
# loss = (z - v)^2 - pi*T * log(p) + c||theta||^2
value_loss = F.mse_loss(value.view(-1), value_batch)
policy_loss = -torch.mean(torch.sum(probs_batch * log_act_probs, 1))
loss = value_loss + policy_loss
loss.backward()
self.optimizer.step()
entropy = -torch.mean(torch.sum(torch.exp(log_act_probs) * log_act_probs, 1))
return loss.item(), entropy.item()

参考资料

AlphaGo Zero是Deepmind 最后一代AI围棋算法,因为已经达到了棋类游戏AI的终极目的:给定任何游戏规则,AI从零出发只通过自我对弈的方式提高,最终可以取得超越任何对手(包括顶级人类棋手和上一代AlphaGo)的能力。换种方式说,当给定足够多的时间和计算资源,可以取得无限逼近游戏真实解的能力。这一篇,我们深入分析AlphaGo Zero的设计理念和关键组件的细节并解释组件之间的关联。下一篇中,我们将在已有的N子棋OpenAI Gym 环境中用Pytorch实现一个简化版的AlphaGo Zero算法。

AlphaGo Zero 综述

AlphaGo Zero 作为Deepmind在围棋领域的最后一代AI Agent,已经可以达到棋类游戏的终极目标:在只给定游戏规则的情况下,AI 棋手从最初始的随机状态开始,通过不断的自我对弈的强化学习来实现超越以往任何人类棋手和上一代Alpha的能力,并且同样的算法和模型应用到了其他棋类也得出相同的效果。这一篇,从原理上来解析AlphaGo Zero的运行方式。

AlphaGo Zero 算法由三种元素构成:强化学习(RL)、深度学习(DL)和蒙特卡洛树搜索(MCTS,Monte Carlo Tree Search)。核心思想是基于神经网络的Policy Iteration强化学习,即最终学的是一个深度学习的policy network,输入是某棋盘局面 s,输出是此局面下可走位的概率分布:\(p(a|s)\)

在第一代AlphaGo算法中,这个初始policy network通过收集专业人类棋手的海量棋局训练得来,再采用传统RL 的Monte Carlo Tree Search Rollout 技术来强化现有的AI对于局面落子(Policy Network)的判断。Monte Carlo Tree Search Rollout 简单说来就是海量棋局模拟,AI Agent在通过现有的Policy Network策略完成一次从某局面节点到最终游戏胜负结束的对弈,这个完整的对弈叫做rollout,又称playout。完成一次rollout之后,通过局面树层层回溯到初始局面节点,并在回溯过程中同步修订所有经过的局面节点的统计指标,修正原先policy network对于落子导致输赢的判断。通过海量并发的棋局模拟来提升基准policy network,即在各种局面下提高好的落子的\(p(a_{win}|s)\),降低坏的落子的\(p(a_{lose}|s)\)

举例如下井字棋局面:
局面s

基准policy network返回 p(s) 如下 \[ p(a|s) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 0.1, & & a = (0,2) \\ 0.05, & & a = (1,0) \\ 0.5, & & a = (1,1) \\ 0.05, & & a = (1,2)\\ 0.2, & & a = (2,0) \\ 0.05, & & a = (2,1) \\ 0.05, & & a = (2,2) \end{array} \right. \end{align*} \] 通过海量并发模拟后,修订成如下的action概率分布,然后通过policy iteration迭代新的网络来逼近 \(p'\) 就提高了棋力。 \[ p'(a|s) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 0, & & a = (0,2) \\ 0, & & a = (1,0) \\ 0.9, & & a = (1,1) \\ 0, & & a = (1,2)\\ 0, & & a = (2,0) \\ 0, & & a = (2,1) \\ 0.1, & & a = (2,2) \end{array} \right. \end{align*} \]

蒙特卡洛树搜索(MCTS)概述

Monte Carlo Tree Search 是Monte Carlo 在棋类游戏中的变种,棋类游戏的一大特点是可以用动作(move)联系的决策树来表示,树的节点数量取决于分支的数量和树的深度。MCTS的目的是在树节点非常多的情况下,通过实验模拟(rollout, playout)的方式来收集尽可能多的局面输赢情况,并基于这些统计信息,将搜索资源的重点均衡地放在未被探索的节点和值得探索的节点上,减少在大概率输的节点上的模拟资源投入。传统MCTS有四个过程:Selection, Expansion, Simulation 和Backpropagation。下图是Wikipedia 的例子:

  • Selection:从根节点出发,根据现有统计的信息和selection规则,选择子节点递归向下做决定,后面我们会详细介绍AlphaGo的UCB规则。图中节点的数字,例如根节点11/21,分别代表赢的次数和总模拟次数。从根节点一路向下分别选择节点 7/10, 1/6直到叶子节点3/3,叶子节点表示它未被探索过。
  • Expansion:由于3/3节点未被探索过,初始化其所有子节点为0/0,图中3/3只有一个子节点。后面我们会看到神经网络在初始化子节点的时候起到的指导作用,即所有子节点初始权重并非相同,而是由神经网络给出估计。
  • Simulation:重复selection和expansion,根据游戏规则递归向下直至游戏结束。
  • Backpropagation:游戏结束在终点节点产生游戏真实的价值,回溯向上调整所有父节点的统计状态。

权衡 Exploration 和 Exploitation

在不断扩张决策树并收集节点统计信息的同时,MCTS根据规则来权衡探索目的(采样不足)或利用目的来做决策,这个权衡规则叫做Upper Confidence Bound(UCB)。典型的UCB公式如下:w表示通过节点的赢的次数,n表示通过节点的总次数,N是父节点的访问次数,c是调节Exploration 和 Exploitation权重的超参。

\[ {\frac{w_i}{n_i}} + c \sqrt{\frac{\ln N_i}{n_i}} \]

假设某节点有两个子节点s1, s2,它们的统计指标为 s1: w/n = 3/4,s2: w/n = 6/8,由于两者输赢比率一样,因此根据公式,访问次数少的节点出于Exploration的目的胜出,MCTS最终决定从s局面走向s1。

从第一性原理来理解AlphaGo Zero

前一代的AlphaGo已经战胜了世界冠军,取得了空前的成就,AlphaGo Zero 的设计目标变得更加General,去除围棋相关的处理和知识,用统一的框架和算法来解决棋类问题。 1. 无人工先验数据

改进之前需要专家棋手对弈数据来冷启动初始棋力

  1. 无特定游戏特征工程

    无需围棋特定技巧,只包含下棋规则,可以适用到所有棋类游戏

  2. 单一神经网络

    统一Policy Network和Value Network,使用一个共享参数的双头神经网络

  3. 简单树搜索

    去除传统MCTS的Rollout 方式,用神经网络来指导MCTS更有效产生搜索策略

搜索空间的两个优化原则

尽管理论上围棋是有解的,即先手必赢、被逼平或必输,通过遍历所有可能局面可以求得解。同理,通过海量模拟所有可能游戏局面,也可以无限逼近所有局面下的真实输赢概率,直至收敛于局面落子的确切最佳结果。但由于围棋棋局的数目远远大于宇宙原子数目,3^361 >> 10^80,因此需要将计算资源有效的去模拟值得探索的局面,例如对于显然的被动局面减小模拟次数,所以如何有效地减小搜索空间是AlphaGo Zero 需要解决的重大问题。David Silver 在Deepmind AlphaZero - Mastering Games Without Human Knowledge中提到AlphaGo Zero 采用两个原则来有效减小搜索空间。

原则1: 通过Value Network减少搜索的深度

Value Network 通过预测给定局面的value来直接预测最终结果,思想和上一期Minimax DP 策略中直接缓存当前局面的胜负状态一样,减少每次必须靠模拟到最后才能知道当前局面的输赢概率,或者需要多层树搜索才能知道输赢概率。

原则2: 通过Policy Network减少搜索的宽度

搜索广度的减少是由Policy Network预估来达成的,将下一步搜索局限在高概率的动作上,大幅度提升原先MCTS新节点生成后冷启动的搜索宽度。

神经网络结构

AlphaGo Zero 使用一个单一的深度神经网络来完成policy 和value的预测。具体实现方式是将policy network和value network合并成一个共享参数 $ $ 的双头网络。其中z是真实游戏结局的效用,范围为[-1, 1] 。

\[ (p, v)=f_{\theta}(s) \] \[ p_{a}=\operatorname{Pr}(a \mid s) \] \[ v = \mathop{\mathbb{E}}[z|s] \]

Monte Carlo Tree Search (MCTS) 建立了棋局搜索树,节点的初始状态由神经网络输出的p和v值来估计,由此初始的动作策略和价值预判就会建立在高手的水平之上。模拟一局游戏之后向上回溯,会同步更新路径上节点的统计数值并生成更好的MCTS搜索策略 \(\vec{\pi}\)。进一步来看,MCTS和神经网络互相形成了正循环。神经网络指导了未知节点的MCTS初始搜索策略,产生自我对弈游戏结局后,通过减小 \(\vec{p}\)\(\vec{\pi}\)的 Loss ,最终又提高了神经网络对于局面的估计能力。神经网络value network的提升也是通过不断减小网络预测的结果和最终结果的差异来提升。 因此,具体神经网络的Loss函数由三部分组成,value network的损失,policy network的损失以及正则项。 \[ l=\sum_{t}\left(v_{\theta}\left(s_{t}\right)-z_{t}\right)^{2}-\vec{\pi}_{t} \cdot \log \left(\vec{p}_{\theta}\left(s_{t}\right)\right) + c {\lVert \theta \rVert}^2 \]

AlphaGo Zero MCTS 具体过程

AlphaGo Plays Games Against Itself

AlphaGo Zero的MCTS和传统MCTS都有相似的四个过程,但AlphaGo Zero的MCTS步骤相对更复杂。 首先,除了W/N统计指标之外,AlphaGo Zero的MCTS保存了决策边 a|s 的Q(s,a):Action Value,也就是Q-Learning中的Q值,其初始值由神经网络给出。此外,Q 值也用于串联自底向上更新节点的Value值。具体说来,当某个新节点被Explore后,会将网络给出的Q值向上传递,并逐层更新父节点的Q值。当游戏结局产生时,也会向上更新所有父节点的Q值。 此外对于某一游戏局面s进行多次模拟,每次在局面s出发向下探索,每次探索在已知节点按Selection规则深入一步,直至达到未探索的局面或者游戏结束,产生Q值后向上回溯到最初局面s,回溯过程中更新路径上的局面的统计值或者Q值。在多次模拟结束后根据Play的算法,决定局面s的下一步行动。尽管每次模拟探索可能会深入多层,但最终play阶段的算法规则仅决定给定局面s的下一层落子动作。多次向下探索的优势在于:

  1. 探索和采样更多的叶子节点,在更多信息下做决策。

  2. 通过average out多次模拟下一层落子决定,尽可能提升MCTS策略的下一步判断能力,提高 \(\pi\) 能力,更有效指导神经网络,提高其学习效率。

New Policy Network V' is Trained to Predict Winner
  1. Selection:

从游戏局面s开始,选择a向下递归,直至未展开的节点(搜索树中的叶子节点)或者游戏结局。具体在局面s下选择a的规则由以下UCB(Upper Confidence Bound)决定
\[ a=\operatorname{argmax}_a(Q(s,a) + u(s,a)) \]

其中,Q(s,a) 和u(s,a) 项分别代表Exploitation 和Exploration。两项相加来均衡Exploitation和Exploration,保证初始时每个节点被explore,在有足够多的信息时逐渐偏向exploitation。

\[ u(s, a)=c_{p u c t} \cdot P(s, a) \cdot \frac{\sqrt{\Sigma_{b} N(s, b)}}{1+N(s, a)} \]

  1. Expand

当遇到一个未展开的节点(搜索树中的叶子节点)时,对其每个子节点使用现有网络进行预估,即

\[ (p(s), v(s))=f_{\theta}(s) \]

  1. Backup

当新的叶子节点展开时或者到达终点局面时,向上更新父节点的Q值,具体公式为 \[ Q(s, a)=\frac{1}{N(s, a)} \sum_{s^{\prime} \mid s, a \rightarrow s^{\prime}} V\left(s^{\prime}\right) \]

  1. Play

多次模拟结束后,使用得到搜索概率分布 ${a} $来确定最终的落子动作。正比于访问次数的某次方 $ {a} N(s, a)^{1 / }\(,其中\)$为温度参数(temperature parameter)。

New Policy Network V' is Trained to Predict Winner

参考资料

继上一篇完成了井字棋(N子棋)的minimax 最佳策略后,我们基于Pygame来创造一个图形游戏环境,可供人机和机器对弈,为后续模拟AlphaGo的自我强化学习算法做环境准备。OpenAI Gym 在强化学习领域是事实标准,我们最终封装成OpenAI Gym的接口。本篇所有代码都在github.com/MyEncyclopedia/ConnectNGym

井字棋、五子棋 Pygame 实现

Pygame 井字棋玩家对弈效果

Python 上有Tkinter,PyQt等跨平台GUI类库,主要用于桌面程序编程,但此类库容量较大,编程也相对麻烦。Pygame具有代码少,开发快的优势,比较适合快速开发五子棋这类桌面小游戏。 ### Pygame 极简入门

与所有的GUI开发相同,Pygame也是基于事件的单线程编程模型。下面的例子包含了显示一个最简单GUI窗口,操作系统产生事件并发送到Pygame窗口,while True 控制了python主线程永远轮询事件。我们在这里仅仅判断了当前是否是关闭应用程序事件,如果是则退出进程。此外,clock 用于控制FPS。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
import sys
import pygame
pygame.init()
display = pygame.display.set_mode((800,600))
clock = pygame.time.Clock()

while True:
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit(0)
else:
pygame.display.update()
clock.tick(1)

PyGameBoard 主体代码

PyGameBoard类封装了Pygame实现游戏交互和显示的逻辑。上一篇中,我们完成了ConnectNGame逻辑,这里PyGameBoard需要在初始化时,指定传入ConnectNGame 实例(见下图),支持通过API 方式改变其状态,也支持GUI交互方式等待人类玩家的输入。next_user_input(self)实现了等待人类玩家输入的逻辑,本质上是循环检查GUI事件直到有合法的落子产生。
PyGameBoard Class Diagram
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class PyGameBoard:

def __init__(self, connectNGame: ConnectNGame):
self.connectNGame = connectNGame
pygame.init()

def next_user_input(self) -> Tuple[int, int]:
self.action = None
while not self.action:
self.check_event()
self._render()
self.clock.tick(60)
return self.action

def move(self, r: int, c: int) -> int:
return self.connectNGame.move(r, c)

if __name__ == '__main__':
connectNGame = ConnectNGame()
pygameBoard = PyGameBoard(connectNGame)
while not pygameBoard.isGameOver():
pos = pygameBoard.next_user_input()
pygameBoard.move(*pos)

pygame.quit()

check_event 较之极简版本增加了处理用户输入事件,这里我们仅支持人类玩家鼠标输入。方法_handle_user_input 将鼠标点击事件转换成棋盘行列值,并判断点击位置是否合法,合法则返回落子位置,类型为Tuple[int, int],例如(0, 0)表示棋盘最左上角位置。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def check_event(self):
for e in pygame.event.get():
if e.type == pygame.QUIT:
pygame.quit()
sys.exit(0)
elif e.type == pygame.MOUSEBUTTONDOWN:
self._handle_user_input(e)

def _handle_user_input(self, e: Event) -> Tuple[int, int]:
origin_x = self.start_x - self.edge_size
origin_y = self.start_y - self.edge_size
size = (self.board_size - 1) * self.grid_size + self.edge_size * 2
pos = e.pos
if origin_x <= pos[0] <= origin_x + size and origin_y <= pos[1] <= origin_y + size:
if not self.connectNGame.gameOver:
x = pos[0] - origin_x
y = pos[1] - origin_y
r = int(y // self.grid_size)
c = int(x // self.grid_size)
valid = self.connectNGame.checkAction(r, c)
if valid:
self.action = (r, c)
return self.action

OpenAI Gym 接口规范

OpenAI Gym规范了Agent和环境(Env)之间的互动,核心抽象接口类是gym.Env,自定义的游戏环境需要继承Env,并实现 reset、step和render方法。下面我们看一下如何具体实现ConnectNGym的这几个方法:

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ConnectNGym(gym.Env):

def reset(self) -> ConnectNGame:
"""Resets the state of the environment and returns an initial observation.

Returns:
observation (object): the initial observation.
"""
raise NotImplementedError


def step(self, action: Tuple[int, int]) -> Tuple[ConnectNGame, int, bool, None]:
"""Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
to reset this environment's state.

Accepts an action and returns a tuple (observation, reward, done, info).

Args:
action (object): an action provided by the agent

Returns:
observation (object): agent's observation of the current environment
reward (float) : amount of reward returned after previous action
done (bool): whether the episode has ended, in which case further step() calls will return undefined results
info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)
"""
raise NotImplementedError



def render(self, mode='human'):
"""
Renders the environment.

The set of supported modes varies per environment. (And some
environments do not support rendering at all.) By convention,
if mode is:

- human: render to the current display or terminal and
return nothing. Usually for human consumption.
- rgb_array: Return an numpy.ndarray with shape (x, y, 3),
representing RGB values for an x-by-y pixel image, suitable
for turning into a video.
- ansi: Return a string (str) or StringIO.StringIO containing a
terminal-style text representation. The text can include newlines
and ANSI escape sequences (e.g. for colors).

Note:
Make sure that your class's metadata 'render.modes' key includes
the list of supported modes. It's recommended to call super()
in implementations to use the functionality of this method.

Args:
mode (str): the mode to render with
"""
raise NotImplementedError

reset 方法

1
def reset(self) -> ConnectNGame

重置环境状态,并返回给Agent重置后环境下观察到的状态。ConnectNGym内部维护了ConnectNGame实例作为自身状态,每个agent落子后会更新这个实例。由于棋类游戏对于玩家来说是完全信息的,我们直接返回ConnectNGame的deepcopy。

step 方法

1
def step(self, action: Tuple[int, int]) -> Tuple[ConnectNGame, int, bool, None]

Agent 选择了某一action后,由环境来执行这个action并返回4个值:1. 执行后的环境Agent观察到的状态;2. 环境执行了这个action回馈给agent的reward;3. 环境是否结束;4. 其余信息。

step方法是最核心的接口,因此举例来说明ConnectNGym中的输入和输出:

初始状态
状态 ((0, 0, 0), (0, 0, 0), (0, 0, 0))

Agent A 选择action = (0, 0),执行ConnectNGym.step 后返回值:status = ((1, 0, 0), (0, 0, 0), (0, 0, 0)),reward = 0,game_end = False

状态 ((1, 0, 0), (0, 0, 0), (0, 0, 0))

Agent B 选择action = (1, 1),执行ConnectNGym.step 后返回值:status = ((1, 0, 0), (0, -1, 0), (0, 0, 0)),reward = 0,game_end = False

状态 ((1, 0, 0), (0, -1, 0), (0, 0, 0))
重复此过程直至游戏结束,下面是5步后游戏可能达到的最终状态
终结状态 ((1, 1, 1), (-1, -1, 0), (0, 0, 0))

此时step的返回值为:status = ((1, 1, 1), (-1, -1, 0), (0, 0, 0)),reward = 1,game_end = True

render 方法

1
def render(self, mode='human')

展现环境,通过mode区分是否是人类玩家。

ConnectNGym 代码

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ConnectNGym(gym.Env):

def __init__(self, pygameBoard: PyGameBoard, isGUI=True, displaySec=2):
self.pygameBoard = pygameBoard
self.isGUI = isGUI
self.displaySec = displaySec
self.action_space = spaces.Discrete(pygameBoard.board_size * pygameBoard.board_size)
self.observation_space = spaces.Discrete(pygameBoard.board_size * pygameBoard.board_size)
self.seed()
self.reset()

def reset(self) -> ConnectNGame:
self.pygameBoard.connectNGame.reset()
return copy.deepcopy(self.pygameBoard.connectNGame)

def step(self, action: Tuple[int, int]) -> Tuple[ConnectNGame, int, bool, None]:
# assert self.action_space.contains(action)

r, c = action
reward = REWARD_NONE
result = self.pygameBoard.move(r, c)
if self.pygameBoard.isGameOver():
reward = result

return copy.deepcopy(self.pygameBoard.connectNGame), reward, not result is None, None

def render(self, mode='human'):
if not self.isGUI:
self.pygameBoard.connectNGame.drawText()
time.sleep(self.displaySec)
else:
self.pygameBoard.display(sec=self.displaySec)

def get_available_actions(self) -> List[Tuple[int, int]]:
return self.pygameBoard.getAvailablePositions()

井字棋(N子棋)Minimax策略玩家

图中当k=3,m=n=3即井字棋游戏中,两个minimax策略玩家的对弈效果,游戏结局符合已知的结论:井字棋的解是先手被对方逼平。

Minimax策略AI对弈

镜像游戏状态的DP处理

上一篇中,我们确认了井字棋的总状态数是5478。当k=3, m=n=4时是6035992,k=4, m=n=4时是9722011,总的来说游戏状态数是以指数级增长的。上一版minimax DP策略还有改善的空间,第一种是旋转格局的处理。对于任意一种棋盘格局可以得到90度旋转后的另外三种格局,它们的最佳结局是一致的。因此,我们在递归过程中解得某一棋盘格局后,将其另外三种旋转后格局的解也一起缓存起来。例如:

游戏状态1
旋转后的三种游戏状态
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def similarStatus(self, status: Tuple[Tuple[int, ...]]) -> List[Tuple[Tuple[int, ...]]]:
ret = []
rotatedS = status
for _ in range(4):
rotatedS = self.rotate(rotatedS)
ret.append(rotatedS)
return ret

def rotate(self, status: Tuple[Tuple[int, ...]]) -> Tuple[Tuple[int, ...]]:
N = len(status)
board = [[ConnectNGame.AVAILABLE] * N for _ in range(N)]

for r in range(N):
for c in range(N):
board[c][N - 1 - r] = status[r][c]

return tuple([tuple(board[i]) for i in range(N)])

Minimax 策略预计算

之前我们对每个棋局去计算最佳的下一步,并在此过程中做了剪枝,即当已经找到当前玩家必胜落子时直接返回。这对于单一局面的计算是较优的,但是AI Agent 需要在每一步都重复这个过程,当棋盘大小>3时运算非常耗时,因此我们来做第二种优化。初始空棋盘时使用Minimax来保证遍历所有状态,缓存所有棋局的最佳结果。对于AI Agent面临的每个棋局只需查找此棋局下所有的可能落子位置,并返回最佳决定,这样大大减少了每次棋局下重复的minimax递归计算。相关代码如下。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class PlannedMinimaxStrategy(Strategy):
def __init__(self, game: ConnectNGame):
super().__init__()
self.game = copy.deepcopy(game)
self.dpMap = {} # game_status => result, move
self.result = self.minimax(game.getStatus())


def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
game = copy.deepcopy(game)

player = game.currentPlayer
bestResult = player * -1 # assume opponent win as worst result
bestMove = None
for move in game.getAvailablePositions():
game.move(*move)
status = game.getStatus()
game.undo()

result = self.dpMap[status]

if player == ConnectNGame.PLAYER_A:
bestResult = max(bestResult, result)
else:
bestResult = min(bestResult, result)
# update bestMove if any improvement
bestMove = move if bestResult == result else bestMove
print(f'move {move} => {result}')

return bestResult, bestMove

Agent 类和对弈逻辑

Agent 类的抽象并不是 OpenAI Gym的规范,出于代码扩展性,我们也封装了Agent基类及其子类,包括AI玩家和人类玩家。BaseAgent需要子类实现 act方法,默认实现为随机决定。

{linenos
1
2
3
4
5
6
class BaseAgent(object):
def __init__(self):
pass

def act(self, game: PyGameBoard, available_actions):
return random.choice(available_actions)

AIAgent 实现act并代理给 strategy 的action方法。

{linenos
1
2
3
4
5
6
7
8
class AIAgent(BaseAgent):
def __init__(self, strategy: Strategy):
self.strategy = strategy

def act(self, game: PyGameBoard, available_actions):
result, move = self.strategy.action(game.connectNGame)
assert move in available_actions
return move

HumanAgent 实现act并代理给 PyGameBoard 的next_user_input方法。

{linenos
1
2
3
4
5
6
class HumanAgent(BaseAgent):
def __init__(self):
pass

def act(self, game: PyGameBoard, available_actions):
return game.next_user_input()
Agent Class Diagram

下面代码展示如何将Agent,ConnectNGym,PyGameBoard 等所有上述类串联起来,完成人人对弈,人机对弈。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def play_ai_vs_ai(env: ConnectNGym):
plannedMinimaxAgent = AIAgent(PlannedMinimaxStrategy(env.pygameBoard.connectNGame))
play(env, plannedMinimaxAgent, plannedMinimaxAgent)


def play(env: ConnectNGym, agent1: BaseAgent, agent2: BaseAgent):
agents = [agent1, agent2]

while True:
env.reset()
done = False
agent_id = -1
while not done:
agent_id = (agent_id + 1) % 2
available_actions = env.get_available_actions()
agent = agents[agent_id]
action = agent.act(pygameBoard, available_actions)
_, reward, done, info = env.step(action)
env.render(True)

if done:
print(f'result={reward}')
time.sleep(3)
break


if __name__ == '__main__':
pygameBoard = PyGameBoard(connectNGame=ConnectNGame(board_size=3, N=3))
env = ConnectNGym(pygameBoard)
env.render(True)

play_ai_vs_ai(env)
Class Diagram 总览

继上一篇介绍了Minimax 和Alpha Beta 剪枝算法之后,本篇选择了Leetcode中的井字棋游戏题目,积累相关代码后实现井字棋游戏并扩展到五子棋和N子棋(战略井字棋),随后用Minimax和Alpha Beta剪枝算法解得小规模下N子棋的游戏结局,并分析其状态数量和每一步的最佳策略。后续篇章中,我们基于本篇代码完成一个N子棋的OpenAI Gym 图形环境,可用于人机对战或机器对战,并最终实现棋盘规模稍大的五子棋或者N子棋中的蒙特卡洛树搜索(MCTS)算法。

Leetcode 上的井字棋系列

Leetcode 1275. 找出井字棋的获胜者 (简单)

A 和 B 在一个 3 x 3 的网格上玩井字棋。
井字棋游戏的规则如下:
玩家轮流将棋子放在空方格 (" ") 上。
第一个玩家 A 总是用 "X" 作为棋子,而第二个玩家 B 总是用 "O" 作为棋子。
"X" 和 "O" 只能放在空方格中,而不能放在已经被占用的方格上。
只要有 3 个相同的(非空)棋子排成一条直线(行、列、对角线)时,游戏结束。
如果所有方块都放满棋子(不为空),游戏也会结束。
游戏结束后,棋子无法再进行任何移动。
给你一个数组 moves,其中每个元素是大小为 2 的另一个数组(元素分别对应网格的行和列),它按照 A 和 B 的行动顺序(先 A 后 B)记录了两人各自的棋子位置。
如果游戏存在获胜者(A 或 B),就返回该游戏的获胜者;如果游戏以平局结束,则返回 "Draw";如果仍会有行动(游戏未结束),则返回 "Pending"。
你可以假设 moves 都 有效(遵循井字棋规则),网格最初是空的,A 将先行动。

示例 1:
输入:moves = [[0,0],[2,0],[1,1],[2,1],[2,2]]
输出:"A"
解释:"A" 获胜,他总是先走。
"X " "X " "X " "X " "X "
" " -> " " -> " X " -> " X " -> " X "
" " "O " "O " "OO " "OOX"

示例 2: 输入:moves = [[0,0],[1,1],[0,1],[0,2],[1,0],[2,0]]
输出:"B"
解释:"B" 获胜。
"X " "X " "XX " "XXO" "XXO" "XXO"
" " -> " O " -> " O " -> " O " -> "XO " -> "XO "
" " " " " " " " " " "O "

第一种解法,检查A或者B赢的所有可能情况:某玩家占据8种连线的任意一种情况则胜利,我们使用八个变量来保存所有情况。下面的代码使用了一个小技巧,将moves转换成3x3的棋盘状态数组,元素的值为1,-1和0。1,-1代表两个玩家,0代表空的棋盘格子,其优势在于后续我们只需累加棋盘的值到八个变量中关联的若干个,再检查这八个变量是否满足取胜条件。例如,row[0]表示第一行的状态,当遍历一次所有棋盘格局后,row[0]为第一行的3个格子的总和,只有当row[0] == 3 才表明玩家A占据了第一行,-3表明玩家B占据了第一行。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# AC
from typing import List

class Solution:
def tictactoe(self, moves: List[List[int]]) -> str:
board = [[0] * 3 for _ in range(3)]
for idx, xy in enumerate(moves):
player = 1 if idx % 2 == 0 else -1
board[xy[0]][xy[1]] = player

turn = 0
row, col = [0, 0, 0], [0, 0, 0]
diag1, diag2 = False, False
for r in range(3):
for c in range(3):
turn += board[r][c]
row[r] += board[r][c]
col[c] += board[r][c]
if r == c:
diag1 += board[r][c]
if r + c == 2:
diag2 += board[r][c]

oWin = any(row[r] == 3 for r in range(3)) or any(col[c] == 3 for c in range(3)) or diag1 == 3 or diag2 == 3
xWin = any(row[r] == -3 for r in range(3)) or any(col[c] == -3 for c in range(3)) or diag1 == -3 or diag2 == -3

return "A" if oWin else "B" if xWin else "Draw" if len(moves) == 9 else "Pending"

下面我们给出另一种解法,这种解法虽然代码较多,但可以不必遍历棋盘每个格子,比上一种严格遍历一次棋盘的解法略为高效。原理如下,题目保证了moves过程中不会产生输赢结果,因此我们直接检查最后一个棋子向外的八个方向,若任意方向有三连子,则此玩家获胜。这种解法主要是为后续井字棋扩展到五子棋时判断每个落子是否产生输赢做代码准备。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# AC
from typing import List

class Solution:
def checkWin(self, r: int, c: int) -> bool:
north = self.getConnectedNum(r, c, -1, 0)
south = self.getConnectedNum(r, c, 1, 0)

east = self.getConnectedNum(r, c, 0, 1)
west = self.getConnectedNum(r, c, 0, -1)

south_east = self.getConnectedNum(r, c, 1, 1)
north_west = self.getConnectedNum(r, c, -1, -1)

north_east = self.getConnectedNum(r, c, -1, 1)
south_west = self.getConnectedNum(r, c, 1, -1)

if (north + south + 1 >= 3) or (east + west + 1 >= 3) or \
(south_east + north_west + 1 >= 3) or (north_east + south_west + 1 >= 3):
return True
return False

def getConnectedNum(self, r: int, c: int, dr: int, dc: int) -> int:
player = self.board[r][c]
result = 0
i = 1
while True:
new_r = r + dr * i
new_c = c + dc * i
if 0 <= new_r < 3 and 0 <= new_c < 3:
if self.board[new_r][new_c] == player:
result += 1
else:
break
else:
break
i += 1
return result

def tictactoe(self, moves: List[List[int]]) -> str:
self.board = [[0] * 3 for _ in range(3)]
for idx, xy in enumerate(moves):
player = 1 if idx % 2 == 0 else -1
self.board[xy[0]][xy[1]] = player

# only check last move
r, c = moves[-1]
win = self.checkWin(r, c)
if win:
return "A" if len(moves) % 2 == 1 else "B"

return "Draw" if len(moves) == 9 else "Pending"

Leetcode 794. 有效的井字游戏 (中等)

用字符串数组作为井字游戏的游戏板 board。当且仅当在井字游戏过程中,玩家有可能将字符放置成游戏板所显示的状态时,才返回 true。
该游戏板是一个 3 x 3 数组,由字符 " ","X" 和 "O" 组成。字符 " " 代表一个空位。
以下是井字游戏的规则:
玩家轮流将字符放入空位(" ")中。
第一个玩家总是放字符 “X”,且第二个玩家总是放字符 “O”。
“X” 和 “O” 只允许放置在空位中,不允许对已放有字符的位置进行填充。
当有 3 个相同(且非空)的字符填充任何行、列或对角线时,游戏结束。
当所有位置非空时,也算为游戏结束。
如果游戏结束,玩家不允许再放置字符。

示例 1:
输入: board = ["O ", " ", " "]
输出: false
解释: 第一个玩家总是放置“X”。

示例 2:
输入: board = ["XOX", " X ", " "]
输出: false
解释: 玩家应该是轮流放置的。

示例 3:
输入: board = ["XXX", " ", "OOO"]
输出: false

示例 4:
输入: board = ["XOX", "O O", "XOX"]
输出: true
说明:

游戏板 board 是长度为 3 的字符串数组,其中每个字符串 board[i] 的长度为 3。 board[i][j] 是集合 {" ", "X", "O"} 中的一个字符。

这道题第一反应是需要DFS来判断给定状态是否可达,但其实可以用上面1275的思路,即通过检验最终棋盘的一些特点来判断给定状态是否合法。比如,X和O的数量只有可能相同,或X比O多一个。其关键在于需要找到判断状态合法的充要条件,就可以在\(O(1)\) 时间复杂度完成判断。 此外,这道题给了我们井字棋所有可能状态数量的启示。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# AC
from typing import List

class Solution:

def convertCell(self, c:str):
return 1 if c == 'X' else -1 if c == 'O' else 0

def validTicTacToe(self, board: List[str]) -> bool:
turn = 0
row, col = [0, 0, 0], [0, 0, 0]
diag1, diag2 = False, False
for r in range(3):
for c in range(3):
turn += self.convertCell(board[r][c])
row[r] += self.convertCell(board[r][c])
col[c] += self.convertCell(board[r][c])
if r == c:
diag1 += self.convertCell(board[r][c])
if r + c == 2:
diag2 += self.convertCell(board[r][c])

xWin = any(row[r] == 3 for r in range(3)) or any(col[c] == 3 for c in range(3)) or diag1 == 3 or diag2 == 3
oWin = any(row[r] == -3 for r in range(3)) or any(col[c] == -3 for c in range(3)) or diag1 == -3 or diag2 == -3
if (xWin and turn == 0) or (oWin and turn == 1):
return False
return (turn == 0 or turn == 1) and (not xWin or not oWin)

Leetcode 348. 判定井字棋胜负 (中等,加锁)

请在 n × n 的棋盘上,实现一个判定井字棋(Tic-Tac-Toe)胜负的神器,判断每一次玩家落子后,是否有胜出的玩家。
在这个井字棋游戏中,会有 2 名玩家,他们将轮流在棋盘上放置自己的棋子。
在实现这个判定器的过程中,你可以假设以下这些规则一定成立:
每一步棋都是在棋盘内的,并且只能被放置在一个空的格子里;
一旦游戏中有一名玩家胜出的话,游戏将不能再继续;
一个玩家如果在同一行、同一列或者同一斜对角线上都放置了自己的棋子,那么他便获得胜利。

示例: 给定棋盘边长 n = 3, 玩家 1 的棋子符号是 "X",玩家 2 的棋子符号是 "O"。
TicTacToe toe = new TicTacToe(3);
toe.move(0, 0, 1); -> 函数返回 0 (此时,暂时没有玩家赢得这场对决)
|X| | |
| | | | // 玩家 1 在 (0, 0) 落子。
| | | |

toe.move(0, 2, 2); -> 函数返回 0 (暂时没有玩家赢得本场比赛)
|X| |O|
| | | | // 玩家 2 在 (0, 2) 落子。
| | | |

toe.move(2, 2, 1); -> 函数返回 0 (暂时没有玩家赢得比赛)
|X| |O|
| | | | // 玩家 1 在 (2, 2) 落子。
| | |X|

toe.move(1, 1, 2); -> 函数返回 0 (暂没有玩家赢得比赛)
|X| |O|
| |O| | // 玩家 2 在 (1, 1) 落子。
| | |X|

toe.move(2, 0, 1); -> 函数返回 0 (暂无玩家赢得比赛)
|X| |O|
| |O| | // 玩家 1 在 (2, 0) 落子。
|X| |X|

toe.move(1, 0, 2); -> 函数返回 0 (没有玩家赢得比赛)
|X| |O|
|O|O| | // 玩家 2 在 (1, 0) 落子.
|X| |X|

toe.move(2, 1, 1); -> 函数返回 1 (此时,玩家 1 赢得了该场比赛)
|X| |O|
|O|O| | // 玩家 1 在 (2, 1) 落子。
|X|X|X|

348 是道加锁题,对于每次玩家的move,可以用1275第二种解法中的checkWin 函数。下面代码给出了另一种基于1275解法一的方法:保存八个关键变量,每次落子后更新这个子所关联的某几个变量。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# AC
class TicTacToe:

def __init__(self, n:int):
"""
Initialize your data structure here.
:type n: int
"""
self.row, self.col, self.diag1, self.diag2, self.n = [0] * n, [0] * n, 0, 0, n

def move(self, row:int, col:int, player:int) -> int:
"""
Player {player} makes a move at ({row}, {col}).
@param row The row of the board.
@param col The column of the board.
@param player The player, can be either 1 or 2.
@return The current winning condition, can be either:
0: No one wins.
1: Player 1 wins.
2: Player 2 wins.
"""
if player == 2:
player = -1

self.row[row] += player
self.col[col] += player
if row == col:
self.diag1 += player
if row + col == self.n - 1:
self.diag2 += player

if self.n in [self.row[row], self.col[col], self.diag1, self.diag2]:
return 1
if -self.n in [self.row[row], self.col[col], self.diag1, self.diag2]:
return 2
return 0


井字棋最佳策略

井字棋的规模可以很自然的扩展成四子棋或五子棋等,区别在于棋盘大小和胜利时的连子数量。这类游戏最一般的形式为 M,n,k-game,中文可能翻译为战略井字游戏,表示棋盘大小为M x N,当k连子时获胜。 下面的ConnectNGame类实现了战略井字游戏(M=N)中,两个玩家轮流下子、更新棋盘状态和判断每次落子输赢等逻辑封装。其中undo方法用于撤销最后一个落子,方便在后续寻找最佳策略时回溯。

ConnectNGame

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ConnectNGame:

PLAYER_A = 1
PLAYER_B = -1
AVAILABLE = 0
RESULT_TIE = 0
RESULT_A_WIN = 1
RESULT_B_WIN = -1

def __init__(self, N:int = 3, board_size:int = 3):
assert N <= board_size
self.N = N
self.board_size = board_size
self.board = [[ConnectNGame.AVAILABLE] * board_size for _ in range(board_size)]
self.gameOver = False
self.gameResult = None
self.currentPlayer = ConnectNGame.PLAYER_A
self.remainingPosNum = board_size * board_size
self.actionStack = []

def move(self, r: int, c: int) -> int:
"""

:param r:
:param c:
:return: None: game ongoing
"""
assert self.board[r][c] == ConnectNGame.AVAILABLE
self.board[r][c] = self.currentPlayer
self.actionStack.append((r, c))
self.remainingPosNum -= 1
if self.checkWin(r, c):
self.gameOver = True
self.gameResult = self.currentPlayer
return self.currentPlayer
if self.remainingPosNum == 0:
self.gameOver = True
self.gameResult = ConnectNGame.RESULT_TIE
return ConnectNGame.RESULT_TIE
self.currentPlayer *= -1

def undo(self):
if len(self.actionStack) > 0:
lastAction = self.actionStack.pop()
r, c = lastAction
self.board[r][c] = ConnectNGame.AVAILABLE
self.currentPlayer = ConnectNGame.PLAYER_A if len(self.actionStack) % 2 == 0 else ConnectNGame.PLAYER_B
self.remainingPosNum += 1
self.gameOver = False
self.gameResult = None
else:
raise Exception('No lastAction')

def getAvailablePositions(self) -> List[Tuple[int, int]]:
return [(i,j) for i in range(self.board_size) for j in range(self.board_size) if self.board[i][j] == ConnectNGame.AVAILABLE]

def getStatus(self) -> Tuple[Tuple[int, ...]]:
return tuple([tuple(self.board[i]) for i in range(self.board_size)])

其中checkWin和1275解法二中的逻辑一致。

Minimax 算法

此战略井字游戏的逻辑代码,结合之前的minimax算法,可以实现游戏最佳策略。

先定义一个通用的策略基类和抽象方法 action。action表示给定一个棋盘状态,返回一个动作决定。返回Tuple的第一个int值表示估计走这一步的结局,第二个值类型是Tuple[int, int],表示这次落子的位置,例如(1,1)。

{linenos
1
2
3
4
5
6
7
8
class Strategy(ABC):

def __init__(self):
super().__init__()

@abstractmethod
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
pass
MinimaxStrategy 的逻辑和之前的minimax模版算法大致相同,多了保存最佳move对应的动作,用于最后返回。
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class MinimaxStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = copy.deepcopy(game)
result, move = self.minimax()
return result, move

def minimax(self) -> Tuple[int, Tuple[int, int]]:
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax()
game.undo()
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if ret == 1:
return 1, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax()
game.undo()
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if ret == -1:
return -1, move
return ret, bestMove
通过上面的代码可以画出初始两步的井字棋最终结局。对于先手O来说可以落9个位置,排除对称位置后只有三种,分别为角落,边上和正中。但无论哪一个位置作为先手,最好的结局都是被对方逼平,不存在必赢的开局。所以井字棋的结局是:如果两个玩家都采用最优策略(无失误),游戏结果为双方逼平。
井字棋第一步结局
下面分别画出三种开局后进一步的游戏结局。
井字棋角落开局
井字棋边上开局
井字棋中间开局

井字棋游戏状态数和解

有趣的是井字棋游戏的状态数量,简单的上限估算是\(3^9=19683\)。这显然是个较宽泛的上限,因为很多状态在游戏结束后无法达到。 这篇文章 Tic-Tac-Toe (Naughts and Crosses, Cheese and Crackers, etc 中列出了每一步的状态数,合计5478个。

Moves Positions Terminal Positions
0 1
1 9
2 72
3 252
4 756
5 1260 120
6 1520 148
7 1140 444
8 390 168
9 78 78
Total 5478 958

我们已经实现了井字棋的minimax策略,算法本质上遍历了所有情况,稍加改造后增加dp数组,就可以确认上面的总状态数。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

class CountingMinimaxStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = copy.deepcopy(game)
self.dpMap = {}
result, move = self.minimax(game.getStatus())
return result, move

def minimax(self, gameStatus: Tuple[Tuple[int, ...]]) -> Tuple[int, Tuple[int, int]]:
# print(f'Current {len(strategy.dpMap)}')

if gameStatus in self.dpMap:
return self.dpMap[gameStatus]

game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax(game.getStatus())
self.dpMap[game.getStatus()] = result, oppMove
else:
self.dpMap[game.getStatus()] = result, move
game.undo()
ret = max(ret, result)
bestMove = move if ret == result else bestMove
self.dpMap[gameStatus] = ret, bestMove
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)

if result is None:
assert not game.gameOver
result, oppMove = self.minimax(game.getStatus())
self.dpMap[game.getStatus()] = result, oppMove
else:
self.dpMap[game.getStatus()] = result, move
game.undo()
ret = min(ret, result)
bestMove = move if ret == result else bestMove
self.dpMap[gameStatus] = ret, bestMove
return ret, bestMove


if __name__ == '__main__':
tic_tac_toe = ConnectNGame(N=3, board_size=3)
strategy = CountingMinimaxStrategy()
strategy.action(tic_tac_toe)
print(f'Game States Number {len(strategy.dpMap)}')

运行程序证实了井字棋状态数为5478,下面是一些极小规模时代码运行结果:

3x3 4x4
k=3 5478 (Draw) 6035992 (Win)
k=4 9722011 (Draw)
k=5

根据 Wikipedia M,n,k-game, 列出了一些小规模下的游戏解:

3x3 4x4 5x5 6x6
k=3 Draw Win Win Win
k=4 Draw Draw Win
k=5 Draw Draw

值得一提的是,五子棋(棋盘15x15或以上)被 L. Victor Allis证明是先手赢。

Alpha-Beta剪枝策略

Alpha Beta 剪枝策略的代码如下(和之前代码比较类似,不再赘述):

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class AlphaBetaStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = game
result, move = self.alpha_beta(self.game.getStatus(), -math.inf, math.inf)
return result, move

def alpha_beta(self, gameStatus: Tuple[Tuple[int, ...]], alpha:int=None, beta:int=None) -> Tuple[int, Tuple[int, int]]:
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.alpha_beta(game.getStatus(), alpha, beta)
game.undo()
alpha = max(alpha, result)
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == 1:
return ret, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.alpha_beta(game.getStatus(), alpha, beta)
game.undo()
beta = min(beta, result)
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == -1:
return ret, move
return ret, bestMove

Alpha Beta 的DP版本中,由于lru_cache无法指定cache的有效参数,递归函数并没有传入alpha, beta。因此我们将alpha,beta参数隐式放入自己维护的栈中,并保证栈的状态和alpha_beta_dp函数调用状态一致。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class AlphaBetaDPStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = game
self.alphaBetaStack = [(-math.inf, math.inf)]
result, move = self.alpha_beta_dp(self.game.getStatus())
return result, move

@lru_cache(maxsize=None)
def alpha_beta_dp(self, gameStatus: Tuple[Tuple[int, ...]]) -> Tuple[int, Tuple[int, int]]:
alpha, beta = self.alphaBetaStack[-1]
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
self.alphaBetaStack.append((alpha, beta))
result, oppMove = self.alpha_beta_dp(game.getStatus())
self.alphaBetaStack.pop()
game.undo()
alpha = max(alpha, result)
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == 1:
return ret, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
self.alphaBetaStack.append((alpha, beta))
result, oppMove = self.alpha_beta_dp(game.getStatus())
self.alphaBetaStack.pop()
game.undo()
beta = min(beta, result)
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == -1:
return ret, move
return ret, bestMove

本系列,我们来看看在一种常见的组合游戏——回合制棋盘类游戏中,如何用算法来解决问题。首先,我们会介绍并解决搜索空间较小的问题,引入经典的博弈算法和相关理论,最终实现在大搜索空间中的Deep RL近似算法。在此基础上可以理解AlphaGo的原理和工作方式。 本系列的第一篇,我们介绍3个Leetcode中的零和回合制游戏,从最初的暴力解法,到动态规划最终演变成博弈论里的经典算法: minimax 以及 alpha beta 剪枝。

Leetcode 292 Nim Game (简单)

简单题 Leetcode 292 Nim Game

你和你的朋友,两个人一起玩 Nim游戏:桌子上有一堆石头,每次你们轮流拿掉 1 - 3 块石头。 拿掉最后一块石头的人就是获胜者。你作为先手。
你们是聪明人,每一步都是最优解。 编写一个函数,来判断你是否可以在给定石头数量的情况下赢得游戏。

示例:
输入: 4
输出: false
解释: 如果堆中有 4 块石头,那么你永远不会赢得比赛;因为无论你拿走 1 块、2 块 还是 3 块石头,最后一块石头总是会被你的朋友拿走。

定义 \(f(n)\) 为有\(n\)个石头并采取最优策略的游戏结果, \(f(n)\)的值只有可能是赢或者输。考察前几个结果:\(f(1) = f(2) = f(3) = Win\),然后来计算\(f(4)\)。因为玩家采取最优策略(只要有一种走法让对方必输,玩家获胜),对于4来说,玩家能走的可能是拿掉1块、2块或3块,但是无论剩余何种局面,对方都是必赢,因此,4就是必输。总的说来,递归关系如下: \[ f(n) = \neg (f(n-1) \land f(n-2) \land f(n-3)) \]

这个递归式可以直接翻译成Python 3代码
{linenos
1
2
3
4
5
6
7
8
9
10
11
# TLE
# Time Complexity: O(exponential)
class Solution_BruteForce:

def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
for i in range(1, 4):
if not self.canWinNim(n - i):
return True
return False
以上的递归公式和代码很像fibonacci数的递归定义和暴力解法,因此对应的时间复杂度也是指数级的,提交代码以后会TLE。下图画出了当n=7时的递归调用,注意 5 被扩展向下重复执行了两次,4重复了4次。
292 Nim Game 暴力解法调用图 n=7
我们采用和fibonacci一样的方式来优化算法:缓存较小n的结果以此来计算较大n的结果。 Python 中,我们可以只加一行lru_cache decorator,来取得这种动态规划效果,下面的代码将复杂度降到了 \(O(N)\)
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
# RecursionError: maximum recursion depth exceeded in comparison n=1348820612
# Time Complexity: O(N)
class Solution_DP:
from functools import lru_cache
@lru_cache(maxsize=None)
def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
for i in range(1, 4):
if not self.canWinNim(n - i):
return True
return False
再来画出调用图:这次5和4就不再被展开重复计算,图中绿色的节点表示缓存命中。
292 Nim Game 动归解法调用图 n=7

但还是没有AC,因为当n=1348820612时,这种方式会导致栈溢出。再改成下面的循环版本,可惜还是TLE。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# TLE for 1348820612
# Time Complexity: O(N)
class Solution:
def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
last3, last2, last1 = True, True, True
for i in range(4, n+1):
this = not (last3 and last2 and last1)
last3, last2, last1 = last2, last1, this
return last1

由此看来,AC 版本需要低于\(O(n)\)的算法复杂度。上面的写法似乎暗示输赢有周期性的规律。事实上,如果将输赢按照顺序画出来,就马上得出规律了:只要\(n \mod 4 = 0\) 就是输,否则赢。原因如下:当面临不能被4整除的数量时 \(4k+i (i=1,2,3)\) ,一方总是可以拿走 \(i\) 个,将\(4k\) 留给对手,而对方下轮又将返回不能被4整除的数,如此循环往复,直到这一方有1, 2, 3 个,最终获胜。

输赢分布

最终AC版本,只有一句语句。

{linenos
1
2
3
4
5
# AC
# Time Complexity: O(1)
class Solution:
def canWinNim(self, n: int) -> bool:
return not (n % 4 == 0)

Leetcode 486 Predict the Winner (中等)

中等难度题目: Leetcode 486 Predict the Winner.

给定一个表示分数的非负整数数组。 玩家1从数组任意一端拿取一个分数,随后玩家2继续从剩余数组任意一端拿取分数,然后玩家1拿,……。每次一个玩家只能拿取一个分数,分数被拿取之后不再可取。直到没有剩余分数可取时游戏结束。最终获得分数总和最多的玩家获胜。
给定一个表示分数的数组,预测玩家1是否会成为赢家。你可以假设每个玩家的玩法都会使他的分数最大化。

示例 1:
输入: [1, 5, 2]
输出: False
解释: 一开始,玩家1可以从1和2中进行选择。
如果他选择2(或者1),那么玩家2可以从1(或者2)和5中进行选择。如果玩家2选择了5,那么玩家1则只剩下1(或者2)可选。
所以,玩家1的最终分数为 1 + 2 = 3,而玩家2为 5。
因此,玩家1永远不会成为赢家,返回 False。

示例 2:
输入: [1, 5, 233, 7]
输出: True
解释: 玩家1一开始选择1。然后玩家2必须从5和7中进行选择。无论玩家2选择了哪个,玩家1都可以选择233。
最终,玩家1(234分)比玩家2(12分)获得更多的分数,所以返回 True,表示玩家1可以成为赢家。

对于当前玩家,他有两种选择:左边或者右边的数。定义 maxDiff(l, r) 为剩余子数组\([l,r]\)时,当前玩家能取得的最大分差,那么

\[ \begin{equation*} \operatorname{maxDiff}(l, r) = \max \begin{cases} nums[l] - \operatorname{maxDiff}(l + 1, r)\\\\ nums[r] - \operatorname{maxDiff}(l, r - 1) \end{cases} \end{equation*} \]

对应的时间复杂度可以写出递归式,显然是指数级的: \[ f(n) = 2f(n-1) = O(2^n) \]

采用暴力解法可以AC,但运算时间很长,接近TLE边缘 (6300ms)。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# AC
# Time Complexity: O(2^N)
# Slow: 6300ms
from typing import List

class Solution:

def maxDiff(self, l: int, r:int) -> int:
if l == r:
return self.nums[l]
return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
return self.maxDiff(0, len(nums) - 1) >= 0

从调用图也很容易看出是指数级的复杂度
486 Predict the Winner 暴力解法调用图 n=4

上图中我们有重复计算的节点,例如[1-2]节点被计算了两次。使用 lru_cache 大法,在maxDiff 上仅加了一句,就能以复杂度 \(O(n^2)\)和运行时间 43ms AC。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# AC
# Time Complexity: O(N^2)
# Fast: 43ms
from functools import lru_cache
from typing import List

class Solution:

@lru_cache(maxsize=None)
def maxDiff(self, l: int, r:int) -> int:
if l == r:
return self.nums[l]
return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
return self.maxDiff(0, len(nums) - 1) >= 0
动态规划解法调用图可以看出节点 [1-2] 这次没有被计算两次。
486 Predict the Winner 动归解法调用图 n=4

Leetcode 464 Can I Win (中等)

类似但稍有难度的题目 Leetcode 464 Can I Win。难点在于使用了位的状态压缩。

在 "100 game" 这个游戏中,两名玩家轮流选择从 1 到 10 的任意整数,累计整数和,先使得累计整数和达到 100 的玩家,即为胜者。
如果我们将游戏规则改为 “玩家不能重复使用整数” 呢?
例如,两个玩家可以轮流从公共整数池中抽取从 1 到 15 的整数(不放回),直到累计整数和 >= 100。
给定一个整数 maxChoosableInteger (整数池中可选择的最大数)和另一个整数 desiredTotal(累计和),判断先出手的玩家是否能稳赢(假设两位玩家游戏时都表现最佳)?
你可以假设 maxChoosableInteger 不会大于 20, desiredTotal 不会大于 300。

示例:
输入:
maxChoosableInteger = 10
desiredTotal = 11
输出:
false
解释:
无论第一个玩家选择哪个整数,他都会失败。
第一个玩家可以选择从 1 到 10 的整数。
如果第一个玩家选择 1,那么第二个玩家只能选择从 2 到 10 的整数。
第二个玩家可以通过选择整数 10(那么累积和为 11 >= desiredTotal),从而取得胜利.
同样地,第一个玩家选择任意其他整数,第二个玩家都会赢。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# AC
# Time Complexity: O:(2^m*m), m: maxChoosableInteger
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
def recurse(self, status: int, currentTotal: int) -> bool:
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return True
if not self.recurse(new_status, currentTotal + i):
return True
return False


def canIWin(self, maxChoosableInteger: int, desiredTotal: int) -> bool:
self.maxChoosableInteger = maxChoosableInteger
self.desiredTotal = desiredTotal

sum = maxChoosableInteger * (maxChoosableInteger + 1) / 2
if sum < desiredTotal:
return False
return self.recurse(0, 0)

上面的代码算法复杂度为\(O(m 2^m)\),m是maxChoosableInteger。由于所有状态的数量是\(2^m\),对于每个状态,最多会尝试 \(m\) 走法。

Minimax 算法

至此,我们AC了leetcode中的几道零和回合制博弈游戏。事实上,在这个领域有通用的算法:回合制博弈下的minimax。算法背景如下,两个玩家轮流玩,第一个玩家max的目的是将游戏的效用最大化,第二个玩家min则是最小化效用。比如,下面的节点表示玩家选取节点后游戏的效用,当两个玩家都能采取最优策略,Minimax 算法从底层节点来计算,游戏的结果是最终max 玩家会得到-7。

Wikipedia Minimax 例子

Minimax Python 3伪代码如下。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
def minimax(node: Node, depth: int, maximizingPlayer: bool) -> int:
if depth == 0 or is_terminal(node):
return evaluate_terminal(node)
if maximizingPlayer:
value:int = −∞
for child in node:
value = max(value, minimax(child, depth − 1, False))
return value
else: # minimizing player
value := +∞
for child in node:
value = min(value, minimax(child, depth − 1, True))
return value

Minimax: 486 Predict the Winner

我们知道486 Predict the Winner 是有minimax解法的,但如何具体实现,其难点在于如何定义合适的游戏价值或者效用。之前的解法中,我们定义maxDiff(l, r) 来表示当前玩家面临子区间 \([l, r]\) 时能取得的最大分差。对于minimax算法,max 玩家要最大化游戏价值,min玩家要最小化游戏价值。先考虑最简单情况即只有一个数x时,若定义max玩家在此局面下得到这个数时游戏价值为 +x,则min玩家为-x,即max玩家得到的所有数为正(\(+a_1 + a_2 + ... = A\)),min玩家得到的所有数为负(\(-b_1 - b_2 - ... = -B\))。至此,max玩家的目标就是 \(max(A-B)\) ,min玩家是 \(min(A-B)\)。有了精确的定义和优化目标,代码只需要套一下上面的模版。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# AC
from functools import lru_cache
from typing import List

class Solution:
# max_player: max(A - B)
# min_player: min(A - B)
@lru_cache(maxsize=None)
def minimax(self, l: int, r: int, isMaxPlayer: bool) -> int:
if l == r:
return self.nums[l] * (1 if isMaxPlayer else -1)

if isMaxPlayer:
return max(
self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))
else:
return min(
-self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
-self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
v = self.minimax(0, len(nums) - 1, True)
return v >= 0
Minimax 486 调用图 nums=[1, 5, 2, 4]

Minimax: 464 Can I Win

该题目是很典型的此类游戏,即结果为赢输平,但是中间的状态没有直接对应的游戏价值。对于这样的问题,一般定义为,max玩家胜,价值 +1,min玩家胜,价值-1,平则0。下面的AC代码实现了 Minimax 算法。算法中针对两个玩家都有剪枝(没有剪枝无法AC)。具体来说,max玩家一旦在某一节点取得胜利(value=1),就停止继续向下搜索,因为这是他能取得的最好分数。同理,min玩家一旦取得-1也直接返回上层节点。这个剪枝可以泛化成 alpha beta剪枝算法。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# AC
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
# currentTotal < desiredTotal
def minimax(self, status: int, currentTotal: int, isMaxPlayer: bool) -> int:
import math
if status == self.allUsed:
return 0 # draw: no winner

if isMaxPlayer:
value = -math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return 1 # shortcut
value = max(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
if value == 1:
return 1
return value
else:
value = math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return -1 # shortcut
value = min(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
if value == -1:
return -1
return value

Alpha-Beta 剪枝

在464 Can I Win minimax 算法代码实现中,我们发现有剪枝优化空间。对于每个节点,定义两个值alpha 和 beta,表示从根节点到目前局面时,max玩家保证能取得的最小值以及min玩家能保证取得的最大值。初始时,根节点alpha = −∞ , beta = +∞,表示游戏最终的价值在区间 [−∞, +∞]中。在向下遍历的过程中,子节点先继承父节点的 alpha beta 值进而继承区间 [alpha, beta]。当子节点在向下遍历的时候同步更新alpha 或者 beta,一旦区间[alpha, beta]非法就立即向上返回。举个Wikimedia的例子来进一步说明:

  1. 根节点初始时: alpha = −∞, beta = +∞

  2. 根节点,最左边子节点返回4后: alpha = 4, beta = +∞

  3. 根节点,中间子节点返回5后: alpha = 5, beta = +∞

  4. 最右Min节点(标1节点),初始时: alpha = 5, beta = +∞

  5. 最右Min节点(标1节点),第一个子节点返回1后: alpha = 5, beta = 1

此时,最右Min节点的alpha, beta形成了无效区间[5, 1],满足了剪枝条件,因此可以不用计算它的第二个和第三个子节点。如果剩余子节点返回值 > 1,比如2,由于这是个min节点,将会被已经到手的1替换。若其他子节点返回值 < 1,但由于min的父节点有效区间是[5, +∞],已经保证了>=5,小于5的值也会被忽略。

Wikimedia Alpha Beta 剪枝例子

Alpha Beta 剪枝 Python 3伪代码如下

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def alpha_beta(node: Node, depth: int, α: int, β: int, maximizingPlayer: bool) -> int:
if depth == 0 or is_terminal(node):
return evaluate_terminal(node)
if maximizingPlayer:
value: int = −∞
for child in node:
value = max(value, alphabeta(child, depth − 1, α, β, False))
α = max(α, value)
if α >= β:
break # β cut-off
return value
else:
value: int = +∞
for child in node:
value = min(value, alphabeta(child, depth − 1, α, β, True))
β = min(β, value)
if β <= α:
break # α cut-off
return value

Alpha-Beta Pruning: 486 Predict the Winner

用 Alpha-Beta 剪枝 再次AC 486。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# AC
import math
from functools import lru_cache
from typing import List

class Solution:
def alpha_beta(self, l: int, r: int, curr: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
if l == r:
return curr + self.nums[l] * (1 if isMaxPlayer else -1)

if isMaxPlayer:
ret = self.alpha_beta(l + 1, r, curr + self.nums[l], not isMaxPlayer, alpha, beta)
alpha = max(alpha, ret)
if alpha >= beta:
return alpha
ret = max(ret, self.alpha_beta(l, r - 1, curr + self.nums[r], not isMaxPlayer, alpha, beta))
return ret
else:
ret = self.alpha_beta(l + 1, r, curr - self.nums[l], not isMaxPlayer, alpha, beta)
beta = min(beta, ret)
if alpha >= beta:
return beta
ret = min(ret, self.alpha_beta(l, r - 1, curr - self.nums[r], not isMaxPlayer, alpha, beta))
return ret

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
v = self.alpha_beta(0, len(nums) - 1, 0, True, -math.inf, math.inf)
return v >= 0

Alpha-Beta Pruning: 464 Can I Win

464 Alpha-Beta 剪枝版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# AC
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
# currentTotal < desiredTotal
def alpha_beta(self, status: int, currentTotal: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
import math
if status == self.allUsed:
return 0 # draw: no winner

if isMaxPlayer:
value = -math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return 1 # shortcut
value = max(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
alpha = max(alpha, value)
if alpha >= beta:
return value
return value
else:
value = math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return -1 # shortcut
value = min(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
beta = min(beta, value)
if alpha >= beta:
return value
return value

C++, Java, Javascript AC 486 Predict the Winner

最后介绍一种不同的DP实现:用C++, Java, Javascript 实现自底向上的DP解法来AC leetcode 486,当然其他语言没有Python的lru_cache大法。以下实现中,注意DP解的构建顺序,先解决小规模的问题,并在此基础上计算稍大的问题。值得一提的是,以下的循环写法严格保证了 \(n^2\) 次循环,但是自顶向下的计划递归可能会少于 \(n^2\)次循环。

Java AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// AC
class Solution {
public boolean PredictTheWinner(int[] nums) {
int n = nums.length;
int[][] dp = new int[n][n];
for (int i = 0; i < n; i++) {
dp[i][i] = nums[i];
}

for (int l = n - 1; l >= 0; l--) {
for (int r = l + 1; r < n; r++) {
dp[l][r] = Math.max(
nums[l] - dp[l + 1][r],
nums[r] - dp[l][r - 1]);
}
}
return dp[0][n - 1] >= 0;
}
}

C++ AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// AC
class Solution {
public:
bool PredictTheWinner(vector<int>& nums) {
int n = nums.size();
vector<vector<int>> dp(n, vector<int>(n, 0));
for (int i = 0; i < n; i++) {
dp[i][i] = nums[i];
}
for (int l = n - 1; l >= 0; l--) {
for (int r = l + 1; r < n; r++) {
dp[l][r] = max(nums[l] - dp[l + 1][r], nums[r] - dp[l][r - 1]);
}
}
return dp[0][n - 1] >= 0;
}
};

Javascript AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* @param {number[]} nums
* @return {boolean}
*/
var PredictTheWinner = function(nums) {
const n = nums.length;
const dp = new Array(n).fill().map(() => new Array(n));

for (let i = 0; i < n; i++) {
dp[i][i] = nums[i];
}

for (let l = n - 1; l >=0; l--) {
for (let r = i + 1; r < n; r++) {
dp[l][r] = Math.max(nums[l] - dp[l + 1][r],nums[r] - dp[l][r - 1]);
}
}

return dp[0][n-1] >=0;
};
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×