#Python

Leetcode 679 24 Game (Hard)

先来介绍一下24点游戏题目,大家一定都玩过,就是给定4个牌面数字,用加减乘除计算24点。

本篇会用两种偏函数式的 Python 3解法来AC 24 Game。

Leetcode 679 24 Game (Hard) > You have 4 cards each containing a number from 1 to 9. You need to judge whether they could operated through *, /, +, -, (, ) to get the value of 24.

Example 1:

Input: [4, 1, 8, 7]

Output: True

Explanation: (8-4) * (7-1) = 24

Example 2:

Input: [1, 2, 1, 2]

Output: False

itertools.permutations

先来介绍一下Python itertools.permutations 的用法,正好用Leetcode 中的Permutation问题来示例。Permutations 的输入可以是List,返回是 generator 实例,用于生成所有排列。简而言之,python 的 generator 可以和List一样,用 for 语句来全部遍历产生的值。和List不同的是,generator 的所有值并不必须全部初始化,一般按需产生从而大量减少内存占用。下面在介绍 yield 时我们会看到如何合理构造 generator。

Leetcode 46 Permutations (Medium) > Given a collection of distinct integers, return all possible permutations.

Example:

Input: [1,2,3]

Output: [ [1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1,2], [3,2,1]]

用 permutations 很直白,代码只有一行。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# AC
# Runtime: 36 ms, faster than 91.78% of Python3 online submissions for Permutations.
# Memory Usage: 13.9 MB, less than 66.52% of Python3 online submissions for Permutations.

from itertools import permutations
from typing import List


class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
return [p for p in permutations(nums)]

itertools.combinations

有了排列就少不了组合,itertools.combinations 可以产生给定List的k个元素组合 \(\binom{n}{k}\),用一道算法题来举例,同样也是一句语句就可以AC。

Leetcode 77 Combinations (Medium)

Given two integers n and k, return all possible combinations of k numbers out of 1 ... n. You may return the answer in any order.

Example 1:

Input: n = 4, k = 2

Output: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4],]

Example 2:

Input: n = 1, k = 1

Output: [[1]]

{linenos
1
2
3
4
5
6
7
8
9
# AC
# Runtime: 84 ms, faster than 95.43% of Python3 online submissions for Combinations.
# Memory Usage: 15.2 MB, less than 68.98% of Python3 online submissions for Combinations.
from itertools import combinations
from typing import List

class Solution:
def combine(self, n: int, k: int) -> List[List[int]]:
return [c for c in combinations(list(range(1, n + 1)), k)]

itertools.product

当有多维度的对象需要迭代笛卡尔积时,可以用 product(iter1, iter2, ...)来生成generator,等价于多重 for 循环。

1
2
[lst for lst in product([1, 2, 3], ['a', 'b'])]
[(i, s) for i in [1, 2, 3] for s in ['a', 'b']]

这两种方式都生成了如下结果

1
[(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b'), (3, 'a'), (3, 'b')]

再举一个Leetcode的例子来实战product generator。

Leetcode 17. Letter Combinations of a Phone Number (Medium)

Given a string containing digits from 2-9 inclusive, return all possible letter combinations that the number could represent. A mapping of digit to letters (just like on the telephone buttons) is given below. Note that 1 does not map to any letters.

Example:

Input: "23"

Output: ["ad", "ae", "af", "bd", "be", "bf", "cd", "ce", "cf"].

举例来说,下面的代码当输入 digits 是 '352' 时,iter_dims 的值是 ['def', 'jkl', 'abc'],再输入给 product 后会产生 'dja', 'djb', 'djc', 'eja', 共 3 x 3 x 3 = 27个组合的值。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# AC
# Runtime: 24 ms, faster than 94.50% of Python3 online submissions for Letter Combinations of a Phone Number.
# Memory Usage: 13.7 MB, less than 83.64% of Python3 online submissions for Letter Combinations of a Phone Number.

from itertools import product
from typing import List


class Solution:
def letterCombinations(self, digits: str) -> List[str]:
if digits == "":
return []
mapping = {'2':'abc', '3':'def', '4':'ghi', '5':'jkl', '6':'mno', '7':'pqrs', '8':'tuv', '9':'wxyz'}
iter_dims = [mapping[i] for i in digits]

result = []
for lst in product(*iter_dims):
result.append(''.join(lst))

return result

yield 示例

Python具有独特的itertools generator,可以花式AC代码,接下来讲解如何进一步构造 generator。Python 定义只要函数中使用了yield关键字,这个函数就是 generator。Generator 在计算机领域的标准名称是 coroutine,即协程,是一种特殊的函数:当返回上层调用时自身能够保存调用栈状态,并在上层函数处理完逻辑后跳入到这个 generator,恢复之前的状态再继续运行下去。Yield语句也举一道经典的Fibonacci 问题。

Leetcode 509. Fibonacci Number (Easy)

The Fibonacci numbers, commonly denoted F(n) form a sequence, called the Fibonacci sequence, such that each number is the sum of the two preceding ones, starting from 0 and 1. That is, F(0) = 0, F(1) = 1 F(N) = F(N - 1) + F(N - 2), for N > 1. Given N, calculate F(N).

Example 1:

Input: 2

Output: 1

Explanation: F(2) = F(1) + F(0) = 1 + 0 = 1.

Example 2:

Input: 3

Output: 2

Explanation: F(3) = F(2) + F(1) = 1 + 1 = 2.

Example 3:

Input: 4

Output: 3

Explanation: F(4) = F(3) + F(2) = 2 + 1 = 3.

Fibonacci 的一般标准解法是循环迭代方式,可以以O(n)时间复杂度和O(1) 空间复杂度来AC。下面的 yield 版本中,我们构造了fib_next generator,它保存了最后两个值作为内部迭代状态,外部每调用一次可以得到下一个fib(n),如此外部只需不断调用直到满足题目给定次数。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# AC
# Runtime: 28 ms, faster than 85.56% of Python3 online submissions for Fibonacci Number.
# Memory Usage: 13.8 MB, less than 58.41% of Python3 online submissions for Fibonacci Number.

class Solution:
def fib(self, N: int) -> int:
if N <= 1:
return N
i = 2
for fib in self.fib_next():
if i == N:
return fib
i += 1

def fib_next(self):
f_last2, f_last = 0, 1
while True:
f = f_last2 + f_last
f_last2, f_last = f_last, f
yield f

yield from 示例

上述yield用法之后,再来演示 yield from 的用法。Yield from 始于Python 3.3,用于嵌套generator时的控制转移,一种典型的用法是有多个generator嵌套时,外层的outer_generator 用 yield from 这种方式等价代替如下代码。

1
2
3
def outer_generator():
for i in inner_generator():
yield i

用一道算法题目来具体示例。

Leetcode 230. Kth Smallest Element in a BST (Medium)

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Example 1: Input: root = [3,1,4,null,2], k = 1

1
2
3
4
5
  3
/ \
1 4
\
2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3

1
2
3
4
5
6
7
         5
/ \
3 6
/ \
2 4
/
1
Output: 3

直觉思路上,我们只要从小到大有序遍历每个节点直至第k个。因为给定的树是Binary Search Tree,有序遍历意味着以左子树、节点本身和右子树的访问顺序递归下去就行。由于ordered_iter是generator,递归调用自己的过程就是嵌套使用generator的过程。下面是yield版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# AC
# Runtime: 48 ms, faster than 90.31% of Python3 online submissions for Kth Smallest Element in a BST.
# Memory Usage: 17.9 MB, less than 14.91% of Python3 online submissions for Kth Smallest Element in a BST.

class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
def ordered_iter(node):
if node:
for sub_node in ordered_iter(node.left):
yield sub_node
yield node
for sub_node in ordered_iter(node.right):
yield sub_node

for node in ordered_iter(root):
k -= 1
if k == 0:
return node.val

等价于如下 yield from 版本:

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# AC
# Runtime: 56 ms, faster than 63.74% of Python3 online submissions for Kth Smallest Element in a BST.
# Memory Usage: 17.7 MB, less than 73.33% of Python3 online submissions for Kth Smallest Element in a BST.

class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
def ordered_iter(node):
if node:
yield from ordered_iter(node.left)
yield node
yield from ordered_iter(node.right)

for node in ordered_iter(root):
k -= 1
if k == 0:
return node.val

24 点问题之函数式枚举解法

看明白了itertools.permuations,combinations,product,yield以及yield from,我们回到本篇最初的24点游戏问题。

24点游戏的本质是枚举出所有可能运算,如果有一种方式得到24返回True,否则返回Flase。进一步思考所有可能的运算,包括下面三个维度:

  1. 4个数字的所有排列,比如给定 [1, 2, 3, 4],可以用permutations([1, 2, 3, 4]) 生成这个维度的所有可能

  2. 三个位置的操作符号的全部可能,可以用 product([+, -, *, /], repeat=3) 生成,具体迭代结果为:[+, +, +],[+, +, -],...

  3. 给定了前面两个维度后,还有一个比较不容易察觉但必要的维度:运算优先级。比如在给定数字顺序 [1, 2, 3, 4]和符号顺序 [+, *, -]之后可能的四种操作树

四种运算优先级

能否算得24点只需要枚举这三个维度笛卡尔积的运算结果

(维度1:数字组合) x (维度2:符号组合) x (维度3:优先级组合)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# AC
# Runtime: 112 ms, faster than 57.59% of Python3 online submissions for 24 Game.
# Memory Usage: 13.7 MB, less than 85.60% of Python3 online submissions for 24 Game.

import math
from itertools import permutations, product
from typing import List

class Solution:

def iter_trees(self, op1, op2, op3, a, b, c, d):
yield op1(op2(a, b), op3(c, d))
yield op1(a, op2(op3(b, c), d))
yield op1(a, op2(b, op3(c, d)))
yield op1(op2(a, op3(b, c)), d)

def judgePoint24(self, nums: List[int]) -> bool:
mul = lambda x, y: x * y
plus = lambda x, y: x + y
div = lambda x, y: x / y if y != 0 else math.inf
minus = lambda x, y: x - y

op_lst = [plus, minus, mul, div]

for ops in product(op_lst, repeat=3):
for val in permutations(nums):
for v in self.iter_trees(ops[0], ops[1], ops[2], val[0], val[1], val[2], val[3]):
if abs(v - 24) < 0.0001:
return True
return False

24 点问题之 DFS yield from 解法

一种常规的思路是,在四个数组成的集合中先选出任意两个数,枚举所有可能的计算,再将剩余的三个数组成的集合递归调用下去,直到叶子节点只剩一个数,如下图所示。

DFS 调用示例

下面的代码是这种思路的 itertools + yield from 解法,recurse方法是generator,会自我递归调用。当只剩下两个数时,用 yield 返回两个数的所有可能运算得出的值,其他非叶子情况下则自我调用使用yield from,例如4个数任选2个先计算再合成3个数的情况。这种情况下,比较麻烦的是由于4个数可能有相同值,若用 combinations(lst, 2) 先任选两个数,后续要生成剩余两个数加上第三个计算的数的集合代码会繁琐。因此,我们改成任选4个数index中的两个,剩余的indices 可以通过集合操作来完成。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# AC
# Runtime: 116 ms, faster than 56.23% of Python3 online submissions for 24 Game.
# Memory Usage: 13.9 MB, less than 44.89% of Python3 online submissions for 24 Game.

import math
from itertools import combinations, product, permutations
from typing import List

class Solution:

def judgePoint24(self, nums: List[int]) -> bool:
mul = lambda x, y: x * y
plus = lambda x, y: x + y
div = lambda x, y: x / y if y != 0 else math.inf
minus = lambda x, y: x - y

op_lst = [plus, minus, mul, div]

def recurse(lst: List[int]):
if len(lst) == 2:
for op, values in product(op_lst, permutations(lst)):
yield op(values[0], values[1])
else:
# choose 2 indices from lst of length n
for choosen_idx_lst in combinations(list(range(len(lst))), 2):
# remaining indices not choosen (of length n-2)
idx_remaining_set = set(list(range(len(lst)))) - set(choosen_idx_lst)

# remaining values not choosen (of length n-2)
value_remaining_lst = list(map(lambda x: lst[x], idx_remaining_set))
for op, idx_lst in product(op_lst, permutations(choosen_idx_lst)):
# 2 choosen values are lst[idx_lst[0]], lst[idx_lst[1]
value_remaining_lst.append(op(lst[idx_lst[0]], lst[idx_lst[1]]))
yield from recurse(value_remaining_lst)
value_remaining_lst = value_remaining_lst[:-1]

for v in recurse(nums):
if abs(v - 24) < 0.0001:
return True

经典教材Reinforcement Learning: An Introduction 第二版由强化领域权威Richard S. Sutton 和 Andrew G. Barto 完成编写,内容深入浅出,非常适合初学者。在本篇中,引入Grid World示例,结合强化学习核心概念,并用python代码实现OpenAI Gym的模拟环境,进一步实现策略评价算法。

Grid World 问题

第四章例子4.1提出了一个简单的离散空间状态问题:Grid World,其大致意思是在4x4的网格世界中有14个格子是非终点状态,在这些非终点状态的格子中可以往上下左右四个方向走,直至走到两个终点状态格子,则游戏结束。每走一步,Agent收获reward -1,表示Agent希望在Grid World中尽早出去。另外,Agent在Grid World边缘时,无法继续往外只能呆在原地,reward也是-1。

Finite MDP 模型

先来回顾一下强化学习的建模基础:有限马尔可夫决策过程(Finite Markov Decision Process, Finite MDP)。如下图,强化学习模型将世界抽象成两个实体,强化学习解决目标的主体Agent和其他外部环境。它们之间的交互过程遵从有限马尔可夫决策过程:若Agent在t时间步骤时处于状态 \(S_t\),采取动作 \(A_t\),然后环境根据自身机制,产生Reward \(R_{t+1}\) 并将Agent状态变为 \(S_{t+1}\)

环境自身机制又称为dynamics,工程上可以看成一个输入(S, A),输出(S, R)的方法。由于MDP包含随机过程,某个输入并不能确定唯一输出,而会根据概率分布输出不同的(S, R)。Finite MDP简化了时间对于模型的影响,因为(S, R)只和(S, A)有关,不和时间t有关。另外,有限指的是S,A,R的状态数量是有限的。

数学上dynamics可以如下表示

\[ p\left(s^{\prime}, r \mid s, a\right) \doteq \operatorname{Pr}\left\{S_{t}=s^{\prime}, R_{t}=r \mid S_{t-1}=s, A_{t-1}=a\right\} \]

即是四元组作为输入的概率函数 \(p: S \times R \times S \times A \rightarrow [0, 1]\)

满足 \[ \sum_{s^{\prime} \in \mathcal{S}} \sum_{r \in \mathcal{R}} p\left(s^{\prime}, r \mid s, a\right)=1, \text { for all } s \in \mathcal{S}, a \in \mathcal{A}(s) \]

以Grid World为例,当Agent处于编号1的网格时,可以往四个方向走,往任意方向走都只产生一种 S, R,因为这个简单的游戏是确定性的,不存在某一动作导致stochastic状态。例如,在1号网格往左就到了终点网格(编号0),得到Reward -1这个规则可以如下表示 \[ p\left(s^{\prime}=0, r=-1 \mid s=1, a=\text{L}\right) = 1 \] 因此,状态s=1的所有dynamics概率映射为

\[ \begin{aligned} p\left(s^{\prime}=0, r=-1 \mid s=1, a=\text{L}\right) &=& 1 \\ p\left(s^{\prime}=2, r=-1 \mid s=1, a=\text{R}\right) &=& 1 \\ p\left(s^{\prime}=1, r=-1 \mid s=1, a=\text{U}\right) &=& 1 \\ p\left(s^{\prime}=5, r=-1 \mid s=1, a=\text{D}\right) &=& 1 \end{aligned} \]

强化学习的目的

在给定了问题以及定义了强化学习的模型之后,强化学习的目的当然是通过学习让Agent能够学到最佳策略\(\pi_{*}\),也就是在某个状态下的行动分布,记成 \(\pi(a|s)\)。对应在数值上的优化目标是Agent在一系列过程中采取某种策略的reward总和的期望(Expected Return)。下面公式定义了t步往后的reward总和,其中 \(\gamma\) 为discount factor,用于权衡短期和长期reward对于当前Agent的效用影响。等式最后一步的意义是t步后的reward总和等价于t步所获的立即reward \(R_{t+1}\),加上t+1步后的reward总和 \(\gamma G_{t+1}\)

\[ \begin{aligned} G_{t} & \doteq R_{t+1}+\gamma R_{t+2}+\gamma^{2} R_{t+3}+\gamma^{3} R_{t+4}+\cdots \\ &=R_{t+1}+\gamma\left(R_{t+2}+\gamma R_{t+3}+\gamma^{2} R_{t+4}+\cdots\right) \\ &=R_{t+1}+\gamma G_{t+1} \end{aligned} \]

有了reward总和的定义,评价Agent策略 \(\pi\) 就可以定义成Agent在状态 s 时采用此策略的Expected Return。

\[ v_{\pi}(s) \doteq \mathbb{E}_{\pi}\left[G_{t} \mid S_{t}=s\right] \]

下面公式推导了 \(v_{\pi}(s)\) 数值上和相关状态 \(s{\prime}\) 的关系:

\[ \begin{aligned} v_{\pi}(s) &\doteq \mathbb{E}_{\pi}\left[G_{t} \mid S_{t}=s\right] \\ &=\mathbb{E}_{\pi}\left[\sum_{k=0}^{\infty} \gamma^{k} R_{t+k+1} \mid S_{t}=s\right]\\ &=\mathbb{E}_{\pi}\left[R_{t+1}+\gamma G_{t+1} \mid S_{t}=s\right] \\ &=\sum_{a} \pi(a \mid s) \sum_{s^{\prime}} \sum_{r} p\left(s^{\prime}, r \mid s, a\right)\left[r+\gamma \mathbb{E}_{\pi}\left[G_{t+1} \mid S_{t+1}=s^{\prime}\right]\right] \\ &=\sum_{a} \pi(a \mid s) \sum_{s^{\prime}, r} p\left(s^{\prime}, r \mid s, a\right)\left[r+\gamma v_{\pi}\left(s^{\prime}\right)\right] \quad \text { for all } s \in \mathcal{S} \end{aligned} \]

注意到如果将 \(v_{\pi}(s)\) 看成未知数,上式即形成 \(\mid \mathcal{S} \mid\) 个未知变量的方程组,可以在数值上解得各个 \(v_{\pi}(s)\)

书中用Backup Diagram来表示递推关系,下图是\(v_{\pi}(s)\)的backup diagram。

尽管v值可以来衡量策略,但由于\(v_{\pi}(s)\) 是Agent在策略\(\pi(a|s)\)的Expected Return,将不同的action拆出来单独计算Expected Return,这样的做法有时更为直接,这就是著名的Q Learning中的q 值,记成\(q_{\pi}(s, a)\)

\[ q_{\pi}(s, a) \doteq \mathbb{E}_{\pi}\left[G_{t} \mid S_{t}=s, A_{t}=a\right] \]

下面是 $q_{}(s, a) $ 的递推 backup diagram。

Bellman 最佳原则

对于所有状态集合\(\mathcal{S}\),策略\({\pi}\)的评价指标 \(v_{\pi}(s)\) 是一个向量,本质上是无法相互比较的。但由于存在Bellman 最佳原则(Bellman's principle of optimality):在有限状态情况下,一定存在一个或者多个最好的策略 \({\pi}_{*}\),它在所有状态下的v值都是最好的,即 \(v_{\pi_{*}}(s) \ge v_{\pi^{\prime}}(s) \text { for all } s \in \mathcal{S}\)

因此,最佳v值定义为最佳策略 \({\pi}_{*}\) 对应的 v 值

\[ v_{*}(s) \doteq \max_{\pi} v_{\pi}(s) \]

同理,也存在最佳q值,记为 \[ \begin{aligned} q_{*}(s, a) &\doteq \max_{\pi} q_{\pi}(s,a) \end{aligned} \]

\(v_{*}(s)\) 改写成递推形式,称为 Bellman Optimality Equation,推导如下

\[ \begin{aligned} v_{*}(s) &=\max _{a \in \mathcal{A}(s)} q_{\pi_{*}}(s, a) \\ &=\max _{a} \mathbb{E}_{\pi_{*}}\left[G_{t} \mid S_{t}=s, A_{t}=a\right] \\ &=\max _{a} \mathbb{E}_{\pi_{*}}\left[R_{t+1}+\gamma G_{t+1} \mid S_{t}=s, A_{t}=a\right] \\ &=\max _{a} \mathbb{E}\left[R_{t+1}+\gamma v_{*}\left(S_{t+1}\right) \mid S_{t}=s, A_{t}=a\right] \\ &=\max _{a} \sum_{s^{\prime}, r} p\left(s^{\prime}, r \mid s, a\right)\left[r+\gamma v_{*}\left(s^{\prime}\right)\right] \end{aligned} \]

直觉上可以理解为状态 s 对应的最佳v值是只采取此状态下的最佳动作后的Expected Return。

最佳q值递归形式的意义为最佳策略下状态s时采取行动 a 的Expected Return,等于所有可能后续状态 s' 下采取最优行动的Expected Return的均值。推导如下:

\[ \begin{aligned} q_{*}(s, a) &=\mathbb{E}\left[R_{t+1}+\gamma \max _{a^{\prime}} q_{*}\left(S_{t+1}, a^{\prime}\right) \mid S_{t}=s, A_{t}=a\right] \\ &=\sum_{s^{\prime}, r} p\left(s^{\prime}, r \mid s, a\right)\left[r+\gamma \max _{a^{\prime}} q_{*}\left(s^{\prime}, a^{\prime}\right)\right] \end{aligned} \]

\(v_{*}(s), q_{*}(s, a)\) 的backup diagram 如下图

Grid World 最佳策略和V值

Grid World 的最佳策略如下:尽可能快的走出去

Grid World最佳策略

上面的2D图中不同颜色表示不同V值,终点格子的红色表示0,隔着一步的黄色为-1,隔两步的绿色为-2,最远的紫色为-3。下面是立体图示。

Grid World最佳策略V值

Grid World OpenAI Gym 环境

下面是OpenAI Gym框架下Grid World环境的代码实现。本质是在GridWorldEnv构造函数中构建MDP,类型定义如下

1
2
3
4
5
6
MDP = Dict[State, Dict[Action, List[Tuple[Prob, State, Reward, bool]]]]

# P[state][action] = [
# (prob1, next_state1, reward1, is_done),
# (prob2, next_state2, reward2, is_done), ...]

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Action(Enum):
UP = 0
DOWN = 1
LEFT = 2
RIGHT = 3

State = int
Reward = float
Prob = float
Policy = Dict[State, Dict[Action, Prob]]
Value = List[float]
StateSet = Set[int]
NonTerminalStateSet = Set[int]
MDP = Dict[State, Dict[Action, List[Tuple[Prob, State, Reward, bool]]]]
# P[s][a] = [(prob, next_state, reward, is_done), ...]

class GridWorldEnv(discrete.DiscreteEnv):
"""
Grid World environment described in Sutton and Barto Reinforcement Learning 2nd, chapter 4.
"""

def __init__(self, shape=[4,4]):
self.shape = shape
nS = np.prod(shape)
nA = len(list(Action))
MAX_R = shape[0]
MAX_C = shape[1]
self.grid = np.arange(nS).reshape(shape)
isd = np.ones(nS) / nS

# P[s][a] = [(prob, next_state, reward, is_done), ...]
P: MDP = {}
action_delta = {Action.UP: (-1, 0), Action.DOWN: (1, 0), Action.LEFT: (0, -1), Action.RIGHT: (0, 1)}
for s in range(0, MAX_R * MAX_C):
P[s] = {a.value : [] for a in list(Action)}
is_terminal = self.is_terminal(s)
if is_terminal:
for a in list(Action):
P[s][a.value] = [(1.0, s, 0, True)]
else:
r = s // MAX_R
c = s % MAX_R
for a in list(Action):
neighbor_r = min(MAX_R-1, max(0, r + action_delta[a][0]))
neighbor_c = min(MAX_C-1, max(0, c + action_delta[a][1]))
s_ = neighbor_r * MAX_R + neighbor_c
P[s][a.value] = [(1.0, s_, -1, False)]

super(GridWorldEnv, self).__init__(nS, nA, P, isd)

策略评估(Policy Evaluation)

策略评估需要解决在给定环境dynamics和Agent策略 \(\pi\)下,计算策略的v值 \(v_{\pi}\)。由于所有数量关系都已知,可以通过解方程组的方式求得,但通常会通过数值迭代的方式来计算,即通过一系列 \(v_{0}, v_{1}, ..., v_{k}\) 收敛至 \(v_{\pi}\)。如下迭代方式已经得到证明,当 \(k \rightarrow \infty\) 一定收敛至 \(v_{\pi}\)

\[ \begin{aligned} v_{k+1}(s) & \doteq \mathbb{E}_{\pi}\left[R_{t+1}+\gamma v_{k}\left(S_{t+1}\right) \mid S_{t}=s\right] \\ &=\sum_{a} \pi(a \mid s) \sum_{s^{\prime}, r} p\left(s^{\prime}, r \mid s, a\right)\left[r+\gamma v_{k}\left(s^{\prime}\right)\right] \end{aligned} \]

书中具体伪代码如下

\[ \begin{align*} &\textbf{Iterative Policy Evaluation, for estimating } V\approx v_{\pi} \\ & \text{Input } {\pi}, \text{the policy to be evaluated} \\ & \text{Algorithm parameter: a small threshold } \theta > 0 \text{ determining accuracy of estimation} \\ & \text{Initialize } V(s), \text{for all } s \in \mathcal{S}^{+} \text{, arbitrarily except that } V (terminal) = 0\\ & \\ &1: \text{Loop:}\\ &2: \quad \quad \Delta \leftarrow 0\\ &3: \quad \quad \text{Loop for each } s \in \mathcal{S}:\\ &4: \quad \quad \quad \quad v \leftarrow V(s) \\ &5: \quad \quad \quad \quad V(s) \leftarrow \sum_{a} \pi(a \mid s) \sum_{s^{\prime}, r} p\left(s^{\prime}, r \mid s, a\right)\left[r+\gamma V\left(s^{\prime}\right)\right] \\ &6: \quad \quad \quad \quad \Delta \leftarrow \max(\Delta, |v-V(s)|) \\ &7: \text{until } \Delta < \theta \end{align*} \]

下面是python 代码实现,注意这里单run迭代时,新的v值直接覆盖数组里的旧v值,这种做法在书中被证明不仅有效,甚至更为高效。这种做法称为原地(in place)更新。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def policy_evaluate(policy: Policy, env: GridWorldEnv, gamma=1.0, theta=0.0001):
V = np.zeros(env.nS)
while True:
delta = 0
for s in range(env.nS):
v = 0
for a, action_prob in enumerate(policy[s]):
for prob, next_state, reward, done in env.P[s][a]:
v += action_prob * prob * (reward + gamma * V[next_state])
delta = max(delta, np.abs(v - V[s]))
V[s] = v
if delta < theta:
break
return np.array(V)

输入策略为随机选择方向,运行上面的policy_evaluate最终多轮收敛后的V值输出为

{linenos
1
2
3
4
[[  0.         -13.99931242 -19.99901152 -21.99891199]
[-13.99931242 -17.99915625 -19.99908389 -19.99909436]
[-19.99901152 -19.99908389 -17.99922697 -13.99942284]
[-21.99891199 -19.99909436 -13.99942284 0. ]]
在3D V值图中可以发现,由于是随机选择方向的策略, Agent在每个格子的V值绝对数值要比最佳V值大,意味着随机策略下Agent在Grid World会得到更多的负reward。
Grid World随机策略V值

旅行商问题(TSP)是计算机算法中经典的NP hard 问题。 在本系列文章中,我们将首先使用动态规划 AC aizu中的TSP问题,然后再利用深度学习求大规模下的近似解。深度学习应用解决问题时先以PyTorch实现监督学习算法 Pointer Network,进而结合强化学习来无监督学习,提高数据使用效率。 本系列完整列表如下:

  • 第一篇: 递归DP方法 AC AIZU TSP问题

  • 第二篇: 二维空间TSP数据集及其DP解法

  • 第三篇: 深度学习 Pointer Networks 的 Pytorch实现

  • 第四篇: 搜寻最有可能路径:Viterbi算法和其他

  • 第五篇: 深度强化学习无监督算法的 Pytorch实现

TSP 问题回顾

TSP可以用图模型来表达,无论有向图或无向图,无论全连通图或者部分连通的图都可以作为TSP问题。 Wikipedia TSP 中举了一个无向全连通的TSP例子。如下图所示,四个顶点A,B,C,D构成无向全连通图。TSP问题要求在所有遍历所有点后返回初始点的回路中找到最短的回路。例如,\(A \rightarrow B \rightarrow C \rightarrow D \rightarrow A\)\(A \rightarrow C \rightarrow B \rightarrow D \rightarrow A\) 都是有效的回路,但是TSP需要返回这些回路中的最短回路(注意,最短回路可能会有多条)。

Wikipedia 4个顶点组成的图

无论是哪种类型的图,我们都能用邻接矩阵表示出一个图。上面的Wikipedia中的图可以用下面的矩阵来描述。

\[ \begin{matrix} & \begin{matrix}A&B&C&D\end{matrix} \\\\ \begin{matrix}A\\\\B\\\\C\\\\D\end{matrix} & \begin{bmatrix}-&20&42&35\\\\20&-&30&34\\\\42&30&-&12\\\\35&34&12&-\end{bmatrix}\\\\ \end{matrix} \]

当然,大多数情况下,TSP问题会被限定在欧氏空间,即二维地图中的全连通无向图。因为,如果将顶点表示一个地理位置,一般来说它可以和其他所有顶点连通,回来的距离相同,由此构成无向图。

AIZU TSP 问题

AIZU在线题库 有一道有向不完全连通图的TSP问题。给定V个顶点和E条边,输出最小回路值。例如,题目里的例子如下所示,由4个顶点和6条单向边构成。

这个示例的答案是16,对应的回路是 \(0\rightarrow1\rightarrow3\rightarrow2\rightarrow0\),由下图的红色边构成。注意,这个题目可能不存在合法解,原因是无回路存在,此时返回-1,可以合理地理解成无穷大。

暴力解法

一种暴力方法是枚举所有可能的从某一顶点的回路,取其中的最小值即可。下面的 Python 示例如何枚举4个顶点构成的图中从顶点0出发的所有回路。

{linenos
1
2
3
4
5
from itertools import permutations
v = [1,2,3]
p = permutations(v)
for t in list(p):
print([0] + list(t) + [0])

所有从顶点0出发的回路如下:

{linenos
1
2
3
4
5
6
[0, 1, 2, 3, 0]
[0, 1, 3, 2, 0]
[0, 2, 1, 3, 0]
[0, 2, 3, 1, 0]
[0, 3, 1, 2, 0]
[0, 3, 2, 1, 0]

很显然,这种方式的时间复杂度是 O(\(n!\)),无法通过AIZU。

动态规划求解

我们可以使用位状态压缩的动态规划来AC这道题。 首先,需要将回路过程中的状态编码成二进制的表示。例如,在四顶点的例子中,如果顶点2和1都被访问过,并且此时停留在顶点1。将已经访问的顶点对应的位置1,那么编码成0110,此外,还需要保存当前顶点的位置,因此我们将代表状态的数组扩展成二维,第一维是位状态,第二维是顶点所在位置,即 \(dp[bitstate][v]\)。这个例子的状态表示就是 \(dp["0110"][1]\)

状态转移方程如下: \[ dp[bitstate][v] = \min ( dp[bitstate \cup \{u\}][u] + dist(v,u) \mid u \notin bitstate ) \] 这种方法对应的时间复杂度是 O(\(n^2*2^n\) ),因为总共有 \(2^n * n\) 个状态,而每个状态又需要一次遍历。虽然都是指数级复杂度,但是它们的巨大区别由下面可以看出区别。

\(n!\) \(n^2*2^n\)
n=8 40320 16384
n=10 3628800 102400
n=12 479001600 589824
n=14 87178291200 3211264

暂停思考一下为什么状态压缩DP能工作。注意到之前暴力解法中其实是有很多重复计算,下面红圈表示重复的计算节点。

在本篇中,我们将会用Python 3和Java 8 实现自顶向下的DP 缓存版本。这种方式比较符合直觉,因为我们不需要预先考虑计算节点的依赖关系。在Java中我们使用了一个小技巧,dp数组初始化成Integer.MAX_VALUE,如此只需要一条语句就能完成更新dp值。

1
res = Math.min(res, s + g.edges[v][u]);

当然,为了AC 这道题,我们需要区分出真正无法到达的情况并返回-1。 在Python实现中,也可以使用同样的技巧,但是这次示例一般的实现方法:将dp数组初始化成-1并通过 if-else 来区分不同情况。

1
2
3
4
5
6
7
INT_INF = -1

if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])

下面附完整的Python 3和Java 8的AC代码,同步在 github

AIZU Java 8 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// passed http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DPL_2_A
import java.util.Arrays;
import java.util.Scanner;

public class Main {
public static class Graph {
public final int V_NUM;
public final int[][] edges;

public Graph(int V_NUM) {
this.V_NUM = V_NUM;
this.edges = new int[V_NUM][V_NUM];
for (int i = 0; i < V_NUM; i++) {
Arrays.fill(this.edges[i], Integer.MAX_VALUE);
}
}

public void setDist(int src, int dest, int dist) {
this.edges[src][dest] = dist;
}

}

public static class TSP {
public final Graph g;
long[][] dp;

public TSP(Graph g) {
this.g = g;
}

public long solve() {
int N = g.V_NUM;
dp = new long[1 << N][N];
for (int i = 0; i < dp.length; i++) {
Arrays.fill(dp[i], -1);
}

long ret = recurse(0, 0);
return ret == Integer.MAX_VALUE ? -1 : ret;
}

private long recurse(int state, int v) {
int ALL = (1 << g.V_NUM) - 1;
if (dp[state][v] >= 0) {
return dp[state][v];
}
if (state == ALL && v == 0) {
dp[state][v] = 0;
return 0;
}
long res = Integer.MAX_VALUE;
for (int u = 0; u < g.V_NUM; u++) {
if ((state & (1 << u)) == 0) {
long s = recurse(state | 1 << u, u);
res = Math.min(res, s + g.edges[v][u]);
}
}
dp[state][v] = res;
return res;

}

}

public static void main(String[] args) {

Scanner in = new Scanner(System.in);
int V = in.nextInt();
int E = in.nextInt();
Graph g = new Graph(V);
while (E > 0) {
int src = in.nextInt();
int dest = in.nextInt();
int dist = in.nextInt();
g.setDist(src, dest, dist);
E--;
}
System.out.println(new TSP(g).solve());
}
}

AIZU Python 3 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import List

INT_INF = -1

class Graph:
v_num: int
edges: List[List[int]]

def __init__(self, v_num: int):
self.v_num = v_num
self.edges = [[INT_INF for c in range(v_num)] for r in range(v_num)]

def setDist(self, src: int, dest: int, dist: int):
self.edges[src][dest] = dist


class TSPSolver:
g: Graph
dp: List[List[int]]

def __init__(self, g: Graph):
self.g = g
self.dp = [[None for c in range(g.v_num)] for r in range(1 << g.v_num)]

def solve(self) -> int:
return self._recurse(0, 0)

def _recurse(self, v: int, state: int) -> int:
"""

:param v:
:param state:
:return: -1 means INF
"""
dp = self.dp
edges = self.g.edges

if dp[state][v] is not None:
return dp[state][v]

if (state == (1 << self.g.v_num) - 1) and (v == 0):
dp[state][v] = 0
return dp[state][v]

ret: int = INT_INF
for u in range(self.g.v_num):
if (state & (1 << u)) == 0:
s: int = self._recurse(u, state | 1 << u)
if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])
dp[state][v] = ret
return ret


def main():
V, E = map(int, input().split())
g: Graph = Graph(V)
for _ in range(E):
src, dest, dist = map(int, input().split())
g.setDist(src, dest, dist)

tsp: TSPSolver = TSPSolver(g)
print(tsp.solve())


if __name__ == "__main__":
main()

Leetcode 1227 是一道有意思的概率题,本篇将从多个角度来讨论这道题。题目如下

有 n 位乘客即将登机,飞机正好有 n 个座位。第一位乘客的票丢了,他随便选了一个座位坐下。 剩下的乘客将会: 如果他们自己的座位还空着,就坐到自己的座位上, 当他们自己的座位被占用时,随机选择其他座位,第 n 位乘客坐在自己的座位上的概率是多少?

示例 1: 输入:n = 1 输出:1.00000 解释:第一个人只会坐在自己的位置上。

示例 2: 输入: n = 2 输出: 0.50000 解释:在第一个人选好座位坐下后,第二个人坐在自己的座位上的概率是 0.5。

提示: 1 <= n <= 10^5

假设规模为n时答案为f(n),一般来说,这种递推问题在数学形式上可能有关于n的简单数学表达式(closed form),或者肯定有f(n)关于f(n-k)的递推表达式。工程上,我们可以通过通过多次模拟即蒙特卡罗模拟来算得近似的数值解。

Monte Carlo 模拟发现规律

首先,我们先来看看如何高效的用代码来模拟。根据题意的描述过程,直接可以写出下面代码。seats为n大小的bool 数组,每个位置表示此位置是否已经被占据。然后依次给第i个人按题意分配座位。注意,每次参数随机数范围在[0,n-1],因此,会出现已经被占据的情况,此时需要再次随机,直至分配到空位。

暴力直接模拟

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def simulate_bruteforce(n: int) -> bool:
"""
Simulates one round. Unbounded time complexity.
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
seats[random.randint(0, n - 1)] = True
else:
if not seats[i]: # i-th has his seat
seats[i] = True
else:
while True:
rnd = random.randint(0, n - 1) # random until no conflicts
if not seats[rnd]:
seats[rnd] = True
break
return not seats[n-1]

运行上面的代码来模拟 n 从 2 到10 的情况,每种情况跑500次模拟,输出如下

1
2
3
4
5
6
7
8
9
10
1 => 1.0
2 => 0.55
3 => 0.54
4 => 0.486
5 => 0.488
6 => 0.498
7 => 0.526
8 => 0.504
9 => 0.482
10 => 0.494

发现当 n>=2 时,似乎概率都是0.5。

标准答案

其实,这道题的标准答案就是 n=1 为1,n>=2 为0.5。下面是 python 3 标准答案。本篇后面会从多个角度来探讨为什么是0.5 。

{linenos
1
2
3
class Solution:
def nthPersonGetsNthSeat(self, n: int) -> float:
return 1.0 if n == 1 else 0.5

O(n) 改进算法

上面的暴力直接模拟版本有个最大的问题是当n很大时,随机分配座位会产生大量冲突,因此,最坏复杂度是没有任何上限的。解决方法是每次发生随机分配时保证不冲突,能直接选到空位。下面是一种最坏复杂度O(n)的模拟过程,seats数组初始话成 0,1,...,n-1,表示座位号。当第i个人登机时,seats[i:n] 的值为他可以选择的座位集合,而seats[0:i]为已经被占据的座位集合。由于[i: n]是连续空间,产生随机数就能保证不冲突。当第i个人选完座位时,将他选中的seats[k]和seats[i] 交换,保证第i+i个人面临的seats[i+1:n]依然为可选座位集合。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def simulate_online(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [i for i in range(n)]

def swap(i, j):
tmp = seats[i]
seats[i] = seats[j]
seats[j] = tmp

# for each person, the seats array idx available are [i, n-1]
for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
swap(rnd, 0)
else:
if seats[i] == i: # i-th still has his seat
pass
else:
rnd = random.randint(i, n - 1) # selects idx from [i, n-1]
swap(rnd, i)
return seats[n-1] == n - 1

递推思维

这一节我们用数学递推思维来解释0.5的解。令f(n) 为第 n 位乘客坐在自己的座位上的概率,考察第一个人的情况(first step analysis),有三种可能

  1. 第一个人选了第一个即自己的座位,那么最后一个人一定能保证坐在自己的座位。
  2. 第一个人选了最后一个人的座位,无论中间什么过程,最后一个人无法坐到自己座位
  3. 第一个人选了第i个座位,(1<i<n),那么第i个人前面的除了第一个外的人都会坐在自己位置上,第i个人由于没有自己座位,随机在剩余的座位1,座位 [i+1,n] 中随机选择,此时,问题转变为f(n-i+1),如下图所示。
第一个人选了位置i
第i个人将问题转换成f(n-i+1)

通过上面分析,得到概率递推关系如下

\[ f(n) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 1 & & p=\frac{1}{n} \quad \text{选了第一个位置} \\\\\\ f(n-i+1) & & p=\frac{1}{n} \quad \text{选了第i个位置,1<i<n} \\\\\\ 0 & & p=\frac{1}{n} \quad \text{选了第n个位置} \end{array} \right. \end{align*} \]

即f(n)的递推式为: \[ f(n) = \frac{1}{n} + \frac{1}{n} \times [ f(n-1) + f(n-2) + ...+ f(2)], \quad n>=2 \] 同理,f(n+1)递推式如下 \[ f(n+1) = \frac{1}{n+1} + \frac{1}{n+1} \times [ f(n) + f(n-1) + ...+ f(2)] \] \((n+1)f(n+1) - nf(n)\) 抵消 \(f(n-1) + ...f(2)\) 项,可得 \[ (n+1)f(n+1) - nf(n) = f(n) \]\[ f(n+1) = f(n) = \frac{1}{2} \quad n>=2 \]

用数学归纳法也可以证明 n>=2 时 f(n)=0.5。

简化的思考方式

我们再仔细思考一下上面的第三种情况,就是第一个人坐了第i个座位,1<i<n,此时,程序继续,不产生结果,直至产生结局1或者2,也就是case 1和2是真正的结局节点,它们产生的概率相同,因此答案是1/2。

从调用图可以看出这种关系,由于中间节点 f(4),f(3),f(2)生成Case 1和2的概率一样,因此无论它们之间是什么关系,最后结果都是1/2.

知乎上有个很形象的类比理解方式

考虑一枚硬币,正面向上的概率为 1/n,反面也是,立起来的概率为 (n-2)/n 。我们规定硬币立起来重新抛,但重新抛时,n会至少减小1。求结果为反面的概率。这样很显然结果为 1/2 。

这里,正面向上对应Case 2,反面对应Case 1。

这种思想可以写出如下代码,seats为 n 大小的bool 数组,当第i个人(0<i<n)发现自己座位被占的话,此时必然seats[0]没有被占,同时seats[i+1:]都是空的。假设seats[0]被占的话,要么是第一个人占的,要么是第p个人(p<i)坐了,两种情况下乱序都已经恢复了,此时第i个座位一定是空的。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def simulate(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
seats[rnd] = True
else:
if not seats[i]: # i-th still has his seat
seats[i] = True
else:
# 0 must not be available, now we have 0 and [i+1, n-1],
rnd = random.randint(i, n - 1)
if rnd == i:
seats[0] = True
else:
seats[rnd] = True
return not seats[n-1]

继上一篇完成了井字棋(N子棋)的minimax 最佳策略后,我们基于Pygame来创造一个图形游戏环境,可供人机和机器对弈,为后续模拟AlphaGo的自我强化学习算法做环境准备。OpenAI Gym 在强化学习领域是事实标准,我们最终封装成OpenAI Gym的接口。本篇所有代码都在github.com/MyEncyclopedia/ConnectNGym

井字棋、五子棋 Pygame 实现

Pygame 井字棋玩家对弈效果

Python 上有Tkinter,PyQt等跨平台GUI类库,主要用于桌面程序编程,但此类库容量较大,编程也相对麻烦。Pygame具有代码少,开发快的优势,比较适合快速开发五子棋这类桌面小游戏。 ### Pygame 极简入门

与所有的GUI开发相同,Pygame也是基于事件的单线程编程模型。下面的例子包含了显示一个最简单GUI窗口,操作系统产生事件并发送到Pygame窗口,while True 控制了python主线程永远轮询事件。我们在这里仅仅判断了当前是否是关闭应用程序事件,如果是则退出进程。此外,clock 用于控制FPS。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
import sys
import pygame
pygame.init()
display = pygame.display.set_mode((800,600))
clock = pygame.time.Clock()

while True:
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit(0)
else:
pygame.display.update()
clock.tick(1)

PyGameBoard 主体代码

PyGameBoard类封装了Pygame实现游戏交互和显示的逻辑。上一篇中,我们完成了ConnectNGame逻辑,这里PyGameBoard需要在初始化时,指定传入ConnectNGame 实例(见下图),支持通过API 方式改变其状态,也支持GUI交互方式等待人类玩家的输入。next_user_input(self)实现了等待人类玩家输入的逻辑,本质上是循环检查GUI事件直到有合法的落子产生。
PyGameBoard Class Diagram
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class PyGameBoard:

def __init__(self, connectNGame: ConnectNGame):
self.connectNGame = connectNGame
pygame.init()

def next_user_input(self) -> Tuple[int, int]:
self.action = None
while not self.action:
self.check_event()
self._render()
self.clock.tick(60)
return self.action

def move(self, r: int, c: int) -> int:
return self.connectNGame.move(r, c)

if __name__ == '__main__':
connectNGame = ConnectNGame()
pygameBoard = PyGameBoard(connectNGame)
while not pygameBoard.isGameOver():
pos = pygameBoard.next_user_input()
pygameBoard.move(*pos)

pygame.quit()

check_event 较之极简版本增加了处理用户输入事件,这里我们仅支持人类玩家鼠标输入。方法_handle_user_input 将鼠标点击事件转换成棋盘行列值,并判断点击位置是否合法,合法则返回落子位置,类型为Tuple[int, int],例如(0, 0)表示棋盘最左上角位置。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def check_event(self):
for e in pygame.event.get():
if e.type == pygame.QUIT:
pygame.quit()
sys.exit(0)
elif e.type == pygame.MOUSEBUTTONDOWN:
self._handle_user_input(e)

def _handle_user_input(self, e: Event) -> Tuple[int, int]:
origin_x = self.start_x - self.edge_size
origin_y = self.start_y - self.edge_size
size = (self.board_size - 1) * self.grid_size + self.edge_size * 2
pos = e.pos
if origin_x <= pos[0] <= origin_x + size and origin_y <= pos[1] <= origin_y + size:
if not self.connectNGame.gameOver:
x = pos[0] - origin_x
y = pos[1] - origin_y
r = int(y // self.grid_size)
c = int(x // self.grid_size)
valid = self.connectNGame.checkAction(r, c)
if valid:
self.action = (r, c)
return self.action

OpenAI Gym 接口规范

OpenAI Gym规范了Agent和环境(Env)之间的互动,核心抽象接口类是gym.Env,自定义的游戏环境需要继承Env,并实现 reset、step和render方法。下面我们看一下如何具体实现ConnectNGym的这几个方法:

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ConnectNGym(gym.Env):

def reset(self) -> ConnectNGame:
"""Resets the state of the environment and returns an initial observation.

Returns:
observation (object): the initial observation.
"""
raise NotImplementedError


def step(self, action: Tuple[int, int]) -> Tuple[ConnectNGame, int, bool, None]:
"""Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
to reset this environment's state.

Accepts an action and returns a tuple (observation, reward, done, info).

Args:
action (object): an action provided by the agent

Returns:
observation (object): agent's observation of the current environment
reward (float) : amount of reward returned after previous action
done (bool): whether the episode has ended, in which case further step() calls will return undefined results
info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)
"""
raise NotImplementedError



def render(self, mode='human'):
"""
Renders the environment.

The set of supported modes varies per environment. (And some
environments do not support rendering at all.) By convention,
if mode is:

- human: render to the current display or terminal and
return nothing. Usually for human consumption.
- rgb_array: Return an numpy.ndarray with shape (x, y, 3),
representing RGB values for an x-by-y pixel image, suitable
for turning into a video.
- ansi: Return a string (str) or StringIO.StringIO containing a
terminal-style text representation. The text can include newlines
and ANSI escape sequences (e.g. for colors).

Note:
Make sure that your class's metadata 'render.modes' key includes
the list of supported modes. It's recommended to call super()
in implementations to use the functionality of this method.

Args:
mode (str): the mode to render with
"""
raise NotImplementedError

reset 方法

1
def reset(self) -> ConnectNGame

重置环境状态,并返回给Agent重置后环境下观察到的状态。ConnectNGym内部维护了ConnectNGame实例作为自身状态,每个agent落子后会更新这个实例。由于棋类游戏对于玩家来说是完全信息的,我们直接返回ConnectNGame的deepcopy。

step 方法

1
def step(self, action: Tuple[int, int]) -> Tuple[ConnectNGame, int, bool, None]

Agent 选择了某一action后,由环境来执行这个action并返回4个值:1. 执行后的环境Agent观察到的状态;2. 环境执行了这个action回馈给agent的reward;3. 环境是否结束;4. 其余信息。

step方法是最核心的接口,因此举例来说明ConnectNGym中的输入和输出:

初始状态
状态 ((0, 0, 0), (0, 0, 0), (0, 0, 0))

Agent A 选择action = (0, 0),执行ConnectNGym.step 后返回值:status = ((1, 0, 0), (0, 0, 0), (0, 0, 0)),reward = 0,game_end = False

状态 ((1, 0, 0), (0, 0, 0), (0, 0, 0))

Agent B 选择action = (1, 1),执行ConnectNGym.step 后返回值:status = ((1, 0, 0), (0, -1, 0), (0, 0, 0)),reward = 0,game_end = False

状态 ((1, 0, 0), (0, -1, 0), (0, 0, 0))
重复此过程直至游戏结束,下面是5步后游戏可能达到的最终状态
终结状态 ((1, 1, 1), (-1, -1, 0), (0, 0, 0))

此时step的返回值为:status = ((1, 1, 1), (-1, -1, 0), (0, 0, 0)),reward = 1,game_end = True

render 方法

1
def render(self, mode='human')

展现环境,通过mode区分是否是人类玩家。

ConnectNGym 代码

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ConnectNGym(gym.Env):

def __init__(self, pygameBoard: PyGameBoard, isGUI=True, displaySec=2):
self.pygameBoard = pygameBoard
self.isGUI = isGUI
self.displaySec = displaySec
self.action_space = spaces.Discrete(pygameBoard.board_size * pygameBoard.board_size)
self.observation_space = spaces.Discrete(pygameBoard.board_size * pygameBoard.board_size)
self.seed()
self.reset()

def reset(self) -> ConnectNGame:
self.pygameBoard.connectNGame.reset()
return copy.deepcopy(self.pygameBoard.connectNGame)

def step(self, action: Tuple[int, int]) -> Tuple[ConnectNGame, int, bool, None]:
# assert self.action_space.contains(action)

r, c = action
reward = REWARD_NONE
result = self.pygameBoard.move(r, c)
if self.pygameBoard.isGameOver():
reward = result

return copy.deepcopy(self.pygameBoard.connectNGame), reward, not result is None, None

def render(self, mode='human'):
if not self.isGUI:
self.pygameBoard.connectNGame.drawText()
time.sleep(self.displaySec)
else:
self.pygameBoard.display(sec=self.displaySec)

def get_available_actions(self) -> List[Tuple[int, int]]:
return self.pygameBoard.getAvailablePositions()

井字棋(N子棋)Minimax策略玩家

图中当k=3,m=n=3即井字棋游戏中,两个minimax策略玩家的对弈效果,游戏结局符合已知的结论:井字棋的解是先手被对方逼平。

Minimax策略AI对弈

镜像游戏状态的DP处理

上一篇中,我们确认了井字棋的总状态数是5478。当k=3, m=n=4时是6035992,k=4, m=n=4时是9722011,总的来说游戏状态数是以指数级增长的。上一版minimax DP策略还有改善的空间,第一种是旋转格局的处理。对于任意一种棋盘格局可以得到90度旋转后的另外三种格局,它们的最佳结局是一致的。因此,我们在递归过程中解得某一棋盘格局后,将其另外三种旋转后格局的解也一起缓存起来。例如:

游戏状态1
旋转后的三种游戏状态
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def similarStatus(self, status: Tuple[Tuple[int, ...]]) -> List[Tuple[Tuple[int, ...]]]:
ret = []
rotatedS = status
for _ in range(4):
rotatedS = self.rotate(rotatedS)
ret.append(rotatedS)
return ret

def rotate(self, status: Tuple[Tuple[int, ...]]) -> Tuple[Tuple[int, ...]]:
N = len(status)
board = [[ConnectNGame.AVAILABLE] * N for _ in range(N)]

for r in range(N):
for c in range(N):
board[c][N - 1 - r] = status[r][c]

return tuple([tuple(board[i]) for i in range(N)])

Minimax 策略预计算

之前我们对每个棋局去计算最佳的下一步,并在此过程中做了剪枝,即当已经找到当前玩家必胜落子时直接返回。这对于单一局面的计算是较优的,但是AI Agent 需要在每一步都重复这个过程,当棋盘大小>3时运算非常耗时,因此我们来做第二种优化。初始空棋盘时使用Minimax来保证遍历所有状态,缓存所有棋局的最佳结果。对于AI Agent面临的每个棋局只需查找此棋局下所有的可能落子位置,并返回最佳决定,这样大大减少了每次棋局下重复的minimax递归计算。相关代码如下。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class PlannedMinimaxStrategy(Strategy):
def __init__(self, game: ConnectNGame):
super().__init__()
self.game = copy.deepcopy(game)
self.dpMap = {} # game_status => result, move
self.result = self.minimax(game.getStatus())


def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
game = copy.deepcopy(game)

player = game.currentPlayer
bestResult = player * -1 # assume opponent win as worst result
bestMove = None
for move in game.getAvailablePositions():
game.move(*move)
status = game.getStatus()
game.undo()

result = self.dpMap[status]

if player == ConnectNGame.PLAYER_A:
bestResult = max(bestResult, result)
else:
bestResult = min(bestResult, result)
# update bestMove if any improvement
bestMove = move if bestResult == result else bestMove
print(f'move {move} => {result}')

return bestResult, bestMove

Agent 类和对弈逻辑

Agent 类的抽象并不是 OpenAI Gym的规范,出于代码扩展性,我们也封装了Agent基类及其子类,包括AI玩家和人类玩家。BaseAgent需要子类实现 act方法,默认实现为随机决定。

{linenos
1
2
3
4
5
6
class BaseAgent(object):
def __init__(self):
pass

def act(self, game: PyGameBoard, available_actions):
return random.choice(available_actions)

AIAgent 实现act并代理给 strategy 的action方法。

{linenos
1
2
3
4
5
6
7
8
class AIAgent(BaseAgent):
def __init__(self, strategy: Strategy):
self.strategy = strategy

def act(self, game: PyGameBoard, available_actions):
result, move = self.strategy.action(game.connectNGame)
assert move in available_actions
return move

HumanAgent 实现act并代理给 PyGameBoard 的next_user_input方法。

{linenos
1
2
3
4
5
6
class HumanAgent(BaseAgent):
def __init__(self):
pass

def act(self, game: PyGameBoard, available_actions):
return game.next_user_input()
Agent Class Diagram

下面代码展示如何将Agent,ConnectNGym,PyGameBoard 等所有上述类串联起来,完成人人对弈,人机对弈。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def play_ai_vs_ai(env: ConnectNGym):
plannedMinimaxAgent = AIAgent(PlannedMinimaxStrategy(env.pygameBoard.connectNGame))
play(env, plannedMinimaxAgent, plannedMinimaxAgent)


def play(env: ConnectNGym, agent1: BaseAgent, agent2: BaseAgent):
agents = [agent1, agent2]

while True:
env.reset()
done = False
agent_id = -1
while not done:
agent_id = (agent_id + 1) % 2
available_actions = env.get_available_actions()
agent = agents[agent_id]
action = agent.act(pygameBoard, available_actions)
_, reward, done, info = env.step(action)
env.render(True)

if done:
print(f'result={reward}')
time.sleep(3)
break


if __name__ == '__main__':
pygameBoard = PyGameBoard(connectNGame=ConnectNGame(board_size=3, N=3))
env = ConnectNGym(pygameBoard)
env.render(True)

play_ai_vs_ai(env)
Class Diagram 总览

继上一篇介绍了Minimax 和Alpha Beta 剪枝算法之后,本篇选择了Leetcode中的井字棋游戏题目,积累相关代码后实现井字棋游戏并扩展到五子棋和N子棋(战略井字棋),随后用Minimax和Alpha Beta剪枝算法解得小规模下N子棋的游戏结局,并分析其状态数量和每一步的最佳策略。后续篇章中,我们基于本篇代码完成一个N子棋的OpenAI Gym 图形环境,可用于人机对战或机器对战,并最终实现棋盘规模稍大的五子棋或者N子棋中的蒙特卡洛树搜索(MCTS)算法。

Leetcode 上的井字棋系列

Leetcode 1275. 找出井字棋的获胜者 (简单)

A 和 B 在一个 3 x 3 的网格上玩井字棋。
井字棋游戏的规则如下:
玩家轮流将棋子放在空方格 (" ") 上。
第一个玩家 A 总是用 "X" 作为棋子,而第二个玩家 B 总是用 "O" 作为棋子。
"X" 和 "O" 只能放在空方格中,而不能放在已经被占用的方格上。
只要有 3 个相同的(非空)棋子排成一条直线(行、列、对角线)时,游戏结束。
如果所有方块都放满棋子(不为空),游戏也会结束。
游戏结束后,棋子无法再进行任何移动。
给你一个数组 moves,其中每个元素是大小为 2 的另一个数组(元素分别对应网格的行和列),它按照 A 和 B 的行动顺序(先 A 后 B)记录了两人各自的棋子位置。
如果游戏存在获胜者(A 或 B),就返回该游戏的获胜者;如果游戏以平局结束,则返回 "Draw";如果仍会有行动(游戏未结束),则返回 "Pending"。
你可以假设 moves 都 有效(遵循井字棋规则),网格最初是空的,A 将先行动。

示例 1:
输入:moves = [[0,0],[2,0],[1,1],[2,1],[2,2]]
输出:"A"
解释:"A" 获胜,他总是先走。
"X " "X " "X " "X " "X "
" " -> " " -> " X " -> " X " -> " X "
" " "O " "O " "OO " "OOX"

示例 2: 输入:moves = [[0,0],[1,1],[0,1],[0,2],[1,0],[2,0]]
输出:"B"
解释:"B" 获胜。
"X " "X " "XX " "XXO" "XXO" "XXO"
" " -> " O " -> " O " -> " O " -> "XO " -> "XO "
" " " " " " " " " " "O "

第一种解法,检查A或者B赢的所有可能情况:某玩家占据8种连线的任意一种情况则胜利,我们使用八个变量来保存所有情况。下面的代码使用了一个小技巧,将moves转换成3x3的棋盘状态数组,元素的值为1,-1和0。1,-1代表两个玩家,0代表空的棋盘格子,其优势在于后续我们只需累加棋盘的值到八个变量中关联的若干个,再检查这八个变量是否满足取胜条件。例如,row[0]表示第一行的状态,当遍历一次所有棋盘格局后,row[0]为第一行的3个格子的总和,只有当row[0] == 3 才表明玩家A占据了第一行,-3表明玩家B占据了第一行。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# AC
from typing import List

class Solution:
def tictactoe(self, moves: List[List[int]]) -> str:
board = [[0] * 3 for _ in range(3)]
for idx, xy in enumerate(moves):
player = 1 if idx % 2 == 0 else -1
board[xy[0]][xy[1]] = player

turn = 0
row, col = [0, 0, 0], [0, 0, 0]
diag1, diag2 = False, False
for r in range(3):
for c in range(3):
turn += board[r][c]
row[r] += board[r][c]
col[c] += board[r][c]
if r == c:
diag1 += board[r][c]
if r + c == 2:
diag2 += board[r][c]

oWin = any(row[r] == 3 for r in range(3)) or any(col[c] == 3 for c in range(3)) or diag1 == 3 or diag2 == 3
xWin = any(row[r] == -3 for r in range(3)) or any(col[c] == -3 for c in range(3)) or diag1 == -3 or diag2 == -3

return "A" if oWin else "B" if xWin else "Draw" if len(moves) == 9 else "Pending"

下面我们给出另一种解法,这种解法虽然代码较多,但可以不必遍历棋盘每个格子,比上一种严格遍历一次棋盘的解法略为高效。原理如下,题目保证了moves过程中不会产生输赢结果,因此我们直接检查最后一个棋子向外的八个方向,若任意方向有三连子,则此玩家获胜。这种解法主要是为后续井字棋扩展到五子棋时判断每个落子是否产生输赢做代码准备。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# AC
from typing import List

class Solution:
def checkWin(self, r: int, c: int) -> bool:
north = self.getConnectedNum(r, c, -1, 0)
south = self.getConnectedNum(r, c, 1, 0)

east = self.getConnectedNum(r, c, 0, 1)
west = self.getConnectedNum(r, c, 0, -1)

south_east = self.getConnectedNum(r, c, 1, 1)
north_west = self.getConnectedNum(r, c, -1, -1)

north_east = self.getConnectedNum(r, c, -1, 1)
south_west = self.getConnectedNum(r, c, 1, -1)

if (north + south + 1 >= 3) or (east + west + 1 >= 3) or \
(south_east + north_west + 1 >= 3) or (north_east + south_west + 1 >= 3):
return True
return False

def getConnectedNum(self, r: int, c: int, dr: int, dc: int) -> int:
player = self.board[r][c]
result = 0
i = 1
while True:
new_r = r + dr * i
new_c = c + dc * i
if 0 <= new_r < 3 and 0 <= new_c < 3:
if self.board[new_r][new_c] == player:
result += 1
else:
break
else:
break
i += 1
return result

def tictactoe(self, moves: List[List[int]]) -> str:
self.board = [[0] * 3 for _ in range(3)]
for idx, xy in enumerate(moves):
player = 1 if idx % 2 == 0 else -1
self.board[xy[0]][xy[1]] = player

# only check last move
r, c = moves[-1]
win = self.checkWin(r, c)
if win:
return "A" if len(moves) % 2 == 1 else "B"

return "Draw" if len(moves) == 9 else "Pending"

Leetcode 794. 有效的井字游戏 (中等)

用字符串数组作为井字游戏的游戏板 board。当且仅当在井字游戏过程中,玩家有可能将字符放置成游戏板所显示的状态时,才返回 true。
该游戏板是一个 3 x 3 数组,由字符 " ","X" 和 "O" 组成。字符 " " 代表一个空位。
以下是井字游戏的规则:
玩家轮流将字符放入空位(" ")中。
第一个玩家总是放字符 “X”,且第二个玩家总是放字符 “O”。
“X” 和 “O” 只允许放置在空位中,不允许对已放有字符的位置进行填充。
当有 3 个相同(且非空)的字符填充任何行、列或对角线时,游戏结束。
当所有位置非空时,也算为游戏结束。
如果游戏结束,玩家不允许再放置字符。

示例 1:
输入: board = ["O ", " ", " "]
输出: false
解释: 第一个玩家总是放置“X”。

示例 2:
输入: board = ["XOX", " X ", " "]
输出: false
解释: 玩家应该是轮流放置的。

示例 3:
输入: board = ["XXX", " ", "OOO"]
输出: false

示例 4:
输入: board = ["XOX", "O O", "XOX"]
输出: true
说明:

游戏板 board 是长度为 3 的字符串数组,其中每个字符串 board[i] 的长度为 3。 board[i][j] 是集合 {" ", "X", "O"} 中的一个字符。

这道题第一反应是需要DFS来判断给定状态是否可达,但其实可以用上面1275的思路,即通过检验最终棋盘的一些特点来判断给定状态是否合法。比如,X和O的数量只有可能相同,或X比O多一个。其关键在于需要找到判断状态合法的充要条件,就可以在\(O(1)\) 时间复杂度完成判断。 此外,这道题给了我们井字棋所有可能状态数量的启示。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# AC
from typing import List

class Solution:

def convertCell(self, c:str):
return 1 if c == 'X' else -1 if c == 'O' else 0

def validTicTacToe(self, board: List[str]) -> bool:
turn = 0
row, col = [0, 0, 0], [0, 0, 0]
diag1, diag2 = False, False
for r in range(3):
for c in range(3):
turn += self.convertCell(board[r][c])
row[r] += self.convertCell(board[r][c])
col[c] += self.convertCell(board[r][c])
if r == c:
diag1 += self.convertCell(board[r][c])
if r + c == 2:
diag2 += self.convertCell(board[r][c])

xWin = any(row[r] == 3 for r in range(3)) or any(col[c] == 3 for c in range(3)) or diag1 == 3 or diag2 == 3
oWin = any(row[r] == -3 for r in range(3)) or any(col[c] == -3 for c in range(3)) or diag1 == -3 or diag2 == -3
if (xWin and turn == 0) or (oWin and turn == 1):
return False
return (turn == 0 or turn == 1) and (not xWin or not oWin)

Leetcode 348. 判定井字棋胜负 (中等,加锁)

请在 n × n 的棋盘上,实现一个判定井字棋(Tic-Tac-Toe)胜负的神器,判断每一次玩家落子后,是否有胜出的玩家。
在这个井字棋游戏中,会有 2 名玩家,他们将轮流在棋盘上放置自己的棋子。
在实现这个判定器的过程中,你可以假设以下这些规则一定成立:
每一步棋都是在棋盘内的,并且只能被放置在一个空的格子里;
一旦游戏中有一名玩家胜出的话,游戏将不能再继续;
一个玩家如果在同一行、同一列或者同一斜对角线上都放置了自己的棋子,那么他便获得胜利。

示例: 给定棋盘边长 n = 3, 玩家 1 的棋子符号是 "X",玩家 2 的棋子符号是 "O"。
TicTacToe toe = new TicTacToe(3);
toe.move(0, 0, 1); -> 函数返回 0 (此时,暂时没有玩家赢得这场对决)
|X| | |
| | | | // 玩家 1 在 (0, 0) 落子。
| | | |

toe.move(0, 2, 2); -> 函数返回 0 (暂时没有玩家赢得本场比赛)
|X| |O|
| | | | // 玩家 2 在 (0, 2) 落子。
| | | |

toe.move(2, 2, 1); -> 函数返回 0 (暂时没有玩家赢得比赛)
|X| |O|
| | | | // 玩家 1 在 (2, 2) 落子。
| | |X|

toe.move(1, 1, 2); -> 函数返回 0 (暂没有玩家赢得比赛)
|X| |O|
| |O| | // 玩家 2 在 (1, 1) 落子。
| | |X|

toe.move(2, 0, 1); -> 函数返回 0 (暂无玩家赢得比赛)
|X| |O|
| |O| | // 玩家 1 在 (2, 0) 落子。
|X| |X|

toe.move(1, 0, 2); -> 函数返回 0 (没有玩家赢得比赛)
|X| |O|
|O|O| | // 玩家 2 在 (1, 0) 落子.
|X| |X|

toe.move(2, 1, 1); -> 函数返回 1 (此时,玩家 1 赢得了该场比赛)
|X| |O|
|O|O| | // 玩家 1 在 (2, 1) 落子。
|X|X|X|

348 是道加锁题,对于每次玩家的move,可以用1275第二种解法中的checkWin 函数。下面代码给出了另一种基于1275解法一的方法:保存八个关键变量,每次落子后更新这个子所关联的某几个变量。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# AC
class TicTacToe:

def __init__(self, n:int):
"""
Initialize your data structure here.
:type n: int
"""
self.row, self.col, self.diag1, self.diag2, self.n = [0] * n, [0] * n, 0, 0, n

def move(self, row:int, col:int, player:int) -> int:
"""
Player {player} makes a move at ({row}, {col}).
@param row The row of the board.
@param col The column of the board.
@param player The player, can be either 1 or 2.
@return The current winning condition, can be either:
0: No one wins.
1: Player 1 wins.
2: Player 2 wins.
"""
if player == 2:
player = -1

self.row[row] += player
self.col[col] += player
if row == col:
self.diag1 += player
if row + col == self.n - 1:
self.diag2 += player

if self.n in [self.row[row], self.col[col], self.diag1, self.diag2]:
return 1
if -self.n in [self.row[row], self.col[col], self.diag1, self.diag2]:
return 2
return 0


井字棋最佳策略

井字棋的规模可以很自然的扩展成四子棋或五子棋等,区别在于棋盘大小和胜利时的连子数量。这类游戏最一般的形式为 M,n,k-game,中文可能翻译为战略井字游戏,表示棋盘大小为M x N,当k连子时获胜。 下面的ConnectNGame类实现了战略井字游戏(M=N)中,两个玩家轮流下子、更新棋盘状态和判断每次落子输赢等逻辑封装。其中undo方法用于撤销最后一个落子,方便在后续寻找最佳策略时回溯。

ConnectNGame

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ConnectNGame:

PLAYER_A = 1
PLAYER_B = -1
AVAILABLE = 0
RESULT_TIE = 0
RESULT_A_WIN = 1
RESULT_B_WIN = -1

def __init__(self, N:int = 3, board_size:int = 3):
assert N <= board_size
self.N = N
self.board_size = board_size
self.board = [[ConnectNGame.AVAILABLE] * board_size for _ in range(board_size)]
self.gameOver = False
self.gameResult = None
self.currentPlayer = ConnectNGame.PLAYER_A
self.remainingPosNum = board_size * board_size
self.actionStack = []

def move(self, r: int, c: int) -> int:
"""

:param r:
:param c:
:return: None: game ongoing
"""
assert self.board[r][c] == ConnectNGame.AVAILABLE
self.board[r][c] = self.currentPlayer
self.actionStack.append((r, c))
self.remainingPosNum -= 1
if self.checkWin(r, c):
self.gameOver = True
self.gameResult = self.currentPlayer
return self.currentPlayer
if self.remainingPosNum == 0:
self.gameOver = True
self.gameResult = ConnectNGame.RESULT_TIE
return ConnectNGame.RESULT_TIE
self.currentPlayer *= -1

def undo(self):
if len(self.actionStack) > 0:
lastAction = self.actionStack.pop()
r, c = lastAction
self.board[r][c] = ConnectNGame.AVAILABLE
self.currentPlayer = ConnectNGame.PLAYER_A if len(self.actionStack) % 2 == 0 else ConnectNGame.PLAYER_B
self.remainingPosNum += 1
self.gameOver = False
self.gameResult = None
else:
raise Exception('No lastAction')

def getAvailablePositions(self) -> List[Tuple[int, int]]:
return [(i,j) for i in range(self.board_size) for j in range(self.board_size) if self.board[i][j] == ConnectNGame.AVAILABLE]

def getStatus(self) -> Tuple[Tuple[int, ...]]:
return tuple([tuple(self.board[i]) for i in range(self.board_size)])

其中checkWin和1275解法二中的逻辑一致。

Minimax 算法

此战略井字游戏的逻辑代码,结合之前的minimax算法,可以实现游戏最佳策略。

先定义一个通用的策略基类和抽象方法 action。action表示给定一个棋盘状态,返回一个动作决定。返回Tuple的第一个int值表示估计走这一步的结局,第二个值类型是Tuple[int, int],表示这次落子的位置,例如(1,1)。

{linenos
1
2
3
4
5
6
7
8
class Strategy(ABC):

def __init__(self):
super().__init__()

@abstractmethod
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
pass
MinimaxStrategy 的逻辑和之前的minimax模版算法大致相同,多了保存最佳move对应的动作,用于最后返回。
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class MinimaxStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = copy.deepcopy(game)
result, move = self.minimax()
return result, move

def minimax(self) -> Tuple[int, Tuple[int, int]]:
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax()
game.undo()
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if ret == 1:
return 1, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax()
game.undo()
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if ret == -1:
return -1, move
return ret, bestMove
通过上面的代码可以画出初始两步的井字棋最终结局。对于先手O来说可以落9个位置,排除对称位置后只有三种,分别为角落,边上和正中。但无论哪一个位置作为先手,最好的结局都是被对方逼平,不存在必赢的开局。所以井字棋的结局是:如果两个玩家都采用最优策略(无失误),游戏结果为双方逼平。
井字棋第一步结局
下面分别画出三种开局后进一步的游戏结局。
井字棋角落开局
井字棋边上开局
井字棋中间开局

井字棋游戏状态数和解

有趣的是井字棋游戏的状态数量,简单的上限估算是\(3^9=19683\)。这显然是个较宽泛的上限,因为很多状态在游戏结束后无法达到。 这篇文章 Tic-Tac-Toe (Naughts and Crosses, Cheese and Crackers, etc 中列出了每一步的状态数,合计5478个。

Moves Positions Terminal Positions
0 1
1 9
2 72
3 252
4 756
5 1260 120
6 1520 148
7 1140 444
8 390 168
9 78 78
Total 5478 958

我们已经实现了井字棋的minimax策略,算法本质上遍历了所有情况,稍加改造后增加dp数组,就可以确认上面的总状态数。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

class CountingMinimaxStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = copy.deepcopy(game)
self.dpMap = {}
result, move = self.minimax(game.getStatus())
return result, move

def minimax(self, gameStatus: Tuple[Tuple[int, ...]]) -> Tuple[int, Tuple[int, int]]:
# print(f'Current {len(strategy.dpMap)}')

if gameStatus in self.dpMap:
return self.dpMap[gameStatus]

game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax(game.getStatus())
self.dpMap[game.getStatus()] = result, oppMove
else:
self.dpMap[game.getStatus()] = result, move
game.undo()
ret = max(ret, result)
bestMove = move if ret == result else bestMove
self.dpMap[gameStatus] = ret, bestMove
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)

if result is None:
assert not game.gameOver
result, oppMove = self.minimax(game.getStatus())
self.dpMap[game.getStatus()] = result, oppMove
else:
self.dpMap[game.getStatus()] = result, move
game.undo()
ret = min(ret, result)
bestMove = move if ret == result else bestMove
self.dpMap[gameStatus] = ret, bestMove
return ret, bestMove


if __name__ == '__main__':
tic_tac_toe = ConnectNGame(N=3, board_size=3)
strategy = CountingMinimaxStrategy()
strategy.action(tic_tac_toe)
print(f'Game States Number {len(strategy.dpMap)}')

运行程序证实了井字棋状态数为5478,下面是一些极小规模时代码运行结果:

3x3 4x4
k=3 5478 (Draw) 6035992 (Win)
k=4 9722011 (Draw)
k=5

根据 Wikipedia M,n,k-game, 列出了一些小规模下的游戏解:

3x3 4x4 5x5 6x6
k=3 Draw Win Win Win
k=4 Draw Draw Win
k=5 Draw Draw

值得一提的是,五子棋(棋盘15x15或以上)被 L. Victor Allis证明是先手赢。

Alpha-Beta剪枝策略

Alpha Beta 剪枝策略的代码如下(和之前代码比较类似,不再赘述):

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class AlphaBetaStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = game
result, move = self.alpha_beta(self.game.getStatus(), -math.inf, math.inf)
return result, move

def alpha_beta(self, gameStatus: Tuple[Tuple[int, ...]], alpha:int=None, beta:int=None) -> Tuple[int, Tuple[int, int]]:
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.alpha_beta(game.getStatus(), alpha, beta)
game.undo()
alpha = max(alpha, result)
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == 1:
return ret, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.alpha_beta(game.getStatus(), alpha, beta)
game.undo()
beta = min(beta, result)
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == -1:
return ret, move
return ret, bestMove

Alpha Beta 的DP版本中,由于lru_cache无法指定cache的有效参数,递归函数并没有传入alpha, beta。因此我们将alpha,beta参数隐式放入自己维护的栈中,并保证栈的状态和alpha_beta_dp函数调用状态一致。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class AlphaBetaDPStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = game
self.alphaBetaStack = [(-math.inf, math.inf)]
result, move = self.alpha_beta_dp(self.game.getStatus())
return result, move

@lru_cache(maxsize=None)
def alpha_beta_dp(self, gameStatus: Tuple[Tuple[int, ...]]) -> Tuple[int, Tuple[int, int]]:
alpha, beta = self.alphaBetaStack[-1]
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
self.alphaBetaStack.append((alpha, beta))
result, oppMove = self.alpha_beta_dp(game.getStatus())
self.alphaBetaStack.pop()
game.undo()
alpha = max(alpha, result)
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == 1:
return ret, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
self.alphaBetaStack.append((alpha, beta))
result, oppMove = self.alpha_beta_dp(game.getStatus())
self.alphaBetaStack.pop()
game.undo()
beta = min(beta, result)
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == -1:
return ret, move
return ret, bestMove

本系列,我们来看看在一种常见的组合游戏——回合制棋盘类游戏中,如何用算法来解决问题。首先,我们会介绍并解决搜索空间较小的问题,引入经典的博弈算法和相关理论,最终实现在大搜索空间中的Deep RL近似算法。在此基础上可以理解AlphaGo的原理和工作方式。 本系列的第一篇,我们介绍3个Leetcode中的零和回合制游戏,从最初的暴力解法,到动态规划最终演变成博弈论里的经典算法: minimax 以及 alpha beta 剪枝。

Leetcode 292 Nim Game (简单)

简单题 Leetcode 292 Nim Game

你和你的朋友,两个人一起玩 Nim游戏:桌子上有一堆石头,每次你们轮流拿掉 1 - 3 块石头。 拿掉最后一块石头的人就是获胜者。你作为先手。
你们是聪明人,每一步都是最优解。 编写一个函数,来判断你是否可以在给定石头数量的情况下赢得游戏。

示例:
输入: 4
输出: false
解释: 如果堆中有 4 块石头,那么你永远不会赢得比赛;因为无论你拿走 1 块、2 块 还是 3 块石头,最后一块石头总是会被你的朋友拿走。

定义 \(f(n)\) 为有\(n\)个石头并采取最优策略的游戏结果, \(f(n)\)的值只有可能是赢或者输。考察前几个结果:\(f(1) = f(2) = f(3) = Win\),然后来计算\(f(4)\)。因为玩家采取最优策略(只要有一种走法让对方必输,玩家获胜),对于4来说,玩家能走的可能是拿掉1块、2块或3块,但是无论剩余何种局面,对方都是必赢,因此,4就是必输。总的说来,递归关系如下: \[ f(n) = \neg (f(n-1) \land f(n-2) \land f(n-3)) \]

这个递归式可以直接翻译成Python 3代码
{linenos
1
2
3
4
5
6
7
8
9
10
11
# TLE
# Time Complexity: O(exponential)
class Solution_BruteForce:

def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
for i in range(1, 4):
if not self.canWinNim(n - i):
return True
return False
以上的递归公式和代码很像fibonacci数的递归定义和暴力解法,因此对应的时间复杂度也是指数级的,提交代码以后会TLE。下图画出了当n=7时的递归调用,注意 5 被扩展向下重复执行了两次,4重复了4次。
292 Nim Game 暴力解法调用图 n=7
我们采用和fibonacci一样的方式来优化算法:缓存较小n的结果以此来计算较大n的结果。 Python 中,我们可以只加一行lru_cache decorator,来取得这种动态规划效果,下面的代码将复杂度降到了 \(O(N)\)
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
# RecursionError: maximum recursion depth exceeded in comparison n=1348820612
# Time Complexity: O(N)
class Solution_DP:
from functools import lru_cache
@lru_cache(maxsize=None)
def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
for i in range(1, 4):
if not self.canWinNim(n - i):
return True
return False
再来画出调用图:这次5和4就不再被展开重复计算,图中绿色的节点表示缓存命中。
292 Nim Game 动归解法调用图 n=7

但还是没有AC,因为当n=1348820612时,这种方式会导致栈溢出。再改成下面的循环版本,可惜还是TLE。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# TLE for 1348820612
# Time Complexity: O(N)
class Solution:
def canWinNim(self, n: int) -> bool:
if n <= 3:
return True
last3, last2, last1 = True, True, True
for i in range(4, n+1):
this = not (last3 and last2 and last1)
last3, last2, last1 = last2, last1, this
return last1

由此看来,AC 版本需要低于\(O(n)\)的算法复杂度。上面的写法似乎暗示输赢有周期性的规律。事实上,如果将输赢按照顺序画出来,就马上得出规律了:只要\(n \mod 4 = 0\) 就是输,否则赢。原因如下:当面临不能被4整除的数量时 \(4k+i (i=1,2,3)\) ,一方总是可以拿走 \(i\) 个,将\(4k\) 留给对手,而对方下轮又将返回不能被4整除的数,如此循环往复,直到这一方有1, 2, 3 个,最终获胜。

输赢分布

最终AC版本,只有一句语句。

{linenos
1
2
3
4
5
# AC
# Time Complexity: O(1)
class Solution:
def canWinNim(self, n: int) -> bool:
return not (n % 4 == 0)

Leetcode 486 Predict the Winner (中等)

中等难度题目: Leetcode 486 Predict the Winner.

给定一个表示分数的非负整数数组。 玩家1从数组任意一端拿取一个分数,随后玩家2继续从剩余数组任意一端拿取分数,然后玩家1拿,……。每次一个玩家只能拿取一个分数,分数被拿取之后不再可取。直到没有剩余分数可取时游戏结束。最终获得分数总和最多的玩家获胜。
给定一个表示分数的数组,预测玩家1是否会成为赢家。你可以假设每个玩家的玩法都会使他的分数最大化。

示例 1:
输入: [1, 5, 2]
输出: False
解释: 一开始,玩家1可以从1和2中进行选择。
如果他选择2(或者1),那么玩家2可以从1(或者2)和5中进行选择。如果玩家2选择了5,那么玩家1则只剩下1(或者2)可选。
所以,玩家1的最终分数为 1 + 2 = 3,而玩家2为 5。
因此,玩家1永远不会成为赢家,返回 False。

示例 2:
输入: [1, 5, 233, 7]
输出: True
解释: 玩家1一开始选择1。然后玩家2必须从5和7中进行选择。无论玩家2选择了哪个,玩家1都可以选择233。
最终,玩家1(234分)比玩家2(12分)获得更多的分数,所以返回 True,表示玩家1可以成为赢家。

对于当前玩家,他有两种选择:左边或者右边的数。定义 maxDiff(l, r) 为剩余子数组\([l,r]\)时,当前玩家能取得的最大分差,那么

\[ \begin{equation*} \operatorname{maxDiff}(l, r) = \max \begin{cases} nums[l] - \operatorname{maxDiff}(l + 1, r)\\\\ nums[r] - \operatorname{maxDiff}(l, r - 1) \end{cases} \end{equation*} \]

对应的时间复杂度可以写出递归式,显然是指数级的: \[ f(n) = 2f(n-1) = O(2^n) \]

采用暴力解法可以AC,但运算时间很长,接近TLE边缘 (6300ms)。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# AC
# Time Complexity: O(2^N)
# Slow: 6300ms
from typing import List

class Solution:

def maxDiff(self, l: int, r:int) -> int:
if l == r:
return self.nums[l]
return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
return self.maxDiff(0, len(nums) - 1) >= 0

从调用图也很容易看出是指数级的复杂度
486 Predict the Winner 暴力解法调用图 n=4

上图中我们有重复计算的节点,例如[1-2]节点被计算了两次。使用 lru_cache 大法,在maxDiff 上仅加了一句,就能以复杂度 \(O(n^2)\)和运行时间 43ms AC。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# AC
# Time Complexity: O(N^2)
# Fast: 43ms
from functools import lru_cache
from typing import List

class Solution:

@lru_cache(maxsize=None)
def maxDiff(self, l: int, r:int) -> int:
if l == r:
return self.nums[l]
return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
return self.maxDiff(0, len(nums) - 1) >= 0
动态规划解法调用图可以看出节点 [1-2] 这次没有被计算两次。
486 Predict the Winner 动归解法调用图 n=4

Leetcode 464 Can I Win (中等)

类似但稍有难度的题目 Leetcode 464 Can I Win。难点在于使用了位的状态压缩。

在 "100 game" 这个游戏中,两名玩家轮流选择从 1 到 10 的任意整数,累计整数和,先使得累计整数和达到 100 的玩家,即为胜者。
如果我们将游戏规则改为 “玩家不能重复使用整数” 呢?
例如,两个玩家可以轮流从公共整数池中抽取从 1 到 15 的整数(不放回),直到累计整数和 >= 100。
给定一个整数 maxChoosableInteger (整数池中可选择的最大数)和另一个整数 desiredTotal(累计和),判断先出手的玩家是否能稳赢(假设两位玩家游戏时都表现最佳)?
你可以假设 maxChoosableInteger 不会大于 20, desiredTotal 不会大于 300。

示例:
输入:
maxChoosableInteger = 10
desiredTotal = 11
输出:
false
解释:
无论第一个玩家选择哪个整数,他都会失败。
第一个玩家可以选择从 1 到 10 的整数。
如果第一个玩家选择 1,那么第二个玩家只能选择从 2 到 10 的整数。
第二个玩家可以通过选择整数 10(那么累积和为 11 >= desiredTotal),从而取得胜利.
同样地,第一个玩家选择任意其他整数,第二个玩家都会赢。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# AC
# Time Complexity: O:(2^m*m), m: maxChoosableInteger
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
def recurse(self, status: int, currentTotal: int) -> bool:
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return True
if not self.recurse(new_status, currentTotal + i):
return True
return False


def canIWin(self, maxChoosableInteger: int, desiredTotal: int) -> bool:
self.maxChoosableInteger = maxChoosableInteger
self.desiredTotal = desiredTotal

sum = maxChoosableInteger * (maxChoosableInteger + 1) / 2
if sum < desiredTotal:
return False
return self.recurse(0, 0)

上面的代码算法复杂度为\(O(m 2^m)\),m是maxChoosableInteger。由于所有状态的数量是\(2^m\),对于每个状态,最多会尝试 \(m\) 走法。

Minimax 算法

至此,我们AC了leetcode中的几道零和回合制博弈游戏。事实上,在这个领域有通用的算法:回合制博弈下的minimax。算法背景如下,两个玩家轮流玩,第一个玩家max的目的是将游戏的效用最大化,第二个玩家min则是最小化效用。比如,下面的节点表示玩家选取节点后游戏的效用,当两个玩家都能采取最优策略,Minimax 算法从底层节点来计算,游戏的结果是最终max 玩家会得到-7。

Wikipedia Minimax 例子

Minimax Python 3伪代码如下。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
def minimax(node: Node, depth: int, maximizingPlayer: bool) -> int:
if depth == 0 or is_terminal(node):
return evaluate_terminal(node)
if maximizingPlayer:
value:int = −∞
for child in node:
value = max(value, minimax(child, depth − 1, False))
return value
else: # minimizing player
value := +∞
for child in node:
value = min(value, minimax(child, depth − 1, True))
return value

Minimax: 486 Predict the Winner

我们知道486 Predict the Winner 是有minimax解法的,但如何具体实现,其难点在于如何定义合适的游戏价值或者效用。之前的解法中,我们定义maxDiff(l, r) 来表示当前玩家面临子区间 \([l, r]\) 时能取得的最大分差。对于minimax算法,max 玩家要最大化游戏价值,min玩家要最小化游戏价值。先考虑最简单情况即只有一个数x时,若定义max玩家在此局面下得到这个数时游戏价值为 +x,则min玩家为-x,即max玩家得到的所有数为正(\(+a_1 + a_2 + ... = A\)),min玩家得到的所有数为负(\(-b_1 - b_2 - ... = -B\))。至此,max玩家的目标就是 \(max(A-B)\) ,min玩家是 \(min(A-B)\)。有了精确的定义和优化目标,代码只需要套一下上面的模版。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# AC
from functools import lru_cache
from typing import List

class Solution:
# max_player: max(A - B)
# min_player: min(A - B)
@lru_cache(maxsize=None)
def minimax(self, l: int, r: int, isMaxPlayer: bool) -> int:
if l == r:
return self.nums[l] * (1 if isMaxPlayer else -1)

if isMaxPlayer:
return max(
self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))
else:
return min(
-self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
-self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
v = self.minimax(0, len(nums) - 1, True)
return v >= 0
Minimax 486 调用图 nums=[1, 5, 2, 4]

Minimax: 464 Can I Win

该题目是很典型的此类游戏,即结果为赢输平,但是中间的状态没有直接对应的游戏价值。对于这样的问题,一般定义为,max玩家胜,价值 +1,min玩家胜,价值-1,平则0。下面的AC代码实现了 Minimax 算法。算法中针对两个玩家都有剪枝(没有剪枝无法AC)。具体来说,max玩家一旦在某一节点取得胜利(value=1),就停止继续向下搜索,因为这是他能取得的最好分数。同理,min玩家一旦取得-1也直接返回上层节点。这个剪枝可以泛化成 alpha beta剪枝算法。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# AC
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
# currentTotal < desiredTotal
def minimax(self, status: int, currentTotal: int, isMaxPlayer: bool) -> int:
import math
if status == self.allUsed:
return 0 # draw: no winner

if isMaxPlayer:
value = -math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return 1 # shortcut
value = max(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
if value == 1:
return 1
return value
else:
value = math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return -1 # shortcut
value = min(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
if value == -1:
return -1
return value

Alpha-Beta 剪枝

在464 Can I Win minimax 算法代码实现中,我们发现有剪枝优化空间。对于每个节点,定义两个值alpha 和 beta,表示从根节点到目前局面时,max玩家保证能取得的最小值以及min玩家能保证取得的最大值。初始时,根节点alpha = −∞ , beta = +∞,表示游戏最终的价值在区间 [−∞, +∞]中。在向下遍历的过程中,子节点先继承父节点的 alpha beta 值进而继承区间 [alpha, beta]。当子节点在向下遍历的时候同步更新alpha 或者 beta,一旦区间[alpha, beta]非法就立即向上返回。举个Wikimedia的例子来进一步说明:

  1. 根节点初始时: alpha = −∞, beta = +∞

  2. 根节点,最左边子节点返回4后: alpha = 4, beta = +∞

  3. 根节点,中间子节点返回5后: alpha = 5, beta = +∞

  4. 最右Min节点(标1节点),初始时: alpha = 5, beta = +∞

  5. 最右Min节点(标1节点),第一个子节点返回1后: alpha = 5, beta = 1

此时,最右Min节点的alpha, beta形成了无效区间[5, 1],满足了剪枝条件,因此可以不用计算它的第二个和第三个子节点。如果剩余子节点返回值 > 1,比如2,由于这是个min节点,将会被已经到手的1替换。若其他子节点返回值 < 1,但由于min的父节点有效区间是[5, +∞],已经保证了>=5,小于5的值也会被忽略。

Wikimedia Alpha Beta 剪枝例子

Alpha Beta 剪枝 Python 3伪代码如下

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def alpha_beta(node: Node, depth: int, α: int, β: int, maximizingPlayer: bool) -> int:
if depth == 0 or is_terminal(node):
return evaluate_terminal(node)
if maximizingPlayer:
value: int = −∞
for child in node:
value = max(value, alphabeta(child, depth − 1, α, β, False))
α = max(α, value)
if α >= β:
break # β cut-off
return value
else:
value: int = +∞
for child in node:
value = min(value, alphabeta(child, depth − 1, α, β, True))
β = min(β, value)
if β <= α:
break # α cut-off
return value

Alpha-Beta Pruning: 486 Predict the Winner

用 Alpha-Beta 剪枝 再次AC 486。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# AC
import math
from functools import lru_cache
from typing import List

class Solution:
def alpha_beta(self, l: int, r: int, curr: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
if l == r:
return curr + self.nums[l] * (1 if isMaxPlayer else -1)

if isMaxPlayer:
ret = self.alpha_beta(l + 1, r, curr + self.nums[l], not isMaxPlayer, alpha, beta)
alpha = max(alpha, ret)
if alpha >= beta:
return alpha
ret = max(ret, self.alpha_beta(l, r - 1, curr + self.nums[r], not isMaxPlayer, alpha, beta))
return ret
else:
ret = self.alpha_beta(l + 1, r, curr - self.nums[l], not isMaxPlayer, alpha, beta)
beta = min(beta, ret)
if alpha >= beta:
return beta
ret = min(ret, self.alpha_beta(l, r - 1, curr - self.nums[r], not isMaxPlayer, alpha, beta))
return ret

def PredictTheWinner(self, nums: List[int]) -> bool:
self.nums = nums
v = self.alpha_beta(0, len(nums) - 1, 0, True, -math.inf, math.inf)
return v >= 0

Alpha-Beta Pruning: 464 Can I Win

464 Alpha-Beta 剪枝版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# AC
class Solution:
from functools import lru_cache
@lru_cache(maxsize=None)
# currentTotal < desiredTotal
def alpha_beta(self, status: int, currentTotal: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
import math
if status == self.allUsed:
return 0 # draw: no winner

if isMaxPlayer:
value = -math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return 1 # shortcut
value = max(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
alpha = max(alpha, value)
if alpha >= beta:
return value
return value
else:
value = math.inf
for i in range(1, self.maxChoosableInteger + 1):
if not (status >> i & 1):
new_status = 1 << i | status
if currentTotal + i >= self.desiredTotal:
return -1 # shortcut
value = min(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
beta = min(beta, value)
if alpha >= beta:
return value
return value

C++, Java, Javascript AC 486 Predict the Winner

最后介绍一种不同的DP实现:用C++, Java, Javascript 实现自底向上的DP解法来AC leetcode 486,当然其他语言没有Python的lru_cache大法。以下实现中,注意DP解的构建顺序,先解决小规模的问题,并在此基础上计算稍大的问题。值得一提的是,以下的循环写法严格保证了 \(n^2\) 次循环,但是自顶向下的计划递归可能会少于 \(n^2\)次循环。

Java AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// AC
class Solution {
public boolean PredictTheWinner(int[] nums) {
int n = nums.length;
int[][] dp = new int[n][n];
for (int i = 0; i < n; i++) {
dp[i][i] = nums[i];
}

for (int l = n - 1; l >= 0; l--) {
for (int r = l + 1; r < n; r++) {
dp[l][r] = Math.max(
nums[l] - dp[l + 1][r],
nums[r] - dp[l][r - 1]);
}
}
return dp[0][n - 1] >= 0;
}
}

C++ AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// AC
class Solution {
public:
bool PredictTheWinner(vector<int>& nums) {
int n = nums.size();
vector<vector<int>> dp(n, vector<int>(n, 0));
for (int i = 0; i < n; i++) {
dp[i][i] = nums[i];
}
for (int l = n - 1; l >= 0; l--) {
for (int r = l + 1; r < n; r++) {
dp[l][r] = max(nums[l] - dp[l + 1][r], nums[r] - dp[l][r - 1]);
}
}
return dp[0][n - 1] >= 0;
}
};

Javascript AC Code

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* @param {number[]} nums
* @return {boolean}
*/
var PredictTheWinner = function(nums) {
const n = nums.length;
const dp = new Array(n).fill().map(() => new Array(n));

for (let i = 0; i < n; i++) {
dp[i][i] = nums[i];
}

for (let l = n - 1; l >=0; l--) {
for (let r = i + 1; r < n; r++) {
dp[l][r] = Math.max(nums[l] - dp[l + 1][r],nums[r] - dp[l][r - 1]);
}
}

return dp[0][n-1] >=0;
};
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×