#Leetcode

在这篇文章中,我们从一道LeetCode 470 题目出发,通过系统地思考,引出拒绝采样(Reject Sampling)的概念,并探索比较三种拒绝采样地解法;接着借助状态转移图来定量计算采样效率;最后,我们利用同样的方法来解一道稍微复杂些的经典抛硬币求期望的统计面试题目。

Leetcode 470 用 Rand7() 实现 Rand10()

已有方法 rand7 可生成 1 到 7 范围内的均匀随机整数,试写一个方法 rand10 生成 1 到 10 范围内的均匀随机整数。

不要使用系统的 Math.random() 方法。

思考

  • rand7()调用次数的 期望值 是多少 ?

  • 你能否尽量少调用 rand7() ?

思维过程

我们已有 rand7() 等概率生成了 [1, 7] 中的数字,我们需要等概率生成 [1, 10] 范围内的数字。第一反应是调用一次rand7() 肯定是不够的,因为覆盖的范围不够。那么,就需要至少2次调用 rand7() 才能生成一次 rand10(),但是还要保证 [1, 10] 的数字生成概率相等,这个是难点。 现在我们先来考虑反问题,给定rand10() 生成 rand7()。这个应该很简单,调用一次 rand10() 得到 [1, 10],如果是 8, 9, 10 ,则丢弃,重新开始,否则返回。想必大家都能想到这个朴素的方法,这种思想就是统计模拟中的拒绝采样(Reject Sampling)。

有了上面反问题的思考,我们可能会想到,rand7() 可以生成 rand5(),覆盖 [1, 5]的范围,如果将区间 [1, 10] 分成两个5个值的区间 [1, 5] 和 [6, 10],那么 rand7() 可以通过先等概率选择区间 [1, 5] 或 [6, 10],再通过rand7() 生成 rand5()就可以了。这个问题就等价于先用 rand7() 生成 rand2(),决定了 [1, 5] 还是 [6, 10],再通过rand7() 生成 rand5() 。

解法一:rand2() + rand5()

我们来实现这种解法。下图为调用两次 rand7() 生成 rand10 数值的映射关系:横轴表示第一次调用,1,2,3决定选择区间 [1, 5] ,4,5,6选择区间 [6, 10]。灰色部分表示结果丢弃,重新开始 (注,若第一次得到7无需再次调用 rand7())。

有了上图,我们很容易写出如下 AC 代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# AC
# Runtime: 408 ms, faster than 23.80% of Python3 online submissions for Implement Rand10() Using Rand7().
# Memory Usage: 16.7 MB, less than 90.76% of Python3 online submissions for Implement Rand10() Using Rand7().
class Solution:
def rand10(self):
while True:
a = rand7()
if a <= 3:
b = rand7()
if b <= 5:
return b
elif a <= 6:
b = rand7()
if b <= 5:
return b + 5

标准解法:rand49()

从提交的结果来看,第一种解法慢于多数解法。原因是我们的调用 rand7() 的采样效率比较低,第一次有 1/7 的概率结果丢弃,第二次有 2/7的概率被丢弃。

如何在第一种解法的基础上提高采样效率呢?直觉告诉我们一种做法是降低上述 7x7 表格中灰色格子的面积。此时,会想到我们通过两次 rand7() 已经构建出来 rand49()了,那么再生成 rand10() 也规约成基本问题了。

下图为 rand49() 和 rand10() 的数字对应关系。

实现代码比较简单。注意,while True 可以去掉,用递归来代替。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# AC
# Runtime: 376 ms, faster than 54.71% of Python3 online submissions for Implement Rand10() Using Rand7().
# Memory Usage: 16.9 MB, less than 38.54% of Python3 online submissions for Implement Rand10() Using Rand7().
class Solution:
def rand10(self):
while True:
a, b = rand7(), rand7()
num = (a - 1) * 7 + b
if num <= 40:
return num % 10 + 1

更快的做法

上面的提交结果发现标准解法在运行时间上有了不少提高,处于中等位置。我们继续思考,看看能否再提高采样效率。

观察发现,rand49() 有 9/49 的概率,生成的值被丢弃,原因是 [41, 49] 只有 9 个数,不足10个。倘若此时能够将这种状态保持下去,那么只需再调用一次 rand7() 而不是从新开始情况下至少调用两次 rand7(), 就可以得到 rand10()了。也就是说,当 rand49() 生成了 [41, 49] 范围内的数的话等价于我们先调用了一次 rand9(),那么依样画葫芦,我们接着调用 rand7() 得到了 rand63()。63 分成了6个10个值的区间后,剩余 3 个数。此时,又等价于 rand3(),循环往复,调用了 rand7() 得到了 rand21(),最后若rand21() 不幸得到21,等价于 rand1(),此时似乎我们走投无路,只能回到最初的状态,一切从头再来了。

改进算法代码如下。注意这次击败了 92.7%的提交。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# AC
# Runtime: 344 ms, faster than 92.72% of Python3 online submissions for Implement Rand10() Using Rand7().
# Memory Usage: 16.7 MB, less than 90.76% of Python3 online submissions for Implement Rand10() Using Rand7().
class Solution:
def rand10(self):
while True:
a, b = rand7(), rand7()
num = (a - 1) * 7 + b
if num <= 40: return num % 10 + 1
a = num - 40
b = rand7()
num = (a - 1) * 7 + b
if num <= 60: return num % 10 + 1
a = num - 60
b = rand7()
num = (a - 1) * 7 + b
if num <= 20: return num % 10 + 1

采样效率计算

通过代码提交的结果和大致的分析,我们已经知道三个解法在采样效率依次变得更快。现在我们来定量计算这三个解法。

我们考虑生成一个 rand10() 平均需要调用多少次 rand7(),作为采样效率的标准。

一种思路是可以通过模拟方法,即将上述每个解法模拟多次,然后用总的 rand7() 调用次数除以 rand10() 的生成次数即可。下面以解法三为例写出代码

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# The rand7() API is already defined for you.
rand7_c = 0
rand10_c = 0

def rand7():
global rand7_c
rand7_c += 1
import random
return random.randint(1, 7)

def rand10():
global rand10_c
rand10_c += 1
while True:
a, b = rand7(), rand7()
num = (a - 1) * 7 + b
if num <= 40: return num % 10 + 1
a = num - 40 # [1, 9]
b = rand7()
num = (a - 1) * 7 + b # [1, 63]
if num <= 60: return num % 10 + 1
a = num - 60 # [1, 3]
b = rand7()
num = (a - 1) * 7 + b # [1, 21]
if num <= 20: return num % 10 + 1

if __name__ == '__main__':
while True:
rand10()
print(f'{rand10_c} {round(rand7_c/rand10_c, 2)}')

运行代码,发现解法三的采样效率稳定在 2.19。

采样效率精确计算

计算解法二

为了精确计算三个解法的采样效率,我们通过代码得到对应的状态转移图来帮助计算。

例如,解法一可以对应到下图:初始状态 Start 节点中的 +2 表示经过此节点会产生 2次 rand7() 的代价。从 Start 节点有 40/49 的概率到达被接受状态 AC,有 9/49 概率到达拒绝状态 REJ。REJ 需要从头开始,则用虚线表示重新回到 Start节点,也就是说 REJ 的代价等价于 Start。注意,从某个节点出发的所有边之和为1。

有了上述状态转移关系图,我们令初始状态的平均代价为 \(C_2\),则可以写成递归表达式,因为其中 REJ 的代价就是 \(C_2\),即

\[ C_2 = 2 + (\frac{40}{49}\cdot0 + \frac{9}{49} C_2) \]

解得 \(C_2\)

\[ C_2 = 2.45 \]

计算解法一

同样的,对于另外两种解法,虽然略微复杂,也可以用同样的方法求得。

解法一的状态转移图为

递归方程表达式为

\[ C_1 = 1+\frac{3}{7} \cdot (1+\frac{5}{7} \cdot 0 + \frac{2}{7} \cdot C_1) \cdot2+ \frac{1}{7} \cdot (C_1) \]

解得 \(C_1\)

\[ C_1 = \frac{91}{30} \approx 3.03 \]

计算解法三

最快的解法三状态转移图为

递归方程表达式为

\[ C_3 = 2+\frac{40}{49} \cdot 0 + \frac{9}{49} (1+\frac{60}{63} \cdot 0 + \frac{3}{63} \cdot (1+\frac{20}{21} \cdot 0 + \frac{1}{21} \cdot C_3)) \]

解得 \(C_3\) \[ C_3 = \frac{329}{150} \approx 2.193 \]

至此总结一下,三个解法的平均代价为 \[ C_1 \approx 3.03 > C_2 = 2.45 > C_3 \approx 2.193 \] 这些值和我们通过模拟得到的结果一致。

稍难些的经典概率求期望题目

至此,LeetCode 470 我们已经分析透彻。现在,我们已经可以很熟练的将此类拒绝采样的问题转变成有概率的状态转移图,再写成递推公式去求平均采样的代价(即期望)。这里,如果大家感兴趣的话不妨再来看一道略微深入的经典统计概率求期望的题目。

问题:给定一枚抛正反面概率一样的硬币,求连续抛硬币直到两个正面(正面记为H,两个正面HH)的平均次数。例如:HTTHH是一个连续次数为5的第一次出现HH的序列。

分析问题画出状态转移图:我们令初始状态下得到第一个HH的平均长度记为 S,那么下一次抛硬币有 1/2 机会是 T,此时状态等价于初始状态,另有 1/2 机会是 H,我们记这个状态下第一次遇见HH的平均长度为 H(下图蓝色节点)。从此蓝色节点出发,当下一枚硬币是H则结束,是T是返回初始状态。于是构建出下图。

这个问题稍微复杂的地方在于我们有两个未知状态互相依赖,但问题的本质和方法是一样的,分别从 S 和 H 出发考虑状态的概率转移,可以写成如下两个方程式:

\[ \left\{ \begin{array}{c} S =&\frac{1}{2} \cdot(1+H) + \frac{1}{2} \cdot(1+S) \\ H =&\frac{1}{2} \cdot 1 + \frac{1}{2} \cdot(1+S) \end{array} \right. \]

解得

\[ \left\{ \begin{array}{c} H= 4 \\ S = 6 \end{array} \right. \]

因此,平均下来,需要6次抛硬币才能得到 HH,这个是否和你直觉的猜测一致呢?

这个问题还可以有另外一问,可以作为思考题让大家来练习一下:第一次得到 HT 的平均次数是多少?这个是否和 HH 一样呢?

Leetcode 1029. 两地调度 (medium)

公司计划面试 2N 人。第 i 人飞往 A 市的费用为 costs[i][0],飞往 B 市的费用为 costs[i][1]。

返回将每个人都飞到某座城市的最低费用,要求每个城市都有 N 人抵达。 

示例:

输入:[[10,20],[30,200],[400,50],[30,20]] 输出:110 解释: 第一个人去 A 市,费用为 10。 第二个人去 A 市,费用为 30。 第三个人去 B 市,费用为 50。 第四个人去 B 市,费用为 20。 最低总费用为 10 + 30 + 50 + 20 = 110,每个城市都有一半的人在面试。

提示:

1 <= costs.length <= 100 costs.length 为偶数 1 <= costs[i][0], costs[i][1] <= 1000

链接:https://leetcode-cn.com/problems/two-city-scheduling

暴力枚举法

最直接的方式是暴力枚举出所有分组的可能。因为 2N 个人平均分成两组,总数为 \({2n \choose n}\),是 n 的指数级数量。在文章24 点游戏算法题的 Python 函数式实现: 学用itertools,yield,yield from 巧刷题,我们展示如何调用 Python 的 itertools包,这里,我们也用同样的方式产生 [0, 2N] 的所有集合大小为N的可能(保存在left_set_list中),再遍历找到最小值即可。当然,这种解法会TLE,只是举个例子来体会一下暴力做法。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import math
from typing import List

class Solution:
def twoCitySchedCost(self, costs: List[List[int]]) -> int:
L = range(len(costs))
from itertools import combinations
left_set_list = [set(c) for c in combinations(list(L), len(L)//2)]

min_total = math.inf
for left_set in left_set_list:
cost = 0
for i in L:
is_left = 1 if i in left_set else 0
cost += costs[i][is_left]
min_total = min(min_total, cost)

return min_total

O(N) AC解法

对于组合优化问题来说,例如TSP问题(解法链接 TSP问题从DP算法到深度学习1:递归DP方法 AC AIZU TSP问题),一般都是 NP-Hard问题,意味着没有多项次复杂度的解法。但是这个问题比较特殊,它增加了一个特定条件:去城市A和城市B的人数相同,也就是我们已经知道两个分组的数量是一样的。我们仔细思考一下这个意味着什么?考虑只有四个人的小规模情况,如果让你来手动规划,你一定不会穷举出所有两两分组的可能,而是比较人与人相对的两个城市的cost差。举个例子,有如下四个人的costs

1
2
3
4
0 A:3,  B:1
1 A:99, B:100
2 A:2, B:2
3 A:3, B:3
虽然1号人去城市A(99)cost 很大,但是相比较他去B(100)来说,可以省下 100-99 = 1 的钱,这个钱比0号人去B不去A省下的钱 3-1 = 2 还要多,因此你一定会选择让1号人去A而让0号人去B。

有了这个想法,再整理一下,就会发现让某人去哪个城市和他去两个城市的cost 差 $ C_a - C_b$相关,如果这个值越大,那么他越应该去B。但是最后决定他是否去B取决于他的差在所有人中的排名,由于两组人数相等,因此差能大到排在前一半,则他就去B,在后一半就去A。

按照这个思路,很快能写出代码,代码写法有很多,下面略举一例。代码中由于用到排序,复杂度为 \(O(N \cdot \log(N))\) 。这里补充一点,理论上只需找数组中位数的值即可,最好的时间复杂度是 \(O(N)\)

代码实现上,cost_diff_list 将每个人的在原数组的index 和他的cost差组成 pair。即

1
[(0, cost_0), (1, cost_1), ... ]

这样我们可以将这个数组按照cost排序,排完序后前面N个元素属于B城市。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# AC
# Runtime: 36 ms, faster than 87.77% of Python3 online submissions
# Memory Usage: 14.5 MB, less than 14.84% of Python3 online
from typing import List

class Solution:
def twoCitySchedCost(self, costs: List[List[int]]) -> int:
L = range(len(costs))
cost_diff_lst = [(i, costs[i][0] - costs[i][1]) for i in L]
cost_diff_lst.sort(key=lambda x: x[1])

total_cost = 0
for c, (idx, _) in enumerate(cost_diff_lst):
is_left = 0 if c < len(L) // 2 else 1
total_cost += costs[idx][is_left]

return total_cost

转换成整数规划问题

这个问题对于略有算法经验的人来说,很类似于背包问题。它们都需要回答N个物品取或者不取,并同时最大最小化总cost。区别在它们的约束条件不一样。这道题的约束是去取(去城市A)和不取(去城市B)的数量一样。这一类问题即 integer programming,即整数规划。下面我们选取两个比较流行的优化库来展示如何调包解这道题。

首先我们先来formulate这个问题,因为需要表达两个约束条件,我们将每个人的状态分成是否去A和是否去B两个变量。

1
2
x[i-th-person][0]: boolean 表示是否去 city a
x[i-th-person][1]: boolean 表示是否去 city b

这样,问题转换成如下优化模型

\[ \begin{array}{rrclcl} \displaystyle \min_{x} & costs[i][0] \cdot x[i][0] + costs[i][1] \cdot x[i][1]\\ \textrm{s.t.} & x[i][0] + x[i][1] =1\\ &x[i][0] + x[i][1] + ... =N \\ \end{array} \]

Google OR-Tools

Google OR-Tools 是业界最好的优化库,下面为调用代码,由于直接对应于上面的数学优化问题,不做赘述。当然 Leetcode上不支持这些第三方的库,肯定也不能AC。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from ortools.sat.python import cp_model

costs = [[515,563],[451,713],[537,709],[343,819],[855,779],[457,60],[650,359],[631,42]]

I = range(len(costs))

model = cp_model.CpModel()
x = []
total_cost = model.NewIntVar(0, 10000, 'total_cost')
for i in I:
t = []
for j in range(2):
t.append(model.NewBoolVar('x[%i,%i]' % (i, j)))
x.append(t)

# Constraints
# Each person must be assigned to at exact one city
[model.Add(sum(x[i][j] for j in range(2)) == 1) for i in I]
# equal number of person assigned to two cities
model.Add(sum(x[i][0] for i in I) == (len(I) // 2))

# Total cost
model.Add(total_cost == sum(x[i][0] * costs[i][0] + x[i][1] * costs[i][1] for i in I))
model.Minimize(total_cost)

solver = cp_model.CpSolver()
status = solver.Solve(model)

if status == cp_model.OPTIMAL:
print('Total min cost = %i' % solver.ObjectiveValue())
print()
for i in I:
for j in range(2):
if solver.Value(x[i][j]) == 1:
print('People ', i, ' assigned to city ', j, ' Cost = ', costs[i][j])

完整代码可以从我的github下载。

https://github.com/MyEncyclopedia/leetcode/blob/master/1029_Two_City_Scheduling/1029_ortool.py

PuLP

类似的,另一种流行 python 优化库 PuLP 的代码为

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pulp

costs = [[259,770],[448,54],[926,667],[184,139],[840,118],[577,469]] # 1859


I = range(len(costs))

items=[i for i in I]
city_a = pulp.LpVariable.dicts('left', items, 0, 1, pulp.LpBinary)
city_b = pulp.LpVariable.dicts('right', items, 0, 1, pulp.LpBinary)

m = pulp.LpProblem("Two Cities", pulp.LpMinimize)

m += pulp.lpSum((costs[i][0] * city_a[i] + costs[i][1] * city_b[i]) for i in items)

# Constraints
# Each person must be assigned to at exact one city
for i in I:
m += pulp.lpSum([city_a[i] + city_b[i]]) == 1
# create a binary variable to state that a table setting is used
m += pulp.lpSum(city_a[i] for i in I) == (len(I) // 2)

m.solve()

total = 0
for i in I:
if city_a[i].value() == 1.0:
total += costs[i][0]
else:
total += costs[i][1]

print("Total cost {}".format(total))

代码地址为

https://github.com/MyEncyclopedia/leetcode/blob/master/1029_Two_City_Scheduling/1029_pulp.py

快速幂运算是一种利用位运算和DP思想求的\(x^n\)的数值算法,它将时间复杂度\(O(n)\)降到\(O(log(n))\)。快速幂运算结合矩阵乘法,可以巧解不少DP问题。本篇会由浅入深,从最基本的快速幂运算算法,到应用矩阵快速幂运算解DP问题,结合三道Leetcode题目来具体讲解。

Leetcode 50. Pow(x, n) (Medium)

Leetcode 50. Pow(x, n) 是实数的快速幂运算问题,题目如下。

Implement pow(x, n), which calculates x raised to the power n (i.e. \(x^n\)).

Example 1:

1
2
Input: x = 2.00000, n = 10
Output: 1024.00000

Example 2:

1
2
Input: x = 2.10000, n = 3
Output: 9.26100

Example 3:

1
2
3
Input: x = 2.00000, n = -2
Output: 0.25000
Explanation: 2-2 = 1/22 = 1/4 = 0.25

快速幂运算解法分析

假设n是32位的int类型,将n写成二进制形式,那么n可以写成最多32个某位为 1(第k位为1则值为\(2^k\))的和。那么\(x^n\)最多可以由32个 \(x^{2^k}\)的乘积组合,例如:

\[ x^{\text{10011101}_{2}} = x^{1} \times x^{\text{100}_{2}} \times x^{\text{1000}_{2}} \times x^{\text{10000}_{2}} \times x^{\text{10000000}_{2}} \]

快速幂运算的特点就是通过32次循环,每次循环根据上轮\(x^{2^k}\)的值进行平方后得出这一轮的值:\(x^{2^k} \times x^{2^k} = x^{2^{k+1}}\),即循环计算出如下数列

\[ x^{1}, x^2=x^{\text{10}_{2}}, x^4=x^{\text{100}_{2}}, x^8=x^{\text{1000}_{2}}, x^{16}=x^{\text{10000}_{2}}, ..., x^{128} = x^{\text{10000000}_{2}} \]

在循环时,如果n的二进制形式在本轮对应的位的值是1,则将这次结果累乘计入最终结果。

下面是python 3 的代码,由于循环为32次,所以容易看出算法复杂度为 \(O(log(n))\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# AC
# Runtime: 32 ms, faster than 54.28% of Python3 online submissions for Pow(x, n).
# Memory Usage: 14.2 MB, less than 5.04% of Python3 online submissions for Pow(x, n).

class Solution:
def myPow(self, x: float, n: int) -> float:
ret = 1.0
i = abs(n)
while i != 0:
if i & 1:
ret *= x
x *= x
i = i >> 1
return 1.0 / ret if n < 0 else ret

对应的 Java 的代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// AC
// Runtime: 1 ms, faster than 42.98% of Java online submissions for Pow(x, n).
// Memory Usage: 38.7 MB, less than 48.31% of Java online submissions for Pow(x, n).

class Solution {
public double myPow(double x, int n) {
double ret = 1.0;
long i = Math.abs((long) n);
while (i != 0) {
if ((i & 1) > 0) {
ret *= x;
}
x *= x;
i = i >> 1;
}

return n < 0 ? 1.0 / ret : ret;
}
}

矩阵快速幂运算

快速幂运算也可以应用到计算矩阵的幂,即上面的x从实数变为方形矩阵。实现上,矩阵的幂需要矩阵乘法:$ A_{r c} B_{c p}$ ,Python中可以用numpy的 np.matmul(A, B)来完成,而Java版本中我们手动实现简单的矩阵相乘算法,从三重循环看出其算法复杂度为\(O(r \times c \times p)\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
public int[][] matrixProd(int[][] A, int[][] B) {
int R = A.length;
int C = B[0].length;
int P = A[0].length;
int[][] ret = new int[R][C];
for (int r = 0; r < R; r++) {
for (int c = 0; c < C; c++) {
for (int p = 0; p < P; p++) {
ret[r][c] += A[r][p] * B[p][c];
}
}
}
return ret;
}

Leetcode 509. Fibonacci Number (Easy)

有了快速矩阵幂运算,我们来看看如何具体解题。Fibonacci问题作为最基本的DP问题,在上一篇Leetcode 679 24 Game 的 Python 函数式实现中我们用python独有的yield来巧解,这次再拿它来做演示。

The Fibonacci numbers, commonly denoted F(n) form a sequence, called the Fibonacci sequence, such that each number is the sum of the two preceding ones, starting from 0 and 1. That is,

1
2
F(0) = 0,   F(1) = 1
F(N) = F(N - 1) + F(N - 2), for N > 1.

Given N, calculate F(N).

Example 1:

1
2
3
Input: 2
Output: 1
Explanation: F(2) = F(1) + F(0) = 1 + 0 = 1.

Example 2:

1
2
3
Input: 3
Output: 2
Explanation: F(3) = F(2) + F(1) = 1 + 1 = 2.

Example 3:

1
2
3
Input: 4
Output: 3
Explanation: F(4) = F(3) + F(2) = 2 + 1 = 3.

转换为矩阵幂运算

Fibonacci的二阶递推式如下:

\[ \begin{align*} F(n) =& F(n-1) + F(n-2) \\ F(n-1) =& F(n-1) \end{align*} \]

等价的矩阵递推形式为:

\[ \begin{bmatrix}F(n)\\F(n-1)\end{bmatrix} = \begin{bmatrix}1 & 1\\1 & 0\end{bmatrix} \begin{bmatrix}F(n-1)\\F(n-2)\end{bmatrix} \]

也就是每轮左乘一个2维矩阵。其循环形式为,即矩阵幂的形式:

\[ \begin{bmatrix}F(n)\\F(n-1)\end{bmatrix} = \begin{bmatrix}1 & 1\\1 & 0\end{bmatrix}^{n-1} \begin{bmatrix}F(1)\\F(0)\end{bmatrix} \]

AC代码

有了上面的矩阵幂公式,代码稍作改动即可。Java 版本代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/**
* AC
* Runtime: 0 ms, faster than 100.00% of Java online submissions for Fibonacci Number.
* Memory Usage: 37.9 MB, less than 18.62% of Java online submissions for Fibonacci Number.
*
* Method: Matrix Fast Power Exponentiation
* Time Complexity: O(log(N))
**/
class Solution {
public int fib(int N) {
if (N <= 1) {
return N;
}
int[][] M = {{1, 1}, {1, 0}};
// powers = M^(N-1)
N--;
int[][] powerDouble = M;
int[][] powers = {{1, 0}, {0, 1}};
while (N > 0) {
if (N % 2 == 1) {
powers = matrixProd(powers, powerDouble);
}
powerDouble = matrixProd(powerDouble, powerDouble);
N = N / 2;
}

return powers[0][0];
}

public int[][] matrixProd(int[][] A, int[][] B) {
int R = A.length;
int C = B[0].length;
int P = A[0].length;
int[][] ret = new int[R][C];
for (int r = 0; r < R; r++) {
for (int c = 0; c < C; c++) {
for (int p = 0; p < P; p++) {
ret[r][c] += A[r][p] * B[p][c];
}
}
}
return ret;
}

}

Python 3的numpy.matmul() 版本代码。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# AC
# Runtime: 256 ms, faster than 26.21% of Python3 online submissions for Fibonacci Number.
# Memory Usage: 29.4 MB, less than 5.25% of Python3 online submissions for Fibonacci Number.

class Solution:

def fib(self, N: int) -> int:
if N <= 1:
return N

import numpy as np
F = np.array([[1, 1], [1, 0]])

N -= 1
powerDouble = F
powers = np.array([[1, 0], [0, 1]])
while N > 0:
if N % 2 == 1:
powers = np.matmul(powers, powerDouble)
powerDouble = np.matmul(powerDouble, powerDouble)
N = N // 2

return powers[0][0]

或者也可以直接调用numpy.matrix_power() 代替手动的快速矩阵幂运算。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# AC
# Runtime: 116 ms, faster than 26.25% of Python3 online submissions for Fibonacci Number.
# Memory Usage: 29.2 MB, less than 5.25% of Python3 online submissions for Fibonacci Number.

class Solution:

def fib(self, N: int) -> int:
if N <= 1:
return N

from numpy.linalg import matrix_power
import numpy as np
F = np.array([[1, 1], [1, 0]])
F = matrix_power(F, N - 1)

return F[0][0]

Leetcode 1411. Number of Ways to Paint N × 3 Grid (Hard)

下面来看一道稍难一点的DP问题,1411. Number of Ways to Paint N × 3 Grid

You have a grid of size n x 3 and you want to paint each cell of the grid with exactly one of the three colours: Red, Yellow or Green while making sure that no two adjacent cells have the same colour (i.e no two cells that share vertical or horizontal sides have the same colour).

You are given n the number of rows of the grid.

Return the number of ways you can paint this grid. As the answer may grow large, the answer must be computed modulo 10^9 + 7.

Example 1:

1
2
3
Input: n = 1
Output: 12
Explanation: There are 12 possible way to paint the grid as shown:

Example 2:

1
2
Input: n = 2
Output: 54

Example 3:

1
2
Input: n = 3
Output: 246

Example 4:

1
2
Input: n = 7
Output: 106494

Example 5:

1
2
Input: n = 5000
Output: 30228214

标准DP解法

分析题目容易发现第i行的状态只取决于第i-1行的状态,第i行会有两种不同状态:三种颜色都有或者只有两种颜色。这个问题容易识别出是经典的双状态DP问题,那么我们定义dp2[i]为第i行只有两种颜色的数量,dp3[i]为第i行有三种颜色的数量。

先考虑dp3[i]和i-1行的关系。假设第i行包含3种颜色,即dp3[i],假设具体颜色为红,绿,黄,若i-1行包含两种颜色(即dp2[i-1]),此时dp2[i-1]只有以下2种可能:

dp2[i-1] -> dp3[i]
还是dp3[i] 红,绿,黄情况,若i-1行包含三种颜色(从dp3[i-1]转移过来),此时dp3[i-1]也只有以下2种可能:
dp3[i-1] -> dp3[i]

因此,dp3[i]= dp2[i-1] * 2 + dp3[i-1] * 2。

同理,若第i行包含两种颜色,即dp2[i],假设具体颜色为绿,黄,绿,若i-1行是两种颜色(dp2[i-1]),此时dp2[i-1]有如下3种可能:

dp2[i-1] -> dp2[i]
dp2[i]的另一种情况是由dp3[i-1]转移过来,则dp3[i-1]有2种可能,枚举如下:
dp3[i-1] -> dp2[i]

因此,dp2[i] = dp2[i-1] * 3 + dp3[i-1] * 2。 初始值dp2[1] = 6,dp3[1] = 6,最终答案为dp2[i] + dp3[i]。

很容易写出普通DP版本的Python 3代码,时间复杂度为\(O(n)\)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
# AC
# Runtime: 36 ms, faster than 98.88% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.
# Memory Usage: 13.9 MB, less than 58.66% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.

class Solution:
def numOfWays(self, n: int) -> int:
MOD = 10 ** 9 + 7
dp2, dp3 = 6, 6
n -= 1
while n > 0:
dp2, dp3 = (dp2 * 3 + dp3 * 2) % MOD, (dp2 * 2 + dp3 * 2) % MOD
n -= 1
return (dp2 + dp3) % MOD

快速矩阵幂运算解法

和Fibonacci一样,我们将DP状态转移方程转换成矩阵乘法:

\[ \begin{bmatrix}dp2(n)\\dp3(n)\end{bmatrix} = \begin{bmatrix}3 & 2\\2 & 2\end{bmatrix} \begin{bmatrix}dp2(n-1)\\dp3(n-1)\end{bmatrix} \]

代入初始值,转换成矩阵幂形式

\[ \begin{bmatrix}dp2(n)\\dp3(n)\end{bmatrix} = \begin{bmatrix}3 & 2\\2 & 2\end{bmatrix}^{n-1}\begin{bmatrix}6\\6\end{bmatrix} \]

代码几乎和Fibonacci一模一样,仅仅多了mod 计算。下面是Java版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

/**
AC
Runtime: 0 ms, faster than 100.00% of Java online submissions for Number of Ways to Paint N × 3 Grid.
Memory Usage: 35.7 MB, less than 97.21% of Java online submissions for Number of Ways to Paint N × 3 Grid.
**/

class Solution {
public int numOfWays(int n) {
long MOD = (long) (1e9 + 7);
long[][] ret = {{6, 6}};
long[][] m = {{3, 2}, {2, 2}};
n -= 1;
while(n > 0) {
if ((n & 1) > 0) {
ret = matrixProd(ret, m, MOD);
}
m = matrixProd(m, m, MOD);
n >>= 1;
}
return (int) ((ret[0][0] + ret[0][1]) % MOD);

}

public long[][] matrixProd(long[][] A, long[][] B, long MOD) {
int R = A.length;
int C = B[0].length;
int P = A[0].length;
long[][] ret = new long[R][C];
for (int r = 0; r < R; r++) {
for (int c = 0; c < C; c++) {
for (int p = 0; p < P; p++) {
ret[r][c] += A[r][p] * B[p][c];
ret[r][c] = ret[r][c] % MOD;
}
}
}
return ret;
}

}

Python 3实现为

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# AC
# Runtime: 88 ms, faster than 39.07% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.
# Memory Usage: 30.2 MB, less than 11.59% of Python3 online submissions for Number of Ways to Paint N × 3 Grid.

class Solution:
def numOfWays(self, n: int) -> int:
import numpy as np

MOD = int(1e9 + 7)
ret = np.array([[6, 6]])
m = np.array([[3, 2], [2, 2]])

n -= 1
while n > 0:
if n % 2 == 1:
ret = np.matmul(ret, m) % MOD
m = np.matmul(m, m) % MOD
n = n // 2
return int((ret[0][0] + ret[0][1]) % MOD)

Leetcode 679 24 Game (Hard)

先来介绍一下24点游戏题目,大家一定都玩过,就是给定4个牌面数字,用加减乘除计算24点。

本篇会用两种偏函数式的 Python 3解法来AC 24 Game。

Leetcode 679 24 Game (Hard) > You have 4 cards each containing a number from 1 to 9. You need to judge whether they could operated through *, /, +, -, (, ) to get the value of 24.

Example 1:

Input: [4, 1, 8, 7]

Output: True

Explanation: (8-4) * (7-1) = 24

Example 2:

Input: [1, 2, 1, 2]

Output: False

itertools.permutations

先来介绍一下Python itertools.permutations 的用法,正好用Leetcode 中的Permutation问题来示例。Permutations 的输入可以是List,返回是 generator 实例,用于生成所有排列。简而言之,python 的 generator 可以和List一样,用 for 语句来全部遍历产生的值。和List不同的是,generator 的所有值并不必须全部初始化,一般按需产生从而大量减少内存占用。下面在介绍 yield 时我们会看到如何合理构造 generator。

Leetcode 46 Permutations (Medium) > Given a collection of distinct integers, return all possible permutations.

Example:

Input: [1,2,3]

Output: [ [1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1,2], [3,2,1]]

用 permutations 很直白,代码只有一行。

{linenos
1
2
3
4
5
6
7
8
9
10
11
# AC
# Runtime: 36 ms, faster than 91.78% of Python3 online submissions for Permutations.
# Memory Usage: 13.9 MB, less than 66.52% of Python3 online submissions for Permutations.

from itertools import permutations
from typing import List


class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
return [p for p in permutations(nums)]

itertools.combinations

有了排列就少不了组合,itertools.combinations 可以产生给定List的k个元素组合 \(\binom{n}{k}\),用一道算法题来举例,同样也是一句语句就可以AC。

Leetcode 77 Combinations (Medium)

Given two integers n and k, return all possible combinations of k numbers out of 1 ... n. You may return the answer in any order.

Example 1:

Input: n = 4, k = 2

Output: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4],]

Example 2:

Input: n = 1, k = 1

Output: [[1]]

{linenos
1
2
3
4
5
6
7
8
9
# AC
# Runtime: 84 ms, faster than 95.43% of Python3 online submissions for Combinations.
# Memory Usage: 15.2 MB, less than 68.98% of Python3 online submissions for Combinations.
from itertools import combinations
from typing import List

class Solution:
def combine(self, n: int, k: int) -> List[List[int]]:
return [c for c in combinations(list(range(1, n + 1)), k)]

itertools.product

当有多维度的对象需要迭代笛卡尔积时,可以用 product(iter1, iter2, ...)来生成generator,等价于多重 for 循环。

1
2
[lst for lst in product([1, 2, 3], ['a', 'b'])]
[(i, s) for i in [1, 2, 3] for s in ['a', 'b']]

这两种方式都生成了如下结果

1
[(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b'), (3, 'a'), (3, 'b')]

再举一个Leetcode的例子来实战product generator。

Leetcode 17. Letter Combinations of a Phone Number (Medium)

Given a string containing digits from 2-9 inclusive, return all possible letter combinations that the number could represent. A mapping of digit to letters (just like on the telephone buttons) is given below. Note that 1 does not map to any letters.

Example:

Input: "23"

Output: ["ad", "ae", "af", "bd", "be", "bf", "cd", "ce", "cf"].

举例来说,下面的代码当输入 digits 是 '352' 时,iter_dims 的值是 ['def', 'jkl', 'abc'],再输入给 product 后会产生 'dja', 'djb', 'djc', 'eja', 共 3 x 3 x 3 = 27个组合的值。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# AC
# Runtime: 24 ms, faster than 94.50% of Python3 online submissions for Letter Combinations of a Phone Number.
# Memory Usage: 13.7 MB, less than 83.64% of Python3 online submissions for Letter Combinations of a Phone Number.

from itertools import product
from typing import List


class Solution:
def letterCombinations(self, digits: str) -> List[str]:
if digits == "":
return []
mapping = {'2':'abc', '3':'def', '4':'ghi', '5':'jkl', '6':'mno', '7':'pqrs', '8':'tuv', '9':'wxyz'}
iter_dims = [mapping[i] for i in digits]

result = []
for lst in product(*iter_dims):
result.append(''.join(lst))

return result

yield 示例

Python具有独特的itertools generator,可以花式AC代码,接下来讲解如何进一步构造 generator。Python 定义只要函数中使用了yield关键字,这个函数就是 generator。Generator 在计算机领域的标准名称是 coroutine,即协程,是一种特殊的函数:当返回上层调用时自身能够保存调用栈状态,并在上层函数处理完逻辑后跳入到这个 generator,恢复之前的状态再继续运行下去。Yield语句也举一道经典的Fibonacci 问题。

Leetcode 509. Fibonacci Number (Easy)

The Fibonacci numbers, commonly denoted F(n) form a sequence, called the Fibonacci sequence, such that each number is the sum of the two preceding ones, starting from 0 and 1. That is, F(0) = 0, F(1) = 1 F(N) = F(N - 1) + F(N - 2), for N > 1. Given N, calculate F(N).

Example 1:

Input: 2

Output: 1

Explanation: F(2) = F(1) + F(0) = 1 + 0 = 1.

Example 2:

Input: 3

Output: 2

Explanation: F(3) = F(2) + F(1) = 1 + 1 = 2.

Example 3:

Input: 4

Output: 3

Explanation: F(4) = F(3) + F(2) = 2 + 1 = 3.

Fibonacci 的一般标准解法是循环迭代方式,可以以O(n)时间复杂度和O(1) 空间复杂度来AC。下面的 yield 版本中,我们构造了fib_next generator,它保存了最后两个值作为内部迭代状态,外部每调用一次可以得到下一个fib(n),如此外部只需不断调用直到满足题目给定次数。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# AC
# Runtime: 28 ms, faster than 85.56% of Python3 online submissions for Fibonacci Number.
# Memory Usage: 13.8 MB, less than 58.41% of Python3 online submissions for Fibonacci Number.

class Solution:
def fib(self, N: int) -> int:
if N <= 1:
return N
i = 2
for fib in self.fib_next():
if i == N:
return fib
i += 1

def fib_next(self):
f_last2, f_last = 0, 1
while True:
f = f_last2 + f_last
f_last2, f_last = f_last, f
yield f

yield from 示例

上述yield用法之后,再来演示 yield from 的用法。Yield from 始于Python 3.3,用于嵌套generator时的控制转移,一种典型的用法是有多个generator嵌套时,外层的outer_generator 用 yield from 这种方式等价代替如下代码。

1
2
3
def outer_generator():
for i in inner_generator():
yield i

用一道算法题目来具体示例。

Leetcode 230. Kth Smallest Element in a BST (Medium)

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Example 1: Input: root = [3,1,4,null,2], k = 1

1
2
3
4
5
  3
/ \
1 4
\
2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3

1
2
3
4
5
6
7
         5
/ \
3 6
/ \
2 4
/
1
Output: 3

直觉思路上,我们只要从小到大有序遍历每个节点直至第k个。因为给定的树是Binary Search Tree,有序遍历意味着以左子树、节点本身和右子树的访问顺序递归下去就行。由于ordered_iter是generator,递归调用自己的过程就是嵌套使用generator的过程。下面是yield版本。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# AC
# Runtime: 48 ms, faster than 90.31% of Python3 online submissions for Kth Smallest Element in a BST.
# Memory Usage: 17.9 MB, less than 14.91% of Python3 online submissions for Kth Smallest Element in a BST.

class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
def ordered_iter(node):
if node:
for sub_node in ordered_iter(node.left):
yield sub_node
yield node
for sub_node in ordered_iter(node.right):
yield sub_node

for node in ordered_iter(root):
k -= 1
if k == 0:
return node.val

等价于如下 yield from 版本:

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# AC
# Runtime: 56 ms, faster than 63.74% of Python3 online submissions for Kth Smallest Element in a BST.
# Memory Usage: 17.7 MB, less than 73.33% of Python3 online submissions for Kth Smallest Element in a BST.

class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
def ordered_iter(node):
if node:
yield from ordered_iter(node.left)
yield node
yield from ordered_iter(node.right)

for node in ordered_iter(root):
k -= 1
if k == 0:
return node.val

24 点问题之函数式枚举解法

看明白了itertools.permuations,combinations,product,yield以及yield from,我们回到本篇最初的24点游戏问题。

24点游戏的本质是枚举出所有可能运算,如果有一种方式得到24返回True,否则返回Flase。进一步思考所有可能的运算,包括下面三个维度:

  1. 4个数字的所有排列,比如给定 [1, 2, 3, 4],可以用permutations([1, 2, 3, 4]) 生成这个维度的所有可能

  2. 三个位置的操作符号的全部可能,可以用 product([+, -, *, /], repeat=3) 生成,具体迭代结果为:[+, +, +],[+, +, -],...

  3. 给定了前面两个维度后,还有一个比较不容易察觉但必要的维度:运算优先级。比如在给定数字顺序 [1, 2, 3, 4]和符号顺序 [+, *, -]之后可能的四种操作树

四种运算优先级

能否算得24点只需要枚举这三个维度笛卡尔积的运算结果

(维度1:数字组合) x (维度2:符号组合) x (维度3:优先级组合)

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# AC
# Runtime: 112 ms, faster than 57.59% of Python3 online submissions for 24 Game.
# Memory Usage: 13.7 MB, less than 85.60% of Python3 online submissions for 24 Game.

import math
from itertools import permutations, product
from typing import List

class Solution:

def iter_trees(self, op1, op2, op3, a, b, c, d):
yield op1(op2(a, b), op3(c, d))
yield op1(a, op2(op3(b, c), d))
yield op1(a, op2(b, op3(c, d)))
yield op1(op2(a, op3(b, c)), d)

def judgePoint24(self, nums: List[int]) -> bool:
mul = lambda x, y: x * y
plus = lambda x, y: x + y
div = lambda x, y: x / y if y != 0 else math.inf
minus = lambda x, y: x - y

op_lst = [plus, minus, mul, div]

for ops in product(op_lst, repeat=3):
for val in permutations(nums):
for v in self.iter_trees(ops[0], ops[1], ops[2], val[0], val[1], val[2], val[3]):
if abs(v - 24) < 0.0001:
return True
return False

24 点问题之 DFS yield from 解法

一种常规的思路是,在四个数组成的集合中先选出任意两个数,枚举所有可能的计算,再将剩余的三个数组成的集合递归调用下去,直到叶子节点只剩一个数,如下图所示。

DFS 调用示例

下面的代码是这种思路的 itertools + yield from 解法,recurse方法是generator,会自我递归调用。当只剩下两个数时,用 yield 返回两个数的所有可能运算得出的值,其他非叶子情况下则自我调用使用yield from,例如4个数任选2个先计算再合成3个数的情况。这种情况下,比较麻烦的是由于4个数可能有相同值,若用 combinations(lst, 2) 先任选两个数,后续要生成剩余两个数加上第三个计算的数的集合代码会繁琐。因此,我们改成任选4个数index中的两个,剩余的indices 可以通过集合操作来完成。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# AC
# Runtime: 116 ms, faster than 56.23% of Python3 online submissions for 24 Game.
# Memory Usage: 13.9 MB, less than 44.89% of Python3 online submissions for 24 Game.

import math
from itertools import combinations, product, permutations
from typing import List

class Solution:

def judgePoint24(self, nums: List[int]) -> bool:
mul = lambda x, y: x * y
plus = lambda x, y: x + y
div = lambda x, y: x / y if y != 0 else math.inf
minus = lambda x, y: x - y

op_lst = [plus, minus, mul, div]

def recurse(lst: List[int]):
if len(lst) == 2:
for op, values in product(op_lst, permutations(lst)):
yield op(values[0], values[1])
else:
# choose 2 indices from lst of length n
for choosen_idx_lst in combinations(list(range(len(lst))), 2):
# remaining indices not choosen (of length n-2)
idx_remaining_set = set(list(range(len(lst)))) - set(choosen_idx_lst)

# remaining values not choosen (of length n-2)
value_remaining_lst = list(map(lambda x: lst[x], idx_remaining_set))
for op, idx_lst in product(op_lst, permutations(choosen_idx_lst)):
# 2 choosen values are lst[idx_lst[0]], lst[idx_lst[1]
value_remaining_lst.append(op(lst[idx_lst[0]], lst[idx_lst[1]]))
yield from recurse(value_remaining_lst)
value_remaining_lst = value_remaining_lst[:-1]

for v in recurse(nums):
if abs(v - 24) < 0.0001:
return True

Leetcode 1227 是一道有意思的概率题,本篇将从多个角度来讨论这道题。题目如下

有 n 位乘客即将登机,飞机正好有 n 个座位。第一位乘客的票丢了,他随便选了一个座位坐下。 剩下的乘客将会: 如果他们自己的座位还空着,就坐到自己的座位上, 当他们自己的座位被占用时,随机选择其他座位,第 n 位乘客坐在自己的座位上的概率是多少?

示例 1: 输入:n = 1 输出:1.00000 解释:第一个人只会坐在自己的位置上。

示例 2: 输入: n = 2 输出: 0.50000 解释:在第一个人选好座位坐下后,第二个人坐在自己的座位上的概率是 0.5。

提示: 1 <= n <= 10^5

假设规模为n时答案为f(n),一般来说,这种递推问题在数学形式上可能有关于n的简单数学表达式(closed form),或者肯定有f(n)关于f(n-k)的递推表达式。工程上,我们可以通过通过多次模拟即蒙特卡罗模拟来算得近似的数值解。

Monte Carlo 模拟发现规律

首先,我们先来看看如何高效的用代码来模拟。根据题意的描述过程,直接可以写出下面代码。seats为n大小的bool 数组,每个位置表示此位置是否已经被占据。然后依次给第i个人按题意分配座位。注意,每次参数随机数范围在[0,n-1],因此,会出现已经被占据的情况,此时需要再次随机,直至分配到空位。

暴力直接模拟

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def simulate_bruteforce(n: int) -> bool:
"""
Simulates one round. Unbounded time complexity.
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
seats[random.randint(0, n - 1)] = True
else:
if not seats[i]: # i-th has his seat
seats[i] = True
else:
while True:
rnd = random.randint(0, n - 1) # random until no conflicts
if not seats[rnd]:
seats[rnd] = True
break
return not seats[n-1]

运行上面的代码来模拟 n 从 2 到10 的情况,每种情况跑500次模拟,输出如下

1
2
3
4
5
6
7
8
9
10
1 => 1.0
2 => 0.55
3 => 0.54
4 => 0.486
5 => 0.488
6 => 0.498
7 => 0.526
8 => 0.504
9 => 0.482
10 => 0.494

发现当 n>=2 时,似乎概率都是0.5。

标准答案

其实,这道题的标准答案就是 n=1 为1,n>=2 为0.5。下面是 python 3 标准答案。本篇后面会从多个角度来探讨为什么是0.5 。

{linenos
1
2
3
class Solution:
def nthPersonGetsNthSeat(self, n: int) -> float:
return 1.0 if n == 1 else 0.5

O(n) 改进算法

上面的暴力直接模拟版本有个最大的问题是当n很大时,随机分配座位会产生大量冲突,因此,最坏复杂度是没有任何上限的。解决方法是每次发生随机分配时保证不冲突,能直接选到空位。下面是一种最坏复杂度O(n)的模拟过程,seats数组初始话成 0,1,...,n-1,表示座位号。当第i个人登机时,seats[i:n] 的值为他可以选择的座位集合,而seats[0:i]为已经被占据的座位集合。由于[i: n]是连续空间,产生随机数就能保证不冲突。当第i个人选完座位时,将他选中的seats[k]和seats[i] 交换,保证第i+i个人面临的seats[i+1:n]依然为可选座位集合。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def simulate_online(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [i for i in range(n)]

def swap(i, j):
tmp = seats[i]
seats[i] = seats[j]
seats[j] = tmp

# for each person, the seats array idx available are [i, n-1]
for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
swap(rnd, 0)
else:
if seats[i] == i: # i-th still has his seat
pass
else:
rnd = random.randint(i, n - 1) # selects idx from [i, n-1]
swap(rnd, i)
return seats[n-1] == n - 1

递推思维

这一节我们用数学递推思维来解释0.5的解。令f(n) 为第 n 位乘客坐在自己的座位上的概率,考察第一个人的情况(first step analysis),有三种可能

  1. 第一个人选了第一个即自己的座位,那么最后一个人一定能保证坐在自己的座位。
  2. 第一个人选了最后一个人的座位,无论中间什么过程,最后一个人无法坐到自己座位
  3. 第一个人选了第i个座位,(1<i<n),那么第i个人前面的除了第一个外的人都会坐在自己位置上,第i个人由于没有自己座位,随机在剩余的座位1,座位 [i+1,n] 中随机选择,此时,问题转变为f(n-i+1),如下图所示。
第一个人选了位置i
第i个人将问题转换成f(n-i+1)

通过上面分析,得到概率递推关系如下

\[ f(n) = \begin{align*} \left\lbrace \begin{array}{r@{}l} 1 & & p=\frac{1}{n} \quad \text{选了第一个位置} \\\\\\ f(n-i+1) & & p=\frac{1}{n} \quad \text{选了第i个位置,1<i<n} \\\\\\ 0 & & p=\frac{1}{n} \quad \text{选了第n个位置} \end{array} \right. \end{align*} \]

即f(n)的递推式为: \[ f(n) = \frac{1}{n} + \frac{1}{n} \times [ f(n-1) + f(n-2) + ...+ f(2)], \quad n>=2 \] 同理,f(n+1)递推式如下 \[ f(n+1) = \frac{1}{n+1} + \frac{1}{n+1} \times [ f(n) + f(n-1) + ...+ f(2)] \] \((n+1)f(n+1) - nf(n)\) 抵消 \(f(n-1) + ...f(2)\) 项,可得 \[ (n+1)f(n+1) - nf(n) = f(n) \]\[ f(n+1) = f(n) = \frac{1}{2} \quad n>=2 \]

用数学归纳法也可以证明 n>=2 时 f(n)=0.5。

简化的思考方式

我们再仔细思考一下上面的第三种情况,就是第一个人坐了第i个座位,1<i<n,此时,程序继续,不产生结果,直至产生结局1或者2,也就是case 1和2是真正的结局节点,它们产生的概率相同,因此答案是1/2。

从调用图可以看出这种关系,由于中间节点 f(4),f(3),f(2)生成Case 1和2的概率一样,因此无论它们之间是什么关系,最后结果都是1/2.

知乎上有个很形象的类比理解方式

考虑一枚硬币,正面向上的概率为 1/n,反面也是,立起来的概率为 (n-2)/n 。我们规定硬币立起来重新抛,但重新抛时,n会至少减小1。求结果为反面的概率。这样很显然结果为 1/2 。

这里,正面向上对应Case 2,反面对应Case 1。

这种思想可以写出如下代码,seats为 n 大小的bool 数组,当第i个人(0<i<n)发现自己座位被占的话,此时必然seats[0]没有被占,同时seats[i+1:]都是空的。假设seats[0]被占的话,要么是第一个人占的,要么是第p个人(p<i)坐了,两种情况下乱序都已经恢复了,此时第i个座位一定是空的。

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def simulate(n: int) -> bool:
"""
Simulates one round of complexity O(N).
:param n: total number of seats
:return: True if last one has last seat, otherwise False
"""

seats = [False for _ in range(n)]

for i in range(n-1):
if i == 0: # first one, always random
rnd = random.randint(0, n - 1)
seats[rnd] = True
else:
if not seats[i]: # i-th still has his seat
seats[i] = True
else:
# 0 must not be available, now we have 0 and [i+1, n-1],
rnd = random.randint(i, n - 1)
if rnd == i:
seats[0] = True
else:
seats[rnd] = True
return not seats[n-1]
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×