#Algorithm

在这篇文章中,我们从一道LeetCode 470 题目出发,通过系统地思考,引出拒绝采样(Reject Sampling)的概念,并探索比较三种拒绝采样地解法;接着借助状态转移图来定量计算采样效率;最后,我们利用同样的方法来解一道稍微复杂些的经典抛硬币求期望的统计面试题目。

Leetcode 470 用 Rand7() 实现 Rand10()

已有方法 rand7 可生成 1 到 7 范围内的均匀随机整数,试写一个方法 rand10 生成 1 到 10 范围内的均匀随机整数。

不要使用系统的 Math.random() 方法。

思考

  • rand7()调用次数的 期望值 是多少 ?

  • 你能否尽量少调用 rand7() ?

思维过程

我们已有 rand7() 等概率生成了 [1, 7] 中的数字,我们需要等概率生成 [1, 10] 范围内的数字。第一反应是调用一次rand7() 肯定是不够的,因为覆盖的范围不够。那么,就需要至少2次调用 rand7() 才能生成一次 rand10(),但是还要保证 [1, 10] 的数字生成概率相等,这个是难点。 现在我们先来考虑反问题,给定rand10() 生成 rand7()。这个应该很简单,调用一次 rand10() 得到 [1, 10],如果是 8, 9, 10 ,则丢弃,重新开始,否则返回。想必大家都能想到这个朴素的方法,这种思想就是统计模拟中的拒绝采样(Reject Sampling)。

有了上面反问题的思考,我们可能会想到,rand7() 可以生成 rand5(),覆盖 [1, 5]的范围,如果将区间 [1, 10] 分成两个5个值的区间 [1, 5] 和 [6, 10],那么 rand7() 可以通过先等概率选择区间 [1, 5] 或 [6, 10],再通过rand7() 生成 rand5()就可以了。这个问题就等价于先用 rand7() 生成 rand2(),决定了 [1, 5] 还是 [6, 10],再通过rand7() 生成 rand5() 。

解法一:rand2() + rand5()

我们来实现这种解法。下图为调用两次 rand7() 生成 rand10 数值的映射关系:横轴表示第一次调用,1,2,3决定选择区间 [1, 5] ,4,5,6选择区间 [6, 10]。灰色部分表示结果丢弃,重新开始 (注,若第一次得到7无需再次调用 rand7())。

有了上图,我们很容易写出如下 AC 代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# AC
# Runtime: 408 ms, faster than 23.80% of Python3 online submissions for Implement Rand10() Using Rand7().
# Memory Usage: 16.7 MB, less than 90.76% of Python3 online submissions for Implement Rand10() Using Rand7().
class Solution:
def rand10(self):
while True:
a = rand7()
if a <= 3:
b = rand7()
if b <= 5:
return b
elif a <= 6:
b = rand7()
if b <= 5:
return b + 5

标准解法:rand49()

从提交的结果来看,第一种解法慢于多数解法。原因是我们的调用 rand7() 的采样效率比较低,第一次有 1/7 的概率结果丢弃,第二次有 2/7的概率被丢弃。

如何在第一种解法的基础上提高采样效率呢?直觉告诉我们一种做法是降低上述 7x7 表格中灰色格子的面积。此时,会想到我们通过两次 rand7() 已经构建出来 rand49()了,那么再生成 rand10() 也规约成基本问题了。

下图为 rand49() 和 rand10() 的数字对应关系。

实现代码比较简单。注意,while True 可以去掉,用递归来代替。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# AC
# Runtime: 376 ms, faster than 54.71% of Python3 online submissions for Implement Rand10() Using Rand7().
# Memory Usage: 16.9 MB, less than 38.54% of Python3 online submissions for Implement Rand10() Using Rand7().
class Solution:
def rand10(self):
while True:
a, b = rand7(), rand7()
num = (a - 1) * 7 + b
if num <= 40:
return num % 10 + 1

更快的做法

上面的提交结果发现标准解法在运行时间上有了不少提高,处于中等位置。我们继续思考,看看能否再提高采样效率。

观察发现,rand49() 有 9/49 的概率,生成的值被丢弃,原因是 [41, 49] 只有 9 个数,不足10个。倘若此时能够将这种状态保持下去,那么只需再调用一次 rand7() 而不是从新开始情况下至少调用两次 rand7(), 就可以得到 rand10()了。也就是说,当 rand49() 生成了 [41, 49] 范围内的数的话等价于我们先调用了一次 rand9(),那么依样画葫芦,我们接着调用 rand7() 得到了 rand63()。63 分成了6个10个值的区间后,剩余 3 个数。此时,又等价于 rand3(),循环往复,调用了 rand7() 得到了 rand21(),最后若rand21() 不幸得到21,等价于 rand1(),此时似乎我们走投无路,只能回到最初的状态,一切从头再来了。

改进算法代码如下。注意这次击败了 92.7%的提交。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# AC
# Runtime: 344 ms, faster than 92.72% of Python3 online submissions for Implement Rand10() Using Rand7().
# Memory Usage: 16.7 MB, less than 90.76% of Python3 online submissions for Implement Rand10() Using Rand7().
class Solution:
def rand10(self):
while True:
a, b = rand7(), rand7()
num = (a - 1) * 7 + b
if num <= 40: return num % 10 + 1
a = num - 40
b = rand7()
num = (a - 1) * 7 + b
if num <= 60: return num % 10 + 1
a = num - 60
b = rand7()
num = (a - 1) * 7 + b
if num <= 20: return num % 10 + 1

采样效率计算

通过代码提交的结果和大致的分析,我们已经知道三个解法在采样效率依次变得更快。现在我们来定量计算这三个解法。

我们考虑生成一个 rand10() 平均需要调用多少次 rand7(),作为采样效率的标准。

一种思路是可以通过模拟方法,即将上述每个解法模拟多次,然后用总的 rand7() 调用次数除以 rand10() 的生成次数即可。下面以解法三为例写出代码

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# The rand7() API is already defined for you.
rand7_c = 0
rand10_c = 0

def rand7():
global rand7_c
rand7_c += 1
import random
return random.randint(1, 7)

def rand10():
global rand10_c
rand10_c += 1
while True:
a, b = rand7(), rand7()
num = (a - 1) * 7 + b
if num <= 40: return num % 10 + 1
a = num - 40 # [1, 9]
b = rand7()
num = (a - 1) * 7 + b # [1, 63]
if num <= 60: return num % 10 + 1
a = num - 60 # [1, 3]
b = rand7()
num = (a - 1) * 7 + b # [1, 21]
if num <= 20: return num % 10 + 1

if __name__ == '__main__':
while True:
rand10()
print(f'{rand10_c} {round(rand7_c/rand10_c, 2)}')

运行代码,发现解法三的采样效率稳定在 2.19。

采样效率精确计算

计算解法二

为了精确计算三个解法的采样效率,我们通过代码得到对应的状态转移图来帮助计算。

例如,解法一可以对应到下图:初始状态 Start 节点中的 +2 表示经过此节点会产生 2次 rand7() 的代价。从 Start 节点有 40/49 的概率到达被接受状态 AC,有 9/49 概率到达拒绝状态 REJ。REJ 需要从头开始,则用虚线表示重新回到 Start节点,也就是说 REJ 的代价等价于 Start。注意,从某个节点出发的所有边之和为1。

有了上述状态转移关系图,我们令初始状态的平均代价为 \(C_2\),则可以写成递归表达式,因为其中 REJ 的代价就是 \(C_2\),即

\[ C_2 = 2 + (\frac{40}{49}\cdot0 + \frac{9}{49} C_2) \]

解得 \(C_2\)

\[ C_2 = 2.45 \]

计算解法一

同样的,对于另外两种解法,虽然略微复杂,也可以用同样的方法求得。

解法一的状态转移图为

递归方程表达式为

\[ C_1 = 1+\frac{3}{7} \cdot (1+\frac{5}{7} \cdot 0 + \frac{2}{7} \cdot C_1) \cdot2+ \frac{1}{7} \cdot (C_1) \]

解得 \(C_1\)

\[ C_1 = \frac{91}{30} \approx 3.03 \]

计算解法三

最快的解法三状态转移图为

递归方程表达式为

\[ C_3 = 2+\frac{40}{49} \cdot 0 + \frac{9}{49} (1+\frac{60}{63} \cdot 0 + \frac{3}{63} \cdot (1+\frac{20}{21} \cdot 0 + \frac{1}{21} \cdot C_3)) \]

解得 \(C_3\) \[ C_3 = \frac{329}{150} \approx 2.193 \]

至此总结一下,三个解法的平均代价为 \[ C_1 \approx 3.03 > C_2 = 2.45 > C_3 \approx 2.193 \] 这些值和我们通过模拟得到的结果一致。

稍难些的经典概率求期望题目

至此,LeetCode 470 我们已经分析透彻。现在,我们已经可以很熟练的将此类拒绝采样的问题转变成有概率的状态转移图,再写成递推公式去求平均采样的代价(即期望)。这里,如果大家感兴趣的话不妨再来看一道略微深入的经典统计概率求期望的题目。

问题:给定一枚抛正反面概率一样的硬币,求连续抛硬币直到两个正面(正面记为H,两个正面HH)的平均次数。例如:HTTHH是一个连续次数为5的第一次出现HH的序列。

分析问题画出状态转移图:我们令初始状态下得到第一个HH的平均长度记为 S,那么下一次抛硬币有 1/2 机会是 T,此时状态等价于初始状态,另有 1/2 机会是 H,我们记这个状态下第一次遇见HH的平均长度为 H(下图蓝色节点)。从此蓝色节点出发,当下一枚硬币是H则结束,是T是返回初始状态。于是构建出下图。

这个问题稍微复杂的地方在于我们有两个未知状态互相依赖,但问题的本质和方法是一样的,分别从 S 和 H 出发考虑状态的概率转移,可以写成如下两个方程式:

\[ \left\{ \begin{array}{c} S =&\frac{1}{2} \cdot(1+H) + \frac{1}{2} \cdot(1+S) \\ H =&\frac{1}{2} \cdot 1 + \frac{1}{2} \cdot(1+S) \end{array} \right. \]

解得

\[ \left\{ \begin{array}{c} H= 4 \\ S = 6 \end{array} \right. \]

因此,平均下来,需要6次抛硬币才能得到 HH,这个是否和你直觉的猜测一致呢?

这个问题还可以有另外一问,可以作为思考题让大家来练习一下:第一次得到 HT 的平均次数是多少?这个是否和 HH 一样呢?

Leetcode 1029. 两地调度 (medium)

公司计划面试 2N 人。第 i 人飞往 A 市的费用为 costs[i][0],飞往 B 市的费用为 costs[i][1]。

返回将每个人都飞到某座城市的最低费用,要求每个城市都有 N 人抵达。 

示例:

输入:[[10,20],[30,200],[400,50],[30,20]] 输出:110 解释: 第一个人去 A 市,费用为 10。 第二个人去 A 市,费用为 30。 第三个人去 B 市,费用为 50。 第四个人去 B 市,费用为 20。 最低总费用为 10 + 30 + 50 + 20 = 110,每个城市都有一半的人在面试。

提示:

1 <= costs.length <= 100 costs.length 为偶数 1 <= costs[i][0], costs[i][1] <= 1000

链接:https://leetcode-cn.com/problems/two-city-scheduling

暴力枚举法

最直接的方式是暴力枚举出所有分组的可能。因为 2N 个人平均分成两组,总数为 \({2n \choose n}\),是 n 的指数级数量。在文章24 点游戏算法题的 Python 函数式实现: 学用itertools,yield,yield from 巧刷题,我们展示如何调用 Python 的 itertools包,这里,我们也用同样的方式产生 [0, 2N] 的所有集合大小为N的可能(保存在left_set_list中),再遍历找到最小值即可。当然,这种解法会TLE,只是举个例子来体会一下暴力做法。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import math
from typing import List

class Solution:
def twoCitySchedCost(self, costs: List[List[int]]) -> int:
L = range(len(costs))
from itertools import combinations
left_set_list = [set(c) for c in combinations(list(L), len(L)//2)]

min_total = math.inf
for left_set in left_set_list:
cost = 0
for i in L:
is_left = 1 if i in left_set else 0
cost += costs[i][is_left]
min_total = min(min_total, cost)

return min_total

O(N) AC解法

对于组合优化问题来说,例如TSP问题(解法链接 TSP问题从DP算法到深度学习1:递归DP方法 AC AIZU TSP问题),一般都是 NP-Hard问题,意味着没有多项次复杂度的解法。但是这个问题比较特殊,它增加了一个特定条件:去城市A和城市B的人数相同,也就是我们已经知道两个分组的数量是一样的。我们仔细思考一下这个意味着什么?考虑只有四个人的小规模情况,如果让你来手动规划,你一定不会穷举出所有两两分组的可能,而是比较人与人相对的两个城市的cost差。举个例子,有如下四个人的costs

1
2
3
4
0 A:3,  B:1
1 A:99, B:100
2 A:2, B:2
3 A:3, B:3
虽然1号人去城市A(99)cost 很大,但是相比较他去B(100)来说,可以省下 100-99 = 1 的钱,这个钱比0号人去B不去A省下的钱 3-1 = 2 还要多,因此你一定会选择让1号人去A而让0号人去B。

有了这个想法,再整理一下,就会发现让某人去哪个城市和他去两个城市的cost 差 $ C_a - C_b$相关,如果这个值越大,那么他越应该去B。但是最后决定他是否去B取决于他的差在所有人中的排名,由于两组人数相等,因此差能大到排在前一半,则他就去B,在后一半就去A。

按照这个思路,很快能写出代码,代码写法有很多,下面略举一例。代码中由于用到排序,复杂度为 \(O(N \cdot \log(N))\) 。这里补充一点,理论上只需找数组中位数的值即可,最好的时间复杂度是 \(O(N)\)

代码实现上,cost_diff_list 将每个人的在原数组的index 和他的cost差组成 pair。即

1
[(0, cost_0), (1, cost_1), ... ]

这样我们可以将这个数组按照cost排序,排完序后前面N个元素属于B城市。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# AC
# Runtime: 36 ms, faster than 87.77% of Python3 online submissions
# Memory Usage: 14.5 MB, less than 14.84% of Python3 online
from typing import List

class Solution:
def twoCitySchedCost(self, costs: List[List[int]]) -> int:
L = range(len(costs))
cost_diff_lst = [(i, costs[i][0] - costs[i][1]) for i in L]
cost_diff_lst.sort(key=lambda x: x[1])

total_cost = 0
for c, (idx, _) in enumerate(cost_diff_lst):
is_left = 0 if c < len(L) // 2 else 1
total_cost += costs[idx][is_left]

return total_cost

转换成整数规划问题

这个问题对于略有算法经验的人来说,很类似于背包问题。它们都需要回答N个物品取或者不取,并同时最大最小化总cost。区别在它们的约束条件不一样。这道题的约束是去取(去城市A)和不取(去城市B)的数量一样。这一类问题即 integer programming,即整数规划。下面我们选取两个比较流行的优化库来展示如何调包解这道题。

首先我们先来formulate这个问题,因为需要表达两个约束条件,我们将每个人的状态分成是否去A和是否去B两个变量。

1
2
x[i-th-person][0]: boolean 表示是否去 city a
x[i-th-person][1]: boolean 表示是否去 city b

这样,问题转换成如下优化模型

\[ \begin{array}{rrclcl} \displaystyle \min_{x} & costs[i][0] \cdot x[i][0] + costs[i][1] \cdot x[i][1]\\ \textrm{s.t.} & x[i][0] + x[i][1] =1\\ &x[i][0] + x[i][1] + ... =N \\ \end{array} \]

Google OR-Tools

Google OR-Tools 是业界最好的优化库,下面为调用代码,由于直接对应于上面的数学优化问题,不做赘述。当然 Leetcode上不支持这些第三方的库,肯定也不能AC。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from ortools.sat.python import cp_model

costs = [[515,563],[451,713],[537,709],[343,819],[855,779],[457,60],[650,359],[631,42]]

I = range(len(costs))

model = cp_model.CpModel()
x = []
total_cost = model.NewIntVar(0, 10000, 'total_cost')
for i in I:
t = []
for j in range(2):
t.append(model.NewBoolVar('x[%i,%i]' % (i, j)))
x.append(t)

# Constraints
# Each person must be assigned to at exact one city
[model.Add(sum(x[i][j] for j in range(2)) == 1) for i in I]
# equal number of person assigned to two cities
model.Add(sum(x[i][0] for i in I) == (len(I) // 2))

# Total cost
model.Add(total_cost == sum(x[i][0] * costs[i][0] + x[i][1] * costs[i][1] for i in I))
model.Minimize(total_cost)

solver = cp_model.CpSolver()
status = solver.Solve(model)

if status == cp_model.OPTIMAL:
print('Total min cost = %i' % solver.ObjectiveValue())
print()
for i in I:
for j in range(2):
if solver.Value(x[i][j]) == 1:
print('People ', i, ' assigned to city ', j, ' Cost = ', costs[i][j])

完整代码可以从我的github下载。

https://github.com/MyEncyclopedia/leetcode/blob/master/1029_Two_City_Scheduling/1029_ortool.py

PuLP

类似的,另一种流行 python 优化库 PuLP 的代码为

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pulp

costs = [[259,770],[448,54],[926,667],[184,139],[840,118],[577,469]] # 1859


I = range(len(costs))

items=[i for i in I]
city_a = pulp.LpVariable.dicts('left', items, 0, 1, pulp.LpBinary)
city_b = pulp.LpVariable.dicts('right', items, 0, 1, pulp.LpBinary)

m = pulp.LpProblem("Two Cities", pulp.LpMinimize)

m += pulp.lpSum((costs[i][0] * city_a[i] + costs[i][1] * city_b[i]) for i in items)

# Constraints
# Each person must be assigned to at exact one city
for i in I:
m += pulp.lpSum([city_a[i] + city_b[i]]) == 1
# create a binary variable to state that a table setting is used
m += pulp.lpSum(city_a[i] for i in I) == (len(I) // 2)

m.solve()

total = 0
for i in I:
if city_a[i].value() == 1.0:
total += costs[i][0]
else:
total += costs[i][1]

print("Total cost {}".format(total))

代码地址为

https://github.com/MyEncyclopedia/leetcode/blob/master/1029_Two_City_Scheduling/1029_pulp.py

本篇是TSP问题从DP算法到深度学习系列第四篇,这一篇我们会详细举例并比较在 seq-to-seq 或者Markov Chain中的一些常见的搜索概率最大的状态序列的算法。这些方法在深度学习的seq-to-seq 中被用作decoding。在第五篇中,我们使用强化学习时也会使用了本篇中讲到的方法。

马尔科夫链问题

在 seq-to-seq 问题中,我们经常会遇到需要从现有模型中找概率最大的可能状态序列。这类问题在机器学习算法和控制领域广泛存在,抽象出来可以表达成马尔可夫链模型:给定初始状态的分布和系统的状态转移方程(称为系统动力,dynamics),找寻最有可能的状态序列。

举个例子,假设系统有 \(n\) 个状态,初始状态由 $s_0 = [0.35, 0.25, 0.4] $ 指定,表示初始时三种状态的分布为 0.35,0.25和0.4。

状态转移矩阵由 \(T\) 表达,其中 $ T[i][j]$ 表示从状态 \(i\) 到状态 \(j\) 的概率。注意下面的矩阵 \(T\) 每行的和为 1.0,对应了从任意状态出发,下一时刻的所有可能转移概率和为1。 \[ T= \begin{matrix} & \begin{matrix}0&1&2\end{matrix} \\\\ \begin{matrix}0\\\\1\\\\2\end{matrix} & \begin{bmatrix}0.3&0.6&0.1\\\\0.4&0.2&0.4\\\\0.3&0.3&0.4\end{bmatrix}\\\\ \end{matrix} \]

至此,系统的所有参数都定下来了。接下去的各个时刻的状态分布可以通过矩阵乘法来算得。比如,记\(s_1\)\(t_1\) 时刻状态分布,计算方法为 \(s_0\) 乘以 \(T\),动画如下:

\(s_1\) 数值计算结果如下。

\[ s_1 = \begin{bmatrix}0.35& 0.25& 0.4\end{bmatrix} \begin{matrix} \begin{bmatrix}0.3&0.6&0.1\\\\0.4&0.2&0.4\\\\0.3&0.3&0.4\end{bmatrix}\\\\ \end{matrix} = \begin{bmatrix}0.325& 0.35& 0.255\end{bmatrix} \] 矩阵左乘行向量可以理解为矩阵每一行的线性组合,直觉上理解为下一时刻的状态分布是上一时刻初始状态分布乘以转移关系的线性组合。 \[ \begin{bmatrix}0.35& 0.25& 0.4\end{bmatrix} \begin{matrix} \begin{bmatrix}0.3&0.6&0.1\\\\0.4&0.2&0.4\\\\0.3&0.3&0.4\end{bmatrix}\\\\ \end{matrix} = 0.35 \times \begin{bmatrix}0.35& 0.6& 0.1\end{bmatrix} + 0.25 \times \begin{bmatrix}0.4& 0.2& 0.4\end{bmatrix} + 0.4 \times \begin{bmatrix}0.3& 0.3& 0.4\end{bmatrix} \] 同样的,后面每一个时刻都可以由上一个状态分布向量乘以 \(T\),当然这里我们假设每个时刻的转移矩阵是不变的。当然,问题也可以是每个时刻都有不同的转移矩阵来定义,例如深度学习 seq-to-seq 模型。当然,这个设定的变化不会影响搜索最可能状态序列的算法。出于简单考虑,本篇中我们假定所有时刻的状态转移矩阵都是 \(T\)

下面我们通过多种算法来找出由上述参数定义的系统中前三个时刻的最有可能序列,即概率最大的 \(s_0 \rightarrow s_1 \rightarrow s_2\)

\(L\) 是阶段数,\(N\) 是每个阶段的状态数,则我们的例子中 \(L=N=3\) 。并且,总共有 \(N^L\) 种不同的路径。

穷竭搜索

若给定一条路径,计算特定路径的概率是很直接的,例如,若给定路径为 \(2(s_0) \rightarrow 1(s_1) \rightarrow 2(s_2)\),则这条路径的概率为

\[ p(2 \rightarrow 1 \rightarrow 2) = s_0[2] \times T[2][1] \times T[1][2] = 0.4 \times 0.3 \times 0.4 = 0.048 \]

因此,我们可以通过枚举所有 \(N^L\) 条路径并计算每条路径的概率来找到最有可能的状态序列。

下面是Python 3的穷竭搜索代码,输出为最有可能的概率及其路径。样例问题的输出为 0.084 和状态序列 \(0 \rightarrow 1 \rightarrow 2\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def search_brute_force(initial: List, transition: List, L: int) -> Tuple[float, Tuple]:
from itertools import combinations_with_replacement
v = [0, 1, 2]
path_all = combinations_with_replacement(v, L)

max_prop = 0.0
max_route = None
prob = 0.0
for path in list(path_all):
for idx, v in enumerate(path):
if idx == 0:
prob = initial[v] # reset to initial state
else:
prev_v = path[idx-1]
prob *= transition[prev_v][v]
if prob > max_prop:
max_prop = max(max_prop, prob)
max_route = path
return max_prop, max_route

贪心搜索

穷竭搜索一定会找到最有可能的状态序列,但是算法复杂度是指数级的 \(O(N^L)\)。一种最简化的策略是,每一时刻都只选取下一时刻最可能的状态,显然这种策略没有考虑全局最优,只考虑下一步最优,因此称为贪心策略。当然,贪心策略虽然牺牲全局最优解但是换取了很快的时间复杂度。贪心搜索算法动画如下。

Python 3 实现中我们利用了 numpy 类库,主要是 np.argmax() 可以让代码简洁。代码本质上是两重循环,(一层循环是np.argmax中),对应时间算法复杂度是 \(O(N\times L)\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def search_greedy(initial: List, transition: List, L: int) -> Tuple[float, Tuple]:
import numpy as np
max_route = []
max_prop = 0.0
states = np.array(initial)

prev_max_v = None
for l in range(0, L):
max_v = np.argmax(states)
max_route.append(max_v)
if l == 0:
max_prop = initial[max_v]
else:
max_prop = max_prop * transition[prev_max_v][max_v]
states = max_prop * states
prev_max_v = max_v

return max_prop, max_route

Beam 搜索

贪心策略只考虑了下一步的最大概率状态,若我们改进一下贪心策略,将下一步的最大 \(k\) 个状态保留下来就是beam 搜索了。具体来说, \(k\) beam search表示每个阶段保留 \(k\) 个最大概率路径,下一阶段扩展这 \(k\) 条路径至 \(k \times N\) 条路径再选取最大的top k。以上例来说,选取\(k=2\),则初始 \(s_0\)时选取最大概率的两种状态 0和 2,下一阶段 \(s_1\),计算以0和2开始的共 \(2 \times 3\) 条路径,并保留其中最大概率的两条,如此往复。显然,beam search也无法找到全局最优解,但是它能以线性时间复杂度探索更多的路径空间。

以下是Python 3 的代码实现,利用了 PriorityQueue 选取 \(k\) 路径。由于PriorityQueue 无法自定义比较关系,我们定义了 @total_ordering 标注的类来实现比较关心。时间算法复杂度是 \(O(k\times N \times L)\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def search_beam(initial: List, transition: List, L: int, K: int) -> Tuple[float, Tuple]:
N = len(initial)
from queue import PriorityQueue
current_q = PriorityQueue()
next_q = PriorityQueue()

from functools import total_ordering
@total_ordering
class PQItem(object):
def __init__(self, prob, route):
self.prob = prob
self.route = route
self.last_v = int(route[-1])

def __eq__(self, other):
return self.prob == other.prob

def __lt__(self, other):
return self.prob > other.prob

for v in range(N):
next_q.put(PQItem(initial[v], str(v)))

for l in range(1, L):
current_q = next_q
next_q = PriorityQueue()
k = K
while not current_q.empty() and k > 0:
item = current_q.get()
prob, route, prev_v = item.prob, item.route, item.last_v
k -= 1
for v in range(N):
nextItem = PQItem(prob * transition[prev_v][v], route + str(v))
next_q.put(nextItem)

max_item = next_q.get()

return max_item.prob, list(map(lambda x: int(x), max_item.route))

Viterbi 动态规划

和之前TSP 动态规划算法的思想一样,最有可能状态路径问题解法有可以将指数时间复杂度 \(O(N^L)\) 降到多项式时间复杂度 \(O(L \times N \times N)\) 的算法,就是大名鼎鼎的 Viterbi 算法(维特比算法)。核心思想是在每个阶段,用数组保存每个状态结尾路径的阶段最大概率(及其对应路径)。在不考虑优化空间的情况下,我们开一个二维数组 \(dp[][]\),第一维表示阶段序号,第二维表示状态序号。例如,\(dp[1][0]\)\(s_1\) 阶段时以状态0结尾的所有路径中的最大概率,即 \[ dp[1][0] = \max \\{s_0[0] \rightarrow s_1[0], s_0[1] \rightarrow s_1[0], s_0[2] \rightarrow s_1[0]\\} \]

实现代码中没有返回路径本身而只是其概率值,目的是通过简洁的三层循环来表达算法精髓。

{linenos
1
2
3
4
5
6
7
8
9
10
11
def search_dp(initial: List, transition: List, L: int) -> float:
N = len(initial)
dp = [[0.0 for c in range(N)] for r in range(L)]
dp[0] = initial[:]

for l in range(1, L):
for v in range(N):
for prev_v in range(N):
dp[l][v] = max(dp[l][v], dp[l - 1][prev_v] * transition[prev_v][v])

return max(dp[L-1])

概率采用

以上所有的算法都是确定性的。在NLP 深度学习decoding 时候会带来一个问题:确定性容易导致生成重复的短语或者句子。比如,确定性算法很容易生成如下句子。

1
This is the best of best of best of ...
一种简单的方法是采用概率采用的方式回避这个问题。也就是我们不寻找确定的局部最优或者全局最优的解,而是通过局部路径或者全局路径的概率信息进行采样生成序列。例如,对于穷竭搜索的 \(N^L\) 条路径计算得到对应概率,转变成 \(N^L\) 个点的 categorical 分布,采样生成某条路径。也可以如下改造贪心或者beam 这类阶段性生成算法一个时刻一个时刻的输出采样的状态序列。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def search_prob_greedy(initial: List, transition: List, L: int) -> Tuple[float, Tuple]:
import random
N = len(initial)
max_route = []
max_prop = 0.0
vertices = [i for i in range(N)]
prob = initial[:]

for l in range(0, L):
v_lst = random.choices(vertices, prob)
v = v_lst[0]
max_route.append(v)
max_prop = prob[v]
prob = [prob[v] * transition[v][v_target] for v_target in range(N)]

return max_prop, max_route

本篇是TSP问题从DP算法到深度学习系列第三篇,在这一篇中,我们会开始进入深度学习领域来求近似解法。本文会介绍并实现指针网络(Pointer Networks),一种seq-to-seq模型,它的设计目的就是为了解决TSP问题或者凸包(Convex Hull)问题。本文代码在 https://github.com/MyEncyclopedia/blog/tree/master/tsp/ptr_net_pytorch 中。

Pointer Networks

随着深度学习 seq-to-seq 模型作为概率近似模型在各领域的成功,TSP问题似乎也可以用同样的思路去解决。然而,传统的seq-to-seq 模型其输出的类别是预先固定的。例如,NLP RNN生成模型每一步会从 \(|V|\) 大的词汇表中产生一个单词。 然而,有很大一类问题,譬如TSP问题、凸包(Convex Hull)问题、Delaunay三角剖分问题,输出的类别不是事先固定的,而是随着输入而变化的。 Pointer Networks 的出现解决了这种限制:输出的类别可以通过指向某个输入,以此克服类别的问题,因此形象地取名为指针网络(Pointer Networks)。先来看看原论文中提到的三个问题。

凸包问题(Convex Hull)

如下图所示,需要在给定的10个点中找到若干个点,使得这些点包住了所有点。问题输入是不确定个数 n 个点的位置信息,输出是 k (k<=n)个点的。 这个经典的算法问题已经被证明找出精确解等价于排序问题(wikipedia 链接),因此时间复杂度为 \(O(n*log(n))\)

image info

\[ \begin{align*} &\text{Input: } \mathcal{P} &=& \left\{ P_{1}, \ldots, P_{10} \right\} \\ &\text{Output: } C^{\mathcal{P}} &=& \{2,4,3,5,6,7,2\} \end{align*} \]

TSP 问题

TSP 和凸包问题很类似,输入为不确定个数的 n 个点信息,输出为这 n 个点的某序列。在。。。中,我们可以将确定解的时间复杂度从 \(O(n!)\) 降到 \(O(n^2*2^n)\)

image info

\[ \begin{align*} &\text{Input: } \mathcal{P} &= &\left\{P_{1}, \ldots, P_{6} \right\} \\ &\text{Output: } C^{\mathcal{P}} &=& \{1,3,2,4,5,6,1\} \end{align*} \]

Delaunay三角剖分

Delaunay三角剖分问题是将平面上的散点集划分成三角形,使得在可能形成的三角剖分中,所形成的三角形的最小角最大。这个问题的输出是若干个集合,每个集合代表一个三角形,由输入点的编号表示。 image info

\[ \begin{align*} &\text{Input: } \mathcal{P} &=& \left\{P_{1}, \ldots, P_{5} \right\} \\ &\text{Output: } C^{\mathcal{P}} &=& \{(1,2,4),(1,4,5),(1,3,5),(1,2,3)\} \end{align*} \]

Seq-to-Seq 模型

现在假设n是固定的,传统基本的seq-to-seq模型(参数部分记为 \(\theta\) ),训练数据若记为\((\mathcal{P}, C^{\mathcal{P}})\),,将拟合以下条件概率:

\[ \begin{equation} p\left(\mathcal{C}^{\mathcal{P}} | \mathcal{P} ; \theta\right)=\prod_{i=1}^{m(\mathcal{P})} p\left(C_{i} | C_{1}, \ldots, C_{i-1}, \mathcal{P} ; \theta\right) \end{equation} \] 训练的方向是找到 \(\theta^{*}\) 来最大化上述联合概率,即: \[ \begin{equation} \theta^{*}=\underset{\theta}{\arg \max } \sum_{\mathcal{P}, \mathcal{C}^{\mathcal{P}}} \log p\left(\mathcal{C}^{\mathcal{P}} | \mathcal{P} ; \theta\right) \end{equation} \]

Content Based Input Attention

一种增强基本seq-to-seq模型的方法是加入attention机制。记encoder和decoder隐藏状态分别是 $ (e_{1}, , e_{n}) $ 和 $ (d_{1}, , d_{m()}) $。seq-to-seq第 i 次输出了 \(d_i\),注意力机制额外计算第i步的注意力向量 \(d_i^{\prime}\),并将其和\(d_i\)连接后作为隐藏状态。\(d_i^{\prime}\)的计算方式如下,输入 $ (e_{1}, , e_{n}) $ 和 i 对应的权重向量 $ (a_{1}^{i}, , a_{n}^{i}) $做点乘。

\[ d_{i} = \sum_{j=1}^{n} a_{j}^{i} e_{j} \]

$ (a_{1}^{i}, , a_{n}^{i}) $ 是向量 $ (u_{1}^{i}, , u_{n}^{i}) $ softmax后的值, \(u_{j}^{i}\) 表示 \(d_{i}\)\(e_{j}\)的距离,Pointer Networks论文中的距离为如下的tanh公式。

\[ \begin{eqnarray} u_{j}^{i} &=& v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_\right) \quad j \in(1, \ldots, n) \\ a_{j}^{i} &=& \operatorname{softmax}\left(u_{j}^{i}\right) \quad j \in(1, \ldots, n) \end{eqnarray} \]

更多Attention计算方式

FloydHub Blog - Attention Mechanism 中,作者清楚地解释了两种经典的attention方法,第一种称为Additive Attention,由Dzmitry Bahdanau 提出,也就是Pointer Networks中通过tanh的计算方式,第二种称为 Multiplicative Attention,由Thang Luong*提出。

Luong Attention 有三种方法计算 \(d_{i}\)\(e_{j}\) 的距离(或者可以认为向量间的对齐得分)。

\[ \operatorname{score} \left( d_i, e_j \right)= \begin{cases} d_i^{\top} e_j & \text { dot } \\ d_i^{\top} W_a e_j & \text { general } \\ v_a^{\top} \tanh \left( W_a \left[ d_i ; e_j \right] \right) & \text { concat } \end{cases} \]

Pointer Networks

image info

Pointer Networks 基于Additive Attention,其创新之处在于用 \(u^i_j\) 作为第j个输入的评分,即第 i 次输出为1-n个输入中 \(u^i_j\) 得分最高的j作为输出,这样巧妙的解决了n不是预先固定的限制。

\[ \begin{eqnarray*} u_{j}^{i} &=& v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) \quad j \in(1, \ldots, n) \\ p\left(C_{i} | C_{1}, \ldots, C_{i-1}, \mathcal{P}\right) &=& \operatorname{softmax}\left(u^{i}\right) \end{eqnarray*} \]

PyTorch 代码实现

在本系列第二篇 episode 2,中,我们说明过TSP数据集的格式,每一行字段意义如下

1
x0, y0, x1, y1, ... output 1 v1 v2 v3 ... 1

转换成PyTorch Dataset

每一个case会转换成nd.ndarray,共有五个分量,分别是 (input, input_len, output_in, output_out, output_len) 并且分装成pytorch的 Dataset类。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.data import Dataset

class TSPDataset(Dataset):
"each data item of form (input, input_len, output_in, output_out, output_len)"
data: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]]

def __len__(self):
return len(self.data)

def __getitem__(self, index):
input, input_len, output_in, output_out, output_len = self.data[index]
return input, input_len, output_in, output_out, output_len
image info

PyTorch pad_packed_sequence 优化技巧

PyTorch 实现 seq-to-seq 模型一般会使用 pack_padded_sequence 以及 pad_packed_sequence 来减少计算量,本质上可以认为根据pad大小分批进行矩阵运算,减少被pad的矩阵元素导致的无效运算,详细的解释可以参考 https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#decoder-1。

image info

对应代码如下:

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class RNNEncoder(nn.Module):
rnn: Union[nn.LSTM, nn.GRU, nn.RNN]

def __init__(self, rnn_type: str, bidirectional: bool, num_layers: int, input_size: int, hidden_size: int, dropout: float):
super(RNNEncoder, self).__init__()
if bidirectional:
assert hidden_size % 2 == 0
hidden_size = hidden_size // 2
self.rnn = rnn_init(rnn_type, input_size=input_size, hidden_size=hidden_size, bidirectional=bidirectional,num_layers=num_layers, dropout=dropout)

def forward(self, src: Tensor, src_lengths: Tensor, hidden: Tensor = None) -> Tuple[Tensor, Tensor]:
lengths = src_lengths.view(-1).tolist()
packed_src = pack_padded_sequence(src, lengths)
memory_bank, hidden_final = self.rnn(packed_src, hidden)
memory_bank = pad_packed_sequence(memory_bank)[0]
return memory_bank, hidden_final

注意力机制相关代码

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Attention(nn.Module):
linear_out: nn.Linear

def __init__(self, dim: int):
super(Attention, self).__init__()
self.linear_out = nn.Linear(dim * 2, dim, bias=False)

def score(self, src: Tensor, target: Tensor) -> Tensor:
batch_size, src_len, dim = src.size()
_, target_len, _ = target.size()
target_ = target
src_ = src.transpose(1, 2)
return torch.bmm(target_, src_)

def forward(self, src: Tensor, target: Tensor, src_lengths: Tensor) -> Tuple[Tensor, Tensor]:
assert target.dim() == 3

batch_size, src_len, dim = src.size()
_, target_len, _ = target.size()

align_score = self.score(src, target)

mask = sequence_mask(src_lengths)
# (batch_size, max_len) -> (batch_size, 1, max_len)
mask = mask.unsqueeze(1)
align_score.data.masked_fill_(~mask, -float('inf'))
align_score = F.softmax(align_score, -1)

c = torch.bmm(align_score, src)

concat_c = torch.cat([c, target], -1)
attn_h = self.linear_out(concat_c)

return attn_h, align_score

参考资料

本篇是TSP问题从DP算法到深度学习系列第二篇。

AIZU TSP 自底向上迭代DP解

上一篇中,我们用Python 3和Java 8完成了自顶向下递归版本的DP解。我们继续改进代码,将它转换成标准DP方式:自底向上的迭代DP版本。下图是3个点TSP问题的递归调用图。

将这个图反过来检查状态的依赖关系,可以很容易发现规律:首先计算状态位含有一个1的点,接着是两个1的节点,最后是状态位三个1的点。简而言之,在计算状态位为n+1个1的节点时需要用到n个1的节点的计算结果,如果能依照这样的 topological 顺序来的话,就可以去除递归,写成迭代(循环)版本的DP。

迭代算法的Java 伪代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for (int bitset_num = N; bitset_num >=0; bitset_num++) {
while(hasNextCombination(bitset_num)) {
int state = nextCombination(bitset_num);
// compute dp[state][v], v-th bit is set in state
for (int v = 0; v < n; v++) {
for (int u = 0; u < n; u++) {
// for each u not reached by this state
if (!include(state, u)) {
dp[state][v] = min(dp[state][v],
dp[new_state_include_u][u] + dist[v][u]);
}
}
}
}
}

举例来说,dp[00010][1] 是从顶点0出发,刚经过顶点1的最小距离 \(0 \rightarrow 1 \rightarrow ? \rightarrow ? \rightarrow ? \rightarrow 0\)

为了找到最小距离值,就必须遍历所有可能的下一个可能的顶点u (第一个问号位置)。 \[ (0 \rightarrow 1) + \begin{align*} \min \left\lbrace \begin{array}{r@{}l} 2 \rightarrow ? \rightarrow ? \rightarrow 0 + dist(1,2) \qquad\text{ new_state=[00110][2] } \qquad\\\\ 3 \rightarrow ? \rightarrow ? \rightarrow 0 + dist(1,3) \qquad\text{ new_state=[01010][3] } \qquad\\\\ 4 \rightarrow ? \rightarrow ? \rightarrow 0 + dist(1,4) \qquad\text{ new_state=[10010][4] } \qquad \end{array} \right. \end{align*} \]

迭代DP AC代码

以下是AC 的Java 算法核心代码。完整代码在 github/MyEncyclopedia 的tsp/alg_aizu/Main_loop.java

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public long solve() {
int N = g.V_NUM;
long[][] dp = new long[1 << N][N];
// init dp[][] with MAX
for (int i = 0; i < dp.length; i++) {
Arrays.fill(dp[i], Integer.MAX_VALUE);
}
dp[(1 << N) - 1][0] = 0;

for (int state = (1 << N) - 2; state >= 0; state--) {
for (int v = 0; v < N; v++) {
for (int u = 0; u < N; u++) {
if (((state >> u) & 1) == 0) {
dp[state][v] = Math.min(dp[state][v], dp[state | 1 << u][u] + g.edges[v][u]);
}
}
}
}
return dp[0][0] == Integer.MAX_VALUE ? -1 : dp[0][0];
}

很显然,时间算法复杂度对应了三重 for 循环,为 O(\(2^n * n * n\)) = O(\(2^n*n^2\) )。

类似的,Python 3 AC 代码如下。完整代码在 github/MyEncyclopedia 的tsp/alg_aizu/TSP_loop.py

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class TSPSolver:
g: Graph

def __init__(self, g: Graph):
self.g = g

def solve(self) -> int:
"""
:param v:
:param state:
:return: -1 means INF
"""
N = self.g.v_num
dp = [[INT_INF for c in range(N)] for r in range(1 << N)]

dp[(1 << N) - 1][0] = 0

for state in range((1 << N) - 2, -1, -1):
for v in range(N):
for u in range(N):
if ((state >> u) & 1) == 0:
if dp[state | 1 << u][u] != INT_INF and self.g.edges[v][u] != INT_INF:
if dp[state][v] == INT_INF:
dp[state][v] = dp[state | 1 << u][u] + self.g.edges[v][u]
else:
dp[state][v] = min(dp[state][v], dp[state | 1 << u][u] + self.g.edges[v][u])
return dp[0][0]

一个欧式空间TSP数据集

至此,TSP的DP解法全部讲解完毕。接下去,我们引入一个二维欧式空间的TSP数据集 PTR_NET on Google Drive ,这个数据集是 Pointer Networks 的作者 Oriol Vinyals 用于模型的训练测试而引入的。

数据集的每一行格式如下:

1
x1, y1, x2, y2, ... output 1 v1 v2 v3 ... 1

一行开始为n个点的x, y坐标,接着是 output,再接着是1,表示从顶点1出发,经v1,v2,...,返回1,注意顶点编号从1开始。

十个顶点数据集的一些数据示例如下:

1
2
3
4
0.607122 0.664447 0.953593 0.021519 0.757626 0.921024 0.586376 0.433565 0.786837 0.052959 0.016088 0.581436 0.496714 0.633571 0.227777 0.971433 0.665490 0.074331 0.383556 0.104392 output 1 3 8 6 10 9 5 2 4 7 1 
0.930534 0.747036 0.277412 0.938252 0.794592 0.794285 0.961946 0.261223 0.070796 0.384302 0.097035 0.796306 0.452332 0.412415 0.341413 0.566108 0.247172 0.890329 0.429978 0.232970 output 1 3 2 9 6 5 8 7 10 4 1
0.686712 0.087942 0.443054 0.277818 0.494769 0.985289 0.559706 0.861138 0.532884 0.351913 0.712561 0.199273 0.554681 0.657214 0.909986 0.277141 0.931064 0.639287 0.398927 0.406909 output 1 5 2 10 7 4 3 9 8 6 1

画出第一个例子的全部顶点和边。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import matplotlib.pyplot as plt
points='0.607122 0.664447 0.953593 0.021519 0.757626 0.921024 0.586376 0.433565 0.786837 0.052959 0.016088 0.581436 0.496714 0.633571 0.227777 0.971433 0.665490 0.074331 0.383556 0.104392'
float_list = list(map(lambda x: float(x), points.split(' ')))

x,y = [],[]
for idx, p in enumerate(float_list):
if idx % 2 == 0:
x.append(p)
else:
y.append(p)

for i in range(0, len(x)):
for j in range(0, len(x)):
if i == j:
continue
plt.plot((x[i],x[j]),(y[i],y[j]))

plt.show()
全连接的图

这个例子的最短TSP旅程为 \[ 1 \rightarrow 3 \rightarrow 8 \rightarrow 6 \rightarrow 10 \rightarrow 9 \rightarrow 5 \rightarrow 2 \rightarrow 4 \rightarrow 7 \rightarrow 1 \]

{linenos
1
2
3
4
5
6
7
8
tour_str = '1 3 8 6 10 9 5 2 4 7 1'
tour = list(map(lambda x: int(x), tour_str.split(' ')))

for i in range(0, len(tour)-1):
p1 = tour[i] - 1
p2 = tour[i + 1] - 1
plt.plot((x[p1],x[p2]),(y[p1],y[p2]))
plt.show()
最短路径

PTR_NET TSP 的Python代码

初始化Init Graph Edges

在之前的自顶向下的递归版本中,需要做一些改动。首先,是图的初始化,我们依然延续之前的邻接矩阵来表示,由于这次的图是无向图,对于任意两个顶点,需要初始化双向的边。

{linenos
1
2
3
4
5
6
7
8
g: Graph = Graph(N)
for v in range(N):
for u in range(N):
diff_x = coordinates[v][0] - coordinates[u][0]
diff_y = coordinates[v][1] - coordinates[u][1]
dist: float = math.sqrt(diff_x * diff_x + diff_y * diff_y)
g.setDist(u, v, dist)
g.setDist(v, u, dist)

辅助变量记录父节点

另一大改动是需要在遍历过程中保存的顶点关联信息,以便在最终找到最短路径值时可以回溯对应的完整路径。在下面代码中,使用parent[bitstate][v] 来保存此状态下最小路径对应的顶点u。

{linenos
1
2
3
4
5
6
7
8
9
10
11
ret: float = FLOAT_INF
u_min: int = -1
for u in range(self.g.v_num):
if (state & (1 << u)) == 0:
s: float = self._recurse(u, state | 1 << u)
if s + edges[v][u] < ret:
ret = s + edges[v][u]
u_min = u
dp[state][v] = ret
self.parent[state][v] = u_min

当最终最短行程确定后,根据parent的信息可以按图索骥找到完整的行程顶点信息。

{linenos
1
2
3
4
5
6
7
8
9
def _form_tour(self):
self.tour = [0]
bit = 0
v = 0
for _ in range(self.g.v_num - 1):
v = self.parent[bit][v]
self.tour.append(v)
bit = bit | (1 << v)
self.tour.append(0)

需要注意的是,有可能存在多个最短行程,它们的距离值是一致的。这种情况下,代码输出的最短路径可能和数据集output后行程路径不一致,但是的两者的总距离是一致的。下面的代码验证了这一点。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
tsp: TSPSolver = TSPSolver(g)
tsp.solve()

output_dist: float = 0.0
output_tour = list(map(lambda x: int(x) - 1, output.split(' ')))
for v in range(1, len(output_tour)):
pre_v = output_tour[v-1]
curr_v = output_tour[v]
diff_x = coordinates[pre_v][0] - coordinates[curr_v][0]
diff_y = coordinates[pre_v][1] - coordinates[curr_v][1]
dist: float = math.sqrt(diff_x * diff_x + diff_y * diff_y)
output_dist += dist

passed = abs(tsp.dist - output_dist) < 10e-5
if passed:
print(f'passed dist={tsp.tour}')
else:
print(f'Min Tour Distance = {output_dist}, Computed Tour Distance = {tsp.dist}, Expected Tour = {output_tour}, Result = {tsp.tour}')

本文所有代码在 github/MyEncyclopedia tsp/alg_plane 中。

Leetcode 679 24 Game (Hard)

先来介绍一下24点游戏题目,大家一定都玩过,就是给定4个牌面数字,用加减乘除计算24点。

本篇会用两种偏函数式的 Python 3解法来AC 24 Game。

Leetcode 679 24 Game (Hard) > You have 4 cards each containing a number from 1 to 9. You need to judge whether they could operated through *, /, +, -, (, ) to get the value of 24.

Example 1:

Input: [4, 1, 8, 7]

Output: True

Explanation: (8-4) * (7-1) = 24

Example 2:

Input: [1, 2, 1, 2]

Output: False

itertools.permutations

先来介绍一下Python itertools.permutations 的用法,正好用Leetcode 中的Permutation问题来示例。Permutations 的输入可以是List,返回是 generator 实例,用于生成所有排列。简而言之,python 的 generator 可以和List一样,用 for 语句来全部遍历产生的值。和List不同的是,generator 的所有值并不必须全部初始化,一般按需产生从而大量减少内存占用。下面在介绍 yield 时我们会看到如何合理构造 generator。

Leetcode 46 Permutations (Medium) > Given a collection of distinct integers, return all possible permutations.

Example:

Input: [1,2,3]

Output: [ [1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1,2], [3,2,1]]

用 permutations 很直白,代码只有一行。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# AC
# Runtime: 36 ms, faster than 91.78% of Python3 online submissions for Permutations.
# Memory Usage: 13.9 MB, less than 66.52% of Python3 online submissions for Permutations.

from itertools import permutations
from typing import List


class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
return [p for p in permutations(nums)]

itertools.combinations

有了排列就少不了组合,itertools.combinations 可以产生给定List的k个元素组合 \(\binom{n}{k}\),用一道算法题来举例,同样也是一句语句就可以AC。

Leetcode 77 Combinations (Medium)

Given two integers n and k, return all possible combinations of k numbers out of 1 ... n. You may return the answer in any order.

Example 1:

Input: n = 4, k = 2

Output: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4],]

Example 2:

Input: n = 1, k = 1

Output: [[1]]

{linenos
1
2
3
4
5
6
7
8
9
# AC
# Runtime: 84 ms, faster than 95.43% of Python3 online submissions for Combinations.
# Memory Usage: 15.2 MB, less than 68.98% of Python3 online submissions for Combinations.
from itertools import combinations
from typing import List

class Solution:
def combine(self, n: int, k: int) -> List[List[int]]:
return [c for c in combinations(list(range(1, n + 1)), k)]

itertools.product

当有多维度的对象需要迭代笛卡尔积时,可以用 product(iter1, iter2, ...)来生成generator,等价于多重 for 循环。

1
2
[lst for lst in product([1, 2, 3], ['a', 'b'])]
[(i, s) for i in [1, 2, 3] for s in ['a', 'b']]

这两种方式都生成了如下结果

1
[(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b'), (3, 'a'), (3, 'b')]

再举一个Leetcode的例子来实战product generator。

Leetcode 17. Letter Combinations of a Phone Number (Medium)

Given a string containing digits from 2-9 inclusive, return all possible letter combinations that the number could represent. A mapping of digit to letters (just like on the telephone buttons) is given below. Note that 1 does not map to any letters.

Example:

Input: "23"

Output: ["ad", "ae", "af", "bd", "be", "bf", "cd", "ce", "cf"].

举例来说,下面的代码当输入 digits 是 '352' 时,iter_dims 的值是 ['def', 'jkl', 'abc'],再输入给 product 后会产生 'dja', 'djb', 'djc', 'eja', 共 3 x 3 x 3 = 27个组合的值。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# AC
# Runtime: 24 ms, faster than 94.50% of Python3 online submissions for Letter Combinations of a Phone Number.
# Memory Usage: 13.7 MB, less than 83.64% of Python3 online submissions for Letter Combinations of a Phone Number.

from itertools import product
from typing import List


class Solution:
def letterCombinations(self, digits: str) -> List[str]:
if digits == "":
return []
mapping = {'2':'abc', '3':'def', '4':'ghi', '5':'jkl', '6':'mno', '7':'pqrs', '8':'tuv', '9':'wxyz'}
iter_dims = [mapping[i] for i in digits]

result = []
for lst in product(*iter_dims):
result.append(''.join(lst))

return result

yield 示例

Python具有独特的itertools generator,可以花式AC代码,接下来讲解如何进一步构造 generator。Python 定义只要函数中使用了yield关键字,这个函数就是 generator。Generator 在计算机领域的标准名称是 coroutine,即协程,是一种特殊的函数:当返回上层调用时自身能够保存调用栈状态,并在上层函数处理完逻辑后跳入到这个 generator,恢复之前的状态再继续运行下去。Yield语句也举一道经典的Fibonacci 问题。

Leetcode 509. Fibonacci Number (Easy)

The Fibonacci numbers, commonly denoted F(n) form a sequence, called the Fibonacci sequence, such that each number is the sum of the two preceding ones, starting from 0 and 1. That is, F(0) = 0, F(1) = 1 F(N) = F(N - 1) + F(N - 2), for N > 1. Given N, calculate F(N).

Example 1:

Input: 2

Output: 1

Explanation: F(2) = F(1) + F(0) = 1 + 0 = 1.

Example 2:

Input: 3

Output: 2

Explanation: F(3) = F(2) + F(1) = 1 + 1 = 2.

Example 3:

Input: 4

Output: 3

Explanation: F(4) = F(3) + F(2) = 2 + 1 = 3.

Fibonacci 的一般标准解法是循环迭代方式,可以以O(n)时间复杂度和O(1) 空间复杂度来AC。下面的 yield 版本中,我们构造了fib_next generator,它保存了最后两个值作为内部迭代状态,外部每调用一次可以得到下一个fib(n),如此外部只需不断调用直到满足题目给定次数。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# AC
# Runtime: 28 ms, faster than 85.56% of Python3 online submissions for Fibonacci Number.
# Memory Usage: 13.8 MB, less than 58.41% of Python3 online submissions for Fibonacci Number.

class Solution:
def fib(self, N: int) -> int:
if N <= 1:
return N
i = 2
for fib in self.fib_next():
if i == N:
return fib
i += 1

def fib_next(self):
f_last2, f_last = 0, 1
while True:
f = f_last2 + f_last
f_last2, f_last = f_last, f
yield f

yield from 示例

上述yield用法之后,再来演示 yield from 的用法。Yield from 始于Python 3.3,用于嵌套generator时的控制转移,一种典型的用法是有多个generator嵌套时,外层的outer_generator 用 yield from 这种方式等价代替如下代码。

1
2
3
def outer_generator():
for i in inner_generator():
yield i

用一道算法题目来具体示例。

Leetcode 230. Kth Smallest Element in a BST (Medium)

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Example 1: Input: root = [3,1,4,null,2], k = 1

1
2
3
4
5
  3
/ \
1 4
\
2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3

1
2
3
4
5
6
7
         5
/ \
3 6
/ \
2 4
/
1
Output: 3

直觉思路上,我们只要从小到大有序遍历每个节点直至第k个。因为给定的树是Binary Search Tree,有序遍历意味着以左子树、节点本身和右子树的访问顺序递归下去就行。由于ordered_iter是generator,递归调用自己的过程就是嵌套使用generator的过程。下面是yield版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# AC
# Runtime: 48 ms, faster than 90.31% of Python3 online submissions for Kth Smallest Element in a BST.
# Memory Usage: 17.9 MB, less than 14.91% of Python3 online submissions for Kth Smallest Element in a BST.

class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
def ordered_iter(node):
if node:
for sub_node in ordered_iter(node.left):
yield sub_node
yield node
for sub_node in ordered_iter(node.right):
yield sub_node

for node in ordered_iter(root):
k -= 1
if k == 0:
return node.val

等价于如下 yield from 版本:

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# AC
# Runtime: 56 ms, faster than 63.74% of Python3 online submissions for Kth Smallest Element in a BST.
# Memory Usage: 17.7 MB, less than 73.33% of Python3 online submissions for Kth Smallest Element in a BST.

class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
def ordered_iter(node):
if node:
yield from ordered_iter(node.left)
yield node
yield from ordered_iter(node.right)

for node in ordered_iter(root):
k -= 1
if k == 0:
return node.val

24 点问题之函数式枚举解法

看明白了itertools.permuations,combinations,product,yield以及yield from,我们回到本篇最初的24点游戏问题。

24点游戏的本质是枚举出所有可能运算,如果有一种方式得到24返回True,否则返回Flase。进一步思考所有可能的运算,包括下面三个维度:

  1. 4个数字的所有排列,比如给定 [1, 2, 3, 4],可以用permutations([1, 2, 3, 4]) 生成这个维度的所有可能

  2. 三个位置的操作符号的全部可能,可以用 product([+, -, *, /], repeat=3) 生成,具体迭代结果为:[+, +, +],[+, +, -],...

  3. 给定了前面两个维度后,还有一个比较不容易察觉但必要的维度:运算优先级。比如在给定数字顺序 [1, 2, 3, 4]和符号顺序 [+, *, -]之后可能的四种操作树

四种运算优先级

能否算得24点只需要枚举这三个维度笛卡尔积的运算结果

(维度1:数字组合) x (维度2:符号组合) x (维度3:优先级组合)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# AC
# Runtime: 112 ms, faster than 57.59% of Python3 online submissions for 24 Game.
# Memory Usage: 13.7 MB, less than 85.60% of Python3 online submissions for 24 Game.

import math
from itertools import permutations, product
from typing import List

class Solution:

def iter_trees(self, op1, op2, op3, a, b, c, d):
yield op1(op2(a, b), op3(c, d))
yield op1(a, op2(op3(b, c), d))
yield op1(a, op2(b, op3(c, d)))
yield op1(op2(a, op3(b, c)), d)

def judgePoint24(self, nums: List[int]) -> bool:
mul = lambda x, y: x * y
plus = lambda x, y: x + y
div = lambda x, y: x / y if y != 0 else math.inf
minus = lambda x, y: x - y

op_lst = [plus, minus, mul, div]

for ops in product(op_lst, repeat=3):
for val in permutations(nums):
for v in self.iter_trees(ops[0], ops[1], ops[2], val[0], val[1], val[2], val[3]):
if abs(v - 24) < 0.0001:
return True
return False

24 点问题之 DFS yield from 解法

一种常规的思路是,在四个数组成的集合中先选出任意两个数,枚举所有可能的计算,再将剩余的三个数组成的集合递归调用下去,直到叶子节点只剩一个数,如下图所示。

DFS 调用示例

下面的代码是这种思路的 itertools + yield from 解法,recurse方法是generator,会自我递归调用。当只剩下两个数时,用 yield 返回两个数的所有可能运算得出的值,其他非叶子情况下则自我调用使用yield from,例如4个数任选2个先计算再合成3个数的情况。这种情况下,比较麻烦的是由于4个数可能有相同值,若用 combinations(lst, 2) 先任选两个数,后续要生成剩余两个数加上第三个计算的数的集合代码会繁琐。因此,我们改成任选4个数index中的两个,剩余的indices 可以通过集合操作来完成。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# AC
# Runtime: 116 ms, faster than 56.23% of Python3 online submissions for 24 Game.
# Memory Usage: 13.9 MB, less than 44.89% of Python3 online submissions for 24 Game.

import math
from itertools import combinations, product, permutations
from typing import List

class Solution:

def judgePoint24(self, nums: List[int]) -> bool:
mul = lambda x, y: x * y
plus = lambda x, y: x + y
div = lambda x, y: x / y if y != 0 else math.inf
minus = lambda x, y: x - y

op_lst = [plus, minus, mul, div]

def recurse(lst: List[int]):
if len(lst) == 2:
for op, values in product(op_lst, permutations(lst)):
yield op(values[0], values[1])
else:
# choose 2 indices from lst of length n
for choosen_idx_lst in combinations(list(range(len(lst))), 2):
# remaining indices not choosen (of length n-2)
idx_remaining_set = set(list(range(len(lst)))) - set(choosen_idx_lst)

# remaining values not choosen (of length n-2)
value_remaining_lst = list(map(lambda x: lst[x], idx_remaining_set))
for op, idx_lst in product(op_lst, permutations(choosen_idx_lst)):
# 2 choosen values are lst[idx_lst[0]], lst[idx_lst[1]
value_remaining_lst.append(op(lst[idx_lst[0]], lst[idx_lst[1]]))
yield from recurse(value_remaining_lst)
value_remaining_lst = value_remaining_lst[:-1]

for v in recurse(nums):
if abs(v - 24) < 0.0001:
return True

旅行商问题(TSP)是计算机算法中经典的NP hard 问题。 在本系列文章中,我们将首先使用动态规划 AC aizu中的TSP问题,然后再利用深度学习求大规模下的近似解。深度学习应用解决问题时先以PyTorch实现监督学习算法 Pointer Network,进而结合强化学习来无监督学习,提高数据使用效率。 本系列完整列表如下:

  • 第一篇: 递归DP方法 AC AIZU TSP问题

  • 第二篇: 二维空间TSP数据集及其DP解法

  • 第三篇: 深度学习 Pointer Networks 的 Pytorch实现

  • 第四篇: 搜寻最有可能路径:Viterbi算法和其他

  • 第五篇: 深度强化学习无监督算法的 Pytorch实现

TSP 问题回顾

TSP可以用图模型来表达,无论有向图或无向图,无论全连通图或者部分连通的图都可以作为TSP问题。 Wikipedia TSP 中举了一个无向全连通的TSP例子。如下图所示,四个顶点A,B,C,D构成无向全连通图。TSP问题要求在所有遍历所有点后返回初始点的回路中找到最短的回路。例如,\(A \rightarrow B \rightarrow C \rightarrow D \rightarrow A\)\(A \rightarrow C \rightarrow B \rightarrow D \rightarrow A\) 都是有效的回路,但是TSP需要返回这些回路中的最短回路(注意,最短回路可能会有多条)。

Wikipedia 4个顶点组成的图

无论是哪种类型的图,我们都能用邻接矩阵表示出一个图。上面的Wikipedia中的图可以用下面的矩阵来描述。

\[ \begin{matrix} & \begin{matrix}A&B&C&D\end{matrix} \\\\ \begin{matrix}A\\\\B\\\\C\\\\D\end{matrix} & \begin{bmatrix}-&20&42&35\\\\20&-&30&34\\\\42&30&-&12\\\\35&34&12&-\end{bmatrix}\\\\ \end{matrix} \]

当然,大多数情况下,TSP问题会被限定在欧氏空间,即二维地图中的全连通无向图。因为,如果将顶点表示一个地理位置,一般来说它可以和其他所有顶点连通,回来的距离相同,由此构成无向图。

AIZU TSP 问题

AIZU在线题库 有一道有向不完全连通图的TSP问题。给定V个顶点和E条边,输出最小回路值。例如,题目里的例子如下所示,由4个顶点和6条单向边构成。

这个示例的答案是16,对应的回路是 \(0\rightarrow1\rightarrow3\rightarrow2\rightarrow0\),由下图的红色边构成。注意,这个题目可能不存在合法解,原因是无回路存在,此时返回-1,可以合理地理解成无穷大。

暴力解法

一种暴力方法是枚举所有可能的从某一顶点的回路,取其中的最小值即可。下面的 Python 示例如何枚举4个顶点构成的图中从顶点0出发的所有回路。

{linenos
1
2
3
4
5
from itertools import permutations
v = [1,2,3]
p = permutations(v)
for t in list(p):
print([0] + list(t) + [0])

所有从顶点0出发的回路如下:

{linenos
1
2
3
4
5
6
[0, 1, 2, 3, 0]
[0, 1, 3, 2, 0]
[0, 2, 1, 3, 0]
[0, 2, 3, 1, 0]
[0, 3, 1, 2, 0]
[0, 3, 2, 1, 0]

很显然,这种方式的时间复杂度是 O(\(n!\)),无法通过AIZU。

动态规划求解

我们可以使用位状态压缩的动态规划来AC这道题。 首先,需要将回路过程中的状态编码成二进制的表示。例如,在四顶点的例子中,如果顶点2和1都被访问过,并且此时停留在顶点1。将已经访问的顶点对应的位置1,那么编码成0110,此外,还需要保存当前顶点的位置,因此我们将代表状态的数组扩展成二维,第一维是位状态,第二维是顶点所在位置,即 \(dp[bitstate][v]\)。这个例子的状态表示就是 \(dp["0110"][1]\)

状态转移方程如下: \[ dp[bitstate][v] = \min ( dp[bitstate \cup \{u\}][u] + dist(v,u) \mid u \notin bitstate ) \] 这种方法对应的时间复杂度是 O(\(n^2*2^n\) ),因为总共有 \(2^n * n\) 个状态,而每个状态又需要一次遍历。虽然都是指数级复杂度,但是它们的巨大区别由下面可以看出区别。

\(n!\) \(n^2*2^n\)
n=8 40320 16384
n=10 3628800 102400
n=12 479001600 589824
n=14 87178291200 3211264

暂停思考一下为什么状态压缩DP能工作。注意到之前暴力解法中其实是有很多重复计算,下面红圈表示重复的计算节点。

在本篇中,我们将会用Python 3和Java 8 实现自顶向下的DP 缓存版本。这种方式比较符合直觉,因为我们不需要预先考虑计算节点的依赖关系。在Java中我们使用了一个小技巧,dp数组初始化成Integer.MAX_VALUE,如此只需要一条语句就能完成更新dp值。

1
res = Math.min(res, s + g.edges[v][u]);

当然,为了AC 这道题,我们需要区分出真正无法到达的情况并返回-1。 在Python实现中,也可以使用同样的技巧,但是这次示例一般的实现方法:将dp数组初始化成-1并通过 if-else 来区分不同情况。

1
2
3
4
5
6
7
INT_INF = -1

if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])

下面附完整的Python 3和Java 8的AC代码,同步在 github

AIZU Java 8 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// passed http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DPL_2_A
import java.util.Arrays;
import java.util.Scanner;

public class Main {
public static class Graph {
public final int V_NUM;
public final int[][] edges;

public Graph(int V_NUM) {
this.V_NUM = V_NUM;
this.edges = new int[V_NUM][V_NUM];
for (int i = 0; i < V_NUM; i++) {
Arrays.fill(this.edges[i], Integer.MAX_VALUE);
}
}

public void setDist(int src, int dest, int dist) {
this.edges[src][dest] = dist;
}

}

public static class TSP {
public final Graph g;
long[][] dp;

public TSP(Graph g) {
this.g = g;
}

public long solve() {
int N = g.V_NUM;
dp = new long[1 << N][N];
for (int i = 0; i < dp.length; i++) {
Arrays.fill(dp[i], -1);
}

long ret = recurse(0, 0);
return ret == Integer.MAX_VALUE ? -1 : ret;
}

private long recurse(int state, int v) {
int ALL = (1 << g.V_NUM) - 1;
if (dp[state][v] >= 0) {
return dp[state][v];
}
if (state == ALL && v == 0) {
dp[state][v] = 0;
return 0;
}
long res = Integer.MAX_VALUE;
for (int u = 0; u < g.V_NUM; u++) {
if ((state & (1 << u)) == 0) {
long s = recurse(state | 1 << u, u);
res = Math.min(res, s + g.edges[v][u]);
}
}
dp[state][v] = res;
return res;

}

}

public static void main(String[] args) {

Scanner in = new Scanner(System.in);
int V = in.nextInt();
int E = in.nextInt();
Graph g = new Graph(V);
while (E > 0) {
int src = in.nextInt();
int dest = in.nextInt();
int dist = in.nextInt();
g.setDist(src, dest, dist);
E--;
}
System.out.println(new TSP(g).solve());
}
}

AIZU Python 3 递归DP版本

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import List

INT_INF = -1

class Graph:
v_num: int
edges: List[List[int]]

def __init__(self, v_num: int):
self.v_num = v_num
self.edges = [[INT_INF for c in range(v_num)] for r in range(v_num)]

def setDist(self, src: int, dest: int, dist: int):
self.edges[src][dest] = dist


class TSPSolver:
g: Graph
dp: List[List[int]]

def __init__(self, g: Graph):
self.g = g
self.dp = [[None for c in range(g.v_num)] for r in range(1 << g.v_num)]

def solve(self) -> int:
return self._recurse(0, 0)

def _recurse(self, v: int, state: int) -> int:
"""

:param v:
:param state:
:return: -1 means INF
"""
dp = self.dp
edges = self.g.edges

if dp[state][v] is not None:
return dp[state][v]

if (state == (1 << self.g.v_num) - 1) and (v == 0):
dp[state][v] = 0
return dp[state][v]

ret: int = INT_INF
for u in range(self.g.v_num):
if (state & (1 << u)) == 0:
s: int = self._recurse(u, state | 1 << u)
if s != INT_INF and edges[v][u] != INT_INF:
if ret == INT_INF:
ret = s + edges[v][u]
else:
ret = min(ret, s + edges[v][u])
dp[state][v] = ret
return ret


def main():
V, E = map(int, input().split())
g: Graph = Graph(V)
for _ in range(E):
src, dest, dist = map(int, input().split())
g.setDist(src, dest, dist)

tsp: TSPSolver = TSPSolver(g)
print(tsp.solve())


if __name__ == "__main__":
main()

Leetcode 1227 是一道有意思的概率题,本篇将从多个角度来讨论这道题。题目如下

有 n 位乘客即将登机,飞机正好有 n 个座位。第一位乘客的票丢了,他随便选了一个座位坐下。 剩下的乘客将会: 如果他们自己的座位还空着,就坐到自己的座位上, 当他们自己的座位被占用时,随机选择其他座位,第 n 位乘客坐在自己的座位上的概率是多少?

示例 1: 输入:n = 1 输出:1.00000 解释:第一个人只会坐在自己的位置上。

示例 2: 输入: n = 2 输出: 0.50000 解释:在第一个人选好座位坐下后,第二个人坐在自己的座位上的概率是 0.5。

提示: 1 <= n <= 10^5

假设规模为n时答案为f(n),一般来说,这种递推问题在数学形式上可能有关于n的简单数学表达式(closed form),或者肯定有f(n)关于f(n-k)的递推表达式。工程上,我们可以通过通过多次模拟即蒙特卡罗模拟来算得近似的数值解。

Monte Carlo 模拟发现规律

首先,我们先来看看如何高效的用代码来模拟。根据题意的描述过程,直接可以写出下面代码。seats为n大小的bool 数组,每个位置表示此位置是否已经被占据。然后依次给第i个人按题意分配座位。注意,每次参数随机数范围在[0,n-1],因此,会出现已经被占据的情况,此时需要再次随机,直至分配到空位。

暴力直接模拟

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def simulate_bruteforce(n: int) -> bool:
"""
Simulates one round. Unbounded time complexity.
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
seats[random.randint(0, n - 1)] = True
else:
if not seats[i]: # i-th has his seat
seats[i] = True
else:
while True:
rnd = random.randint(0, n - 1) # random until no conflicts
if not seats[rnd]:
seats[rnd] = True
break
return not seats[n-1]

运行上面的代码来模拟 n 从 2 到10 的情况,每种情况跑500次模拟,输出如下

1
2
3
4
5
6
7
8
9
10
1 => 1.0
2 => 0.55
3 => 0.54
4 => 0.486
5 => 0.488
6 => 0.498
7 => 0.526
8 => 0.504
9 => 0.482
10 => 0.494

发现当 n>=2 时,似乎概率都是0.5。

标准答案

其实,这道题的标准答案就是 n=1 为1,n>=2 为0.5。下面是 python 3 标准答案。本篇后面会从多个角度来探讨为什么是0.5 。

{linenos
1
2
3
class Solution:
def nthPersonGetsNthSeat(self, n: int) -> float:
return 1.0 if n == 1 else 0.5

O(n) 改进算法

上面的暴力直接模拟版本有个最大的问题是当n很大时,随机分配座位会产生大量冲突,因此,最坏复杂度是没有任何上限的。解决方法是每次发生随机分配时保证不冲突,能直接选到空位。下面是一种最坏复杂度O(n)的模拟过程,seats数组初始话成 0,1,...,n-1,表示座位号。当第i个人登机时,seats[i:n] 的值为他可以选择的座位集合,而seats[0:i]为已经被占据的座位集合。由于[i: n]是连续空间,产生随机数就能保证不冲突。当第i个人选完座位时,将他选中的seats[k]和seats[i] 交换,保证第i+i个人面临的seats[i+1:n]依然为可选座位集合。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def simulate_online(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [i for i in range(n)]

def swap(i, j):
tmp = seats[i]
seats[i] = seats[j]
seats[j] = tmp

# for each person, the seats array idx available are [i, n-1]
for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
swap(rnd, 0)
else:
if seats[i] == i: # i-th still has his seat
pass
else:
rnd = random.randint(i, n - 1) # selects idx from [i, n-1]
swap(rnd, i)
return seats[n-1] == n - 1

递推思维

这一节我们用数学递推思维来解释0.5的解。令f(n) 为第 n 位乘客坐在自己的座位上的概率,考察第一个人的情况(first step analysis),有三种可能

  1. 第一个人选了第一个即自己的座位,那么最后一个人一定能保证坐在自己的座位。
  2. 第一个人选了最后一个人的座位,无论中间什么过程,最后一个人无法坐到自己座位
  3. 第一个人选了第i个座位,(1<i<n),那么第i个人前面的除了第一个外的人都会坐在自己位置上,第i个人由于没有自己座位,随机在剩余的座位1,座位 [i+1,n] 中随机选择,此时,问题转变为f(n-i+1),如下图所示。
第一个人选了位置i
第i个人将问题转换成f(n-i+1)

通过上面分析,得到概率递推关系如下

\[ f(n) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 1 & & p=\frac{1}{n} \quad \text{选了第一个位置} \\\\\\ f(n-i+1) & & p=\frac{1}{n} \quad \text{选了第i个位置,1<i<n} \\\\\\ 0 & & p=\frac{1}{n} \quad \text{选了第n个位置} \end{array} \right. \end{align*} \]

即f(n)的递推式为: \[ f(n) = \frac{1}{n} + \frac{1}{n} \times [ f(n-1) + f(n-2) + ...+ f(2)], \quad n>=2 \] 同理,f(n+1)递推式如下 \[ f(n+1) = \frac{1}{n+1} + \frac{1}{n+1} \times [ f(n) + f(n-1) + ...+ f(2)] \] \((n+1)f(n+1) - nf(n)\) 抵消 \(f(n-1) + ...f(2)\) 项,可得 \[ (n+1)f(n+1) - nf(n) = f(n) \]\[ f(n+1) = f(n) = \frac{1}{2} \quad n>=2 \]

用数学归纳法也可以证明 n>=2 时 f(n)=0.5。

简化的思考方式

我们再仔细思考一下上面的第三种情况,就是第一个人坐了第i个座位,1<i<n,此时,程序继续,不产生结果,直至产生结局1或者2,也就是case 1和2是真正的结局节点,它们产生的概率相同,因此答案是1/2。

从调用图可以看出这种关系,由于中间节点 f(4),f(3),f(2)生成Case 1和2的概率一样,因此无论它们之间是什么关系,最后结果都是1/2.

知乎上有个很形象的类比理解方式

考虑一枚硬币,正面向上的概率为 1/n,反面也是,立起来的概率为 (n-2)/n 。我们规定硬币立起来重新抛,但重新抛时,n会至少减小1。求结果为反面的概率。这样很显然结果为 1/2 。

这里,正面向上对应Case 2,反面对应Case 1。

这种思想可以写出如下代码,seats为 n 大小的bool 数组,当第i个人(0<i<n)发现自己座位被占的话,此时必然seats[0]没有被占,同时seats[i+1:]都是空的。假设seats[0]被占的话,要么是第一个人占的,要么是第p个人(p<i)坐了,两种情况下乱序都已经恢复了,此时第i个座位一定是空的。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def simulate(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
seats[rnd] = True
else:
if not seats[i]: # i-th still has his seat
seats[i] = True
else:
# 0 must not be available, now we have 0 and [i+1, n-1],
rnd = random.randint(i, n - 1)
if rnd == i:
seats[0] = True
else:
seats[rnd] = True
return not seats[n-1]

上一篇我们从原理层面解析了AlphaGo Zero如何改进MCTS算法,通过不断自我对弈,最终实现从零棋力开始训练直至能够打败任何高手。在本篇中,我们在已有的N子棋OpenAI Gym 环境中用Pytorch实现一个简化版的AlphaGo Zero算法。本篇所有代码在 github MyEncyclopedia/ConnectNGym 中,其中部分参考了SongXiaoJun 的 AlphaZero_Gomoku

AlphaGo Zero MCTS 树节点

上一篇中,我们知道AlphaGo Zero 的MCTS树搜索是基于传统MCTS 的UCT (UCB for Tree)的改进版PUCT(Polynomial Upper Confidence Trees)。局面节点的PUCT值由两部分组成,分别是代表Exploitation的action value Q值,和代表Exploration的U值。 \[ PUCT(s, a) =Q(s,a) + U(s,a) \] U值计算由这些参数决定:系数\(c_{puct}\),节点先验概率P(s, a) ,父节点访问次数,本节点的访问次数。具体公式如下 \[ U(s, a)=c_{p u c t} \cdot P(s, a) \cdot \frac{\sqrt{\Sigma_{b} N(s, b)}}{1+N(s, a)} \]

因此在实现过程中,对于一个树节点来说,需要保存其Q值、节点访问次数 _visit_num和先验概率 _prior。其中,_prior在节点初始化后不变,Q值和 visit_num随着游戏MCTS模拟进程而改变。此外,节点保存了 parent和_children变量,用于维护父子关系。c_puct为class variable,作为全局参数。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
class TreeNode:
"""
MCTS Tree Node
"""

c_puct: ClassVar[int] = 5 # class-wise global param c_puct, exploration weight factor.

_parent: TreeNode
_children: Dict[int, TreeNode] # map from action to TreeNode
_visit_num: int
_Q: float # Q value of the node, which is the mean action value.
_prior: float

和上面的计算公式相对应,下列代码根据节点状态计算PUCT(s, a)。

{linenos
1
2
3
4
5
6
7
8
9
10
class TreeNode:

def get_puct(self) -> float:
"""
Computes AlphaGo Zero PUCT (polynomial upper confidence trees) of the node.

:return: Node PUCT value.
"""
U = (TreeNode.c_puct * self._prior * np.sqrt(self._parent._visit_num) / (1 + self._visit_num))
return self._Q + U

AlphaGo Zero MCTS在playout时遇到已经被展开的节点,会根据selection规则选择子节点,该规则本质上是在所有子节点中选择最大的PUCT值的节点。

\[ a=\operatorname{argmax}_a(PUCT(s, a))=\operatorname{argmax}_a(Q(s,a) + U(s,a)) \]

{linenos
1
2
3
4
5
6
7
8
9
class TreeNode:

def select(self) -> Tuple[Pos, TreeNode]:
"""
Selects an action(Pos) having max UCB value.

:return: Action and corresponding node
"""
return max(self._children.items(), key=lambda act_node: act_node[1].get_puct())

新的叶节点一旦在playout时产生,关联的 v 值会一路向上更新至根节点,具体新节点的v值将在下一节中解释。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class TreeNode:

def propagate_to_root(self, leaf_value: float):
"""
Updates current node with observed leaf_value and propagates to root node.

:param leaf_value:
:return:
"""
if self._parent:
self._parent.propagate_to_root(-leaf_value)
self._update(leaf_value)

def _update(self, leaf_value: float):
"""
Updates the node by newly observed leaf_value.

:param leaf_value:
:return:
"""
self._visit_num += 1
# new Q is updated towards deviation from existing Q
self._Q += 0.5 * (leaf_value - self._Q)

AlphaGo Zero MCTS Player 实现

AlphaGo Zero MCTS 在训练阶段分为如下几个步骤。游戏初始局面下,整个局面树的建立由子节点的不断被探索而丰富起来。AlphaGo Zero对弈一次即产生了一次完整的游戏开始到结束的动作系列。在对弈过程中的某一游戏局面,需要采样海量的playout,又称MCTS模拟,以此来决定此局面的下一步动作。一次playout可视为在真实游戏状态树的一种特定采样,playout可能会产生游戏结局,生成真实的v值;也可能explore 到新的叶子节点,此时v值依赖策略价值网络的输出,目的是利用训练的神经网络来产生高质量的游戏对战局面。每次playout会从当前给定局面递归向下,向下的过程中会遇到下面三种节点情况。

  • 若局面节点是游戏结局(叶子节点),可以得到游戏的真实价值 z。从底部节点带着z向上更新沿途节点的Q值,直至根节点(初始局面)。
  • 若局面节点从未被扩展过(叶子节点),此时会将局面编码输入到策略价值双头网络,输出结果为网络预估的action分布和v值。Action分布作为节点先验概率P(s, a)来初始化子节点,预估的v值和上面真实游戏价值z一样,从叶子节点向上沿途更新到根节点。
  • 若局面节点已经被扩展过,则根据PUCT的select规则继续选择下一节点。

海量的playout模拟后,建立了游戏状态树的节点信息。但至此,AI玩家只是收集了信息,还仍未给定局面落子,而落子的决定由Play规则产生。下图展示了给定局面(Current节点)下,MCST模拟进行的多次playout探索后生成的局面树,play规则根据这些节点信息,产生Current 节点的动作分布 \(\pi\) ,确定下一步落子。

MCTS Playout和Play关系

Play 给定局面

对于当前需要做落子决定的某游戏局面\(s_0\),根据如下play公式生成落子分布 $$ ,子局面的落子概率正比于其访问次数的某次方。其中,某次方的倒数称为温度参数(Temperature)。

\[ \pi\left(a \mid s_{0}\right)=\frac{N\left(s_{0}, a\right)^{1 / \tau}}{\sum_{b} N\left(s_{0}, b\right)^{1 / \tau}} \]

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class MCTSAlphaGoZeroPlayer(BaseAgent):

def _next_step_play_act_probs(self, game: ConnectNGame) -> Tuple[List[Pos], ActionProbs]:
"""
For the given game status, run playouts number of times specified by self._playout_num.
Returns the action distribution according to AlphaGo Zero MCTS play formula.

:param game:
:return: actions and their probability
"""

for n in range(self._playout_num):
self._playout(copy.deepcopy(game))

act_visits = [(act, node._visit_num) for act, node in self._current_root._children.items()]
acts, visits = zip(*act_visits)
act_probs = softmax(1.0 / MCTSAlphaGoZeroPlayer.temperature * np.log(np.array(visits) + 1e-10))

return acts, act_probs

在训练模式时,考虑到偏向exploration的目的,在\(\pi\) 落子分布的基础上增加了 Dirichlet 分布。

\[ P(s,a) = (1-\epsilon)*\pi(a \mid s) + \epsilon * \boldsymbol{\eta} \quad (\boldsymbol{\eta} \sim \operatorname{Dir}(0.3)) \]

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MCTSAlphaGoZeroPlayer(BaseAgent):

def get_action(self, board: PyGameBoard) -> Pos:
"""
Method defined in BaseAgent.

:param board:
:return: next move for the given game board.
"""
return self._get_action(copy.deepcopy(board.connect_n_game))[0]

def _get_action(self, game: ConnectNGame) -> Tuple[MoveWithProb]:
epsilon = 0.25
avail_pos = game.get_avail_pos()
move_probs: ActionProbs = np.zeros(game.board_size * game.board_size)
assert len(avail_pos) > 0

# the pi defined in AlphaGo Zero paper
acts, act_probs = self._next_step_play_act_probs(game)
move_probs[list(acts)] = act_probs
if self._is_training:
# add Dirichlet Noise when training in favour of exploration
p_ = (1-epsilon) * act_probs + epsilon * np.random.dirichlet(0.3 * np.ones(len(act_probs)))
move = np.random.choice(acts, p=p_)
assert move in game.get_avail_pos()
else:
move = np.random.choice(acts, p=act_probs)

self.reset()
return move, move_probs

一次完整的对弈

一次完整的AI对弈就是从初始局面迭代play直至游戏结束,对弈生成的数据是一系列的 $(s, , z) $。

如下图 s0 到 s5 是某次井字棋的对弈。最终结局是先手黑棋玩家赢,即对于黑棋玩家 z = +1。需要注意的是:z = +1 是对于所有黑棋面临的局面,即s0, s2, s4,而对应的其余白棋玩家来说 z = -1。

一局完整对弈

\[ \begin{align*} &0: (s_0, \vec{\pi_0}, +1) \\ &1: (s_1, \vec{\pi_1}, -1) \\ &2: (s_2, \vec{\pi_2}, +1) \\ &3: (s_3, \vec{\pi_3}, -1) \\ &4: (s_4, \vec{\pi_4}, +1) \end{align*} \]

以下代码展示如何在AI对弈时收集数据 $(s, , z) $

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class MCTSAlphaGoZeroPlayer(BaseAgent):

def self_play_one_game(self, game: ConnectNGame) \
-> List[Tuple[NetGameState, ActionProbs, NDArray[(Any), np.float]]]:
"""

:param game:
:return:
Sequence of (s, pi, z) of a complete game play. The number of list is the game play length.
"""

states: List[NetGameState] = []
probs: List[ActionProbs] = []
current_players: List[np.float] = []

while not game.game_over:
move, move_probs = self._get_action(game)
states.append(convert_game_state(game))
probs.append(move_probs)
current_players.append(game.current_player)
game.move(move)

current_player_z = np.zeros(len(current_players))
current_player_z[np.array(current_players) == game.game_result] = 1.0
current_player_z[np.array(current_players) == -game.game_result] = -1.0
self.reset()

return list(zip(states, probs, current_player_z))

Playout 代码实现

一次playout会从当前局面根据PUCT selection规则下沉到叶子节点,如果此叶子节点非游戏终结点,则会扩展当前节点生成下一层新节点,其先验分布由策略价值网络输出的action分布决定。一次playout最终会得到叶子节点的 v 值,并沿着MCTS树向上更新沿途的所有父节点 Q值。 从上一篇文章已知,游戏节点的数量随着参数而指数级增长,举例来说,井字棋(k=3,m=n=3)的状态数量是5478,k=3,m=n=4时是6035992 ,k=m=n=4时是9722011 。如果我们将初始局面节点作为根节点,同时保存海量playout探索得到的局面节点,实现时会发现我们无法将所有探索到的局面节点都保存在内存中。这里的一种解决方法是在一次self play中每轮playout之后,将根节点重置成落子的节点,从而有效控制整颗局面树中的节点数量。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class MCTSAlphaGoZeroPlayer(BaseAgent):

def _playout(self, game: ConnectNGame):
"""
From current game status, run a sequence down to a leaf node, either because game ends or unexplored node.
Get the leaf value of the leaf node, either the actual reward of game or action value returned by policy net.
And propagate upwards to root node.

:param game:
"""
player_id = game.current_player

node = self._current_root
while True:
if node.is_leaf():
break
act, node = node.select()
game.move(act)

# now game state is a leaf node in the tree, either a terminal node or an unexplored node
act_and_probs: Iterator[MoveWithProb]
act_and_probs, leaf_value = self._policy_value_net.policy_value_fn(game)

if not game.game_over:
# case where encountering an unexplored leaf node, update leaf_value estimated by policy net to root
for act, prob in act_and_probs:
game.move(act)
child_node = node.expand(act, prob)
game.undo()
else:
# case where game ends, update actual leaf_value to root
if game.game_result == ConnectNGame.RESULT_TIE:
leaf_value = ConnectNGame.RESULT_TIE
else:
leaf_value = 1 if game.game_result == player_id else -1
leaf_value = float(leaf_value)

# Update leaf_value and propagate up to root node
node.propagate_to_root(-leaf_value)

编码游戏局面

为了将信息有效的传递给策略神经网络,必须从当前玩家的角度编码游戏局面。局面不仅要反映棋盘上黑白棋子的位置,也需要考虑最后一个落子的位置以及是否为当前玩家棋局。因此,我们将某局面按照当前玩家来编码,返回类型为4个棋盘大小组成的ndarray,即shape [4, board_size, board_size],其中

  1. 第一个数组编码当前玩家的棋子位置
  2. 第二个数组编码对手玩家棋子位置
  3. 第三个表示最后落子位置
  4. 第四个全1表示此局面为先手(黑棋)局面,全0表示白棋局面

例如之前游戏对弈中的前四步:

s1->s2 后局面s2的编码:当前玩家为黑棋玩家,编码局面s2 返回如下ndarray,数组[0] 为s2黑子位置,[1]为白子位置,[2]表示最后一个落子(1, 1) ,[3] 全1表示当前是黑棋落子的局面。

编码黑棋玩家局面 s2
s2->s3 后局面s3的编码:当前玩家为白棋玩家,编码返回如下,数组[0] 为s3白子位置,[1]为黑子位置,[2]表示最后一个落子(1, 0) ,[3] 全0表示当前是白棋落子的局面。
编码白棋玩家局面 s3

具体代码实现如下。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
NetGameState = NDArray[(4, Any, Any), np.int]


def convert_game_state(game: ConnectNGame) -> NetGameState:
"""
Converts game state to type NetGameState as ndarray.

:param game:
:return:
Of shape 4 * board_size * board_size.
[0] is current player positions.
[1] is opponent positions.
[2] is last move location.
[3] all 1 meaning move by black player, all 0 meaning move by white.
"""
state_matrix = np.zeros((4, game.board_size, game.board_size))

if game.action_stack:
actions = np.array(game.action_stack)
move_curr = actions[::2]
move_oppo = actions[1::2]
for move in move_curr:
state_matrix[0][move] = 1.0
for move in move_oppo:
state_matrix[1][move] = 1.0
# indicate the last move location
state_matrix[2][actions[-1]] = 1.0
if len(game.action_stack) % 2 == 0:
state_matrix[3][:, :] = 1.0 # indicate the colour to play
return state_matrix[:, ::-1, :]

策略价值网络训练

策略价值网络是一个共享参数 \(\theta\) 的双头网络,给定上面的游戏局面编码会产生预估的p和v。

\[ \vec{p_{\theta}}, v_{\theta}=f_{\theta}(s) \] 结合真实游戏对弈后产生三元组数据 $(s, , z) $ ,按照论文中的loss 来训练神经网络。 \[ l=\sum_{t}\left(v_{\theta}\left(s_{t}\right)-z_{t}\right)^{2}-\vec{\pi_{t}} \cdot \log \left(\vec{p_{\theta}}\left(s_{t}\right)\right) + c {\lVert \theta \rVert}^2 \]

下面代码为Pytorch backward部分。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def backward_step(self, state_batch: List[NetGameState], probs_batch: List[ActionProbs],
value_batch: List[NDArray[(Any), np.float]], lr) -> Tuple[float, float]:
if self.use_gpu:
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
probs_batch = Variable(torch.FloatTensor(probs_batch).cuda())
value_batch = Variable(torch.FloatTensor(value_batch).cuda())
else:
state_batch = Variable(torch.FloatTensor(state_batch))
probs_batch = Variable(torch.FloatTensor(probs_batch))
value_batch = Variable(torch.FloatTensor(value_batch))

self.optimizer.zero_grad()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

log_act_probs, value = self.policy_value_net(state_batch)
# loss = (z - v)^2 - pi*T * log(p) + c||theta||^2
value_loss = F.mse_loss(value.view(-1), value_batch)
policy_loss = -torch.mean(torch.sum(probs_batch * log_act_probs, 1))
loss = value_loss + policy_loss
loss.backward()
self.optimizer.step()
entropy = -torch.mean(torch.sum(torch.exp(log_act_probs) * log_act_probs, 1))
return loss.item(), entropy.item()

参考资料

AlphaGo Zero是Deepmind 最后一代AI围棋算法,因为已经达到了棋类游戏AI的终极目的:给定任何游戏规则,AI从零出发只通过自我对弈的方式提高,最终可以取得超越任何对手(包括顶级人类棋手和上一代AlphaGo)的能力。换种方式说,当给定足够多的时间和计算资源,可以取得无限逼近游戏真实解的能力。这一篇,我们深入分析AlphaGo Zero的设计理念和关键组件的细节并解释组件之间的关联。下一篇中,我们将在已有的N子棋OpenAI Gym 环境中用Pytorch实现一个简化版的AlphaGo Zero算法。

AlphaGo Zero 综述

AlphaGo Zero 作为Deepmind在围棋领域的最后一代AI Agent,已经可以达到棋类游戏的终极目标:在只给定游戏规则的情况下,AI 棋手从最初始的随机状态开始,通过不断的自我对弈的强化学习来实现超越以往任何人类棋手和上一代Alpha的能力,并且同样的算法和模型应用到了其他棋类也得出相同的效果。这一篇,从原理上来解析AlphaGo Zero的运行方式。

AlphaGo Zero 算法由三种元素构成:强化学习(RL)、深度学习(DL)和蒙特卡洛树搜索(MCTS,Monte Carlo Tree Search)。核心思想是基于神经网络的Policy Iteration强化学习,即最终学的是一个深度学习的policy network,输入是某棋盘局面 s,输出是此局面下可走位的概率分布:\(p(a|s)\)

在第一代AlphaGo算法中,这个初始policy network通过收集专业人类棋手的海量棋局训练得来,再采用传统RL 的Monte Carlo Tree Search Rollout 技术来强化现有的AI对于局面落子(Policy Network)的判断。Monte Carlo Tree Search Rollout 简单说来就是海量棋局模拟,AI Agent在通过现有的Policy Network策略完成一次从某局面节点到最终游戏胜负结束的对弈,这个完整的对弈叫做rollout,又称playout。完成一次rollout之后,通过局面树层层回溯到初始局面节点,并在回溯过程中同步修订所有经过的局面节点的统计指标,修正原先policy network对于落子导致输赢的判断。通过海量并发的棋局模拟来提升基准policy network,即在各种局面下提高好的落子的\(p(a_{win}|s)\),降低坏的落子的\(p(a_{lose}|s)\)

举例如下井字棋局面:
局面s

基准policy network返回 p(s) 如下 \[ p(a|s) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 0.1, & & a = (0,2) \\ 0.05, & & a = (1,0) \\ 0.5, & & a = (1,1) \\ 0.05, & & a = (1,2)\\ 0.2, & & a = (2,0) \\ 0.05, & & a = (2,1) \\ 0.05, & & a = (2,2) \end{array} \right. \end{align*} \] 通过海量并发模拟后,修订成如下的action概率分布,然后通过policy iteration迭代新的网络来逼近 \(p'\) 就提高了棋力。 \[ p'(a|s) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 0, & & a = (0,2) \\ 0, & & a = (1,0) \\ 0.9, & & a = (1,1) \\ 0, & & a = (1,2)\\ 0, & & a = (2,0) \\ 0, & & a = (2,1) \\ 0.1, & & a = (2,2) \end{array} \right. \end{align*} \]

蒙特卡洛树搜索(MCTS)概述

Monte Carlo Tree Search 是Monte Carlo 在棋类游戏中的变种,棋类游戏的一大特点是可以用动作(move)联系的决策树来表示,树的节点数量取决于分支的数量和树的深度。MCTS的目的是在树节点非常多的情况下,通过实验模拟(rollout, playout)的方式来收集尽可能多的局面输赢情况,并基于这些统计信息,将搜索资源的重点均衡地放在未被探索的节点和值得探索的节点上,减少在大概率输的节点上的模拟资源投入。传统MCTS有四个过程:Selection, Expansion, Simulation 和Backpropagation。下图是Wikipedia 的例子:

  • Selection:从根节点出发,根据现有统计的信息和selection规则,选择子节点递归向下做决定,后面我们会详细介绍AlphaGo的UCB规则。图中节点的数字,例如根节点11/21,分别代表赢的次数和总模拟次数。从根节点一路向下分别选择节点 7/10, 1/6直到叶子节点3/3,叶子节点表示它未被探索过。
  • Expansion:由于3/3节点未被探索过,初始化其所有子节点为0/0,图中3/3只有一个子节点。后面我们会看到神经网络在初始化子节点的时候起到的指导作用,即所有子节点初始权重并非相同,而是由神经网络给出估计。
  • Simulation:重复selection和expansion,根据游戏规则递归向下直至游戏结束。
  • Backpropagation:游戏结束在终点节点产生游戏真实的价值,回溯向上调整所有父节点的统计状态。

权衡 Exploration 和 Exploitation

在不断扩张决策树并收集节点统计信息的同时,MCTS根据规则来权衡探索目的(采样不足)或利用目的来做决策,这个权衡规则叫做Upper Confidence Bound(UCB)。典型的UCB公式如下:w表示通过节点的赢的次数,n表示通过节点的总次数,N是父节点的访问次数,c是调节Exploration 和 Exploitation权重的超参。

\[ {\frac{w_i}{n_i}} + c \sqrt{\frac{\ln N_i}{n_i}} \]

假设某节点有两个子节点s1, s2,它们的统计指标为 s1: w/n = 3/4,s2: w/n = 6/8,由于两者输赢比率一样,因此根据公式,访问次数少的节点出于Exploration的目的胜出,MCTS最终决定从s局面走向s1。

从第一性原理来理解AlphaGo Zero

前一代的AlphaGo已经战胜了世界冠军,取得了空前的成就,AlphaGo Zero 的设计目标变得更加General,去除围棋相关的处理和知识,用统一的框架和算法来解决棋类问题。 1. 无人工先验数据

改进之前需要专家棋手对弈数据来冷启动初始棋力

  1. 无特定游戏特征工程

    无需围棋特定技巧,只包含下棋规则,可以适用到所有棋类游戏

  2. 单一神经网络

    统一Policy Network和Value Network,使用一个共享参数的双头神经网络

  3. 简单树搜索

    去除传统MCTS的Rollout 方式,用神经网络来指导MCTS更有效产生搜索策略

搜索空间的两个优化原则

尽管理论上围棋是有解的,即先手必赢、被逼平或必输,通过遍历所有可能局面可以求得解。同理,通过海量模拟所有可能游戏局面,也可以无限逼近所有局面下的真实输赢概率,直至收敛于局面落子的确切最佳结果。但由于围棋棋局的数目远远大于宇宙原子数目,3^361 >> 10^80,因此需要将计算资源有效的去模拟值得探索的局面,例如对于显然的被动局面减小模拟次数,所以如何有效地减小搜索空间是AlphaGo Zero 需要解决的重大问题。David Silver 在Deepmind AlphaZero - Mastering Games Without Human Knowledge中提到AlphaGo Zero 采用两个原则来有效减小搜索空间。

原则1: 通过Value Network减少搜索的深度

Value Network 通过预测给定局面的value来直接预测最终结果,思想和上一期Minimax DP 策略中直接缓存当前局面的胜负状态一样,减少每次必须靠模拟到最后才能知道当前局面的输赢概率,或者需要多层树搜索才能知道输赢概率。

原则2: 通过Policy Network减少搜索的宽度

搜索广度的减少是由Policy Network预估来达成的,将下一步搜索局限在高概率的动作上,大幅度提升原先MCTS新节点生成后冷启动的搜索宽度。

神经网络结构

AlphaGo Zero 使用一个单一的深度神经网络来完成policy 和value的预测。具体实现方式是将policy network和value network合并成一个共享参数 $ $ 的双头网络。其中z是真实游戏结局的效用,范围为[-1, 1] 。

\[ (p, v)=f_{\theta}(s) \] \[ p_{a}=\operatorname{Pr}(a \mid s) \] \[ v = \mathop{\mathbb{E}}[z|s] \]

Monte Carlo Tree Search (MCTS) 建立了棋局搜索树,节点的初始状态由神经网络输出的p和v值来估计,由此初始的动作策略和价值预判就会建立在高手的水平之上。模拟一局游戏之后向上回溯,会同步更新路径上节点的统计数值并生成更好的MCTS搜索策略 \(\vec{\pi}\)。进一步来看,MCTS和神经网络互相形成了正循环。神经网络指导了未知节点的MCTS初始搜索策略,产生自我对弈游戏结局后,通过减小 \(\vec{p}\)\(\vec{\pi}\)的 Loss ,最终又提高了神经网络对于局面的估计能力。神经网络value network的提升也是通过不断减小网络预测的结果和最终结果的差异来提升。 因此,具体神经网络的Loss函数由三部分组成,value network的损失,policy network的损失以及正则项。 \[ l=\sum_{t}\left(v_{\theta}\left(s_{t}\right)-z_{t}\right)^{2}-\vec{\pi}_{t} \cdot \log \left(\vec{p}_{\theta}\left(s_{t}\right)\right) + c {\lVert \theta \rVert}^2 \]

AlphaGo Zero MCTS 具体过程

AlphaGo Plays Games Against Itself

AlphaGo Zero的MCTS和传统MCTS都有相似的四个过程,但AlphaGo Zero的MCTS步骤相对更复杂。 首先,除了W/N统计指标之外,AlphaGo Zero的MCTS保存了决策边 a|s 的Q(s,a):Action Value,也就是Q-Learning中的Q值,其初始值由神经网络给出。此外,Q 值也用于串联自底向上更新节点的Value值。具体说来,当某个新节点被Explore后,会将网络给出的Q值向上传递,并逐层更新父节点的Q值。当游戏结局产生时,也会向上更新所有父节点的Q值。 此外对于某一游戏局面s进行多次模拟,每次在局面s出发向下探索,每次探索在已知节点按Selection规则深入一步,直至达到未探索的局面或者游戏结束,产生Q值后向上回溯到最初局面s,回溯过程中更新路径上的局面的统计值或者Q值。在多次模拟结束后根据Play的算法,决定局面s的下一步行动。尽管每次模拟探索可能会深入多层,但最终play阶段的算法规则仅决定给定局面s的下一层落子动作。多次向下探索的优势在于:

  1. 探索和采样更多的叶子节点,在更多信息下做决策。

  2. 通过average out多次模拟下一层落子决定,尽可能提升MCTS策略的下一步判断能力,提高 \(\pi\) 能力,更有效指导神经网络,提高其学习效率。

New Policy Network V' is Trained to Predict Winner
  1. Selection:

从游戏局面s开始,选择a向下递归,直至未展开的节点(搜索树中的叶子节点)或者游戏结局。具体在局面s下选择a的规则由以下UCB(Upper Confidence Bound)决定
\[ a=\operatorname{argmax}_a(Q(s,a) + u(s,a)) \]

其中,Q(s,a) 和u(s,a) 项分别代表Exploitation 和Exploration。两项相加来均衡Exploitation和Exploration,保证初始时每个节点被explore,在有足够多的信息时逐渐偏向exploitation。

\[ u(s, a)=c_{p u c t} \cdot P(s, a) \cdot \frac{\sqrt{\Sigma_{b} N(s, b)}}{1+N(s, a)} \]

  1. Expand

当遇到一个未展开的节点(搜索树中的叶子节点)时,对其每个子节点使用现有网络进行预估,即

\[ (p(s), v(s))=f_{\theta}(s) \]

  1. Backup

当新的叶子节点展开时或者到达终点局面时,向上更新父节点的Q值,具体公式为 \[ Q(s, a)=\frac{1}{N(s, a)} \sum_{s^{\prime} \mid s, a \rightarrow s^{\prime}} V\left(s^{\prime}\right) \]

  1. Play

多次模拟结束后,使用得到搜索概率分布 ${a} $来确定最终的落子动作。正比于访问次数的某次方 $ {a} N(s, a)^{1 / }\(,其中\)$为温度参数(temperature parameter)。

New Policy Network V' is Trained to Predict Winner

参考资料

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×