#Reinforcement Learning

This episode extends last one, where Minimax and Alpha Beta Pruning algorithms are introduced. We will solve several tic-tac-toe problems in leetcode, gathering intuition and building blocks for tic-tac-toe game logic, which can be naturally extended to Connect-N game or Gomoku (N=5). Then we solve tic-tac-toe using Minimax and Alpha Beta pruning for small N and analyze their state space. In the following episodes, based on building blocks here, we will implement a Connect-N Open Gym GUI Environment, where we can play against computer visually or compare different computer algorithms. Finally, we demonstrate how to implement a Monte Carlo Tree Search for Connect-N Game.

Leetcode Tic-Tac-Toe Problems

1275. Find Winner on a Tic Tac Toe Game (Easy)

Tic-tac-toe is played by two players A and B on a 3 x 3 grid.
Here are the rules of Tic-Tac-Toe:
Players take turns placing characters into empty squares (" ").
The first player A always places "X" characters, while the second player B always places "O" characters.
"X" and "O" characters are always placed into empty squares, never on filled ones.
The game ends when there are 3 of the same (non-empty) character filling any row, column, or diagonal.
The game also ends if all squares are non-empty.
No more moves can be played if the game is over. Given an array moves where each element is another array of size 2 corresponding to the row and column of the grid where they mark their respective character in the order in which A and B play.
Return the winner of the game if it exists (A or B), in case the game ends in a draw return "Draw", if there are still movements to play return "Pending".
You can assume that moves is valid (It follows the rules of Tic-Tac-Toe), the grid is initially empty and A will play first.

Example 1:
Input: moves = [[0,0],[2,0],[1,1],[2,1],[2,2]]
Output: "A"
Explanation: "A" wins, he always plays first.
"X " "X " "X " "X " "X "
" " -> " " -> " X " -> " X " -> " X "
" " "O " "O " "OO " "OOX"

Example 2:
Input: moves = [[0,0],[1,1],[0,1],[0,2],[1,0],[2,0]]
Output: "B"
Explanation: "B" wins.
"X " "X " "XX " "XXO" "XXO" "XXO"
" " -> " O " -> " O " -> " O " -> "XO " -> "XO "
" " " " " " " " " " "O "

Example 3:
Input: moves = [[0,0],[1,1],[2,0],[1,0],[1,2],[2,1],[0,1],[0,2],[2,2]]
Output: "Draw"
Explanation: The game ends in a draw since there are no moves to make.
"XXO"
"OOX"
"XOX"

Example 4:
Input: moves = [[0,0],[1,1]]
Output: "Pending"
Explanation: The game has not finished yet.
"X "
" O "
" "

The intuitive solution is to permute all 8 possible winning conditions: 3 vertical lines, 3 horizontal lines and 2 diagonal lines. We keep 8 variables representing each winning condition and a simple trick is converting board state to a 3x3 2d array, whose cell has value -1, 1, and 0. In this way, we can traverse the board state exactly once and in the process determine all 8 variables value by summing corresponding cell value. For example, row[0] is for first line winning condition, summed by all 3 cells in first row during board traveral. It indicates win for first player only when it's equal to 3 and win for second player when it's -3.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# AC
from typing import List

class Solution:
def tictactoe(self, moves: List[List[int]]) -> str:
board = [[0] * 3 for _ in range(3)]
for idx, xy in enumerate(moves):
player = 1 if idx % 2 == 0 else -1
board[xy[0]][xy[1]] = player

turn = 0
row, col = [0, 0, 0], [0, 0, 0]
diag1, diag2 = False, False
for r in range(3):
for c in range(3):
turn += board[r][c]
row[r] += board[r][c]
col[c] += board[r][c]
if r == c:
diag1 += board[r][c]
if r + c == 2:
diag2 += board[r][c]

oWin = any(row[r] == 3 for r in range(3)) or any(col[c] == 3 for c in range(3)) or diag1 == 3 or diag2 == 3
xWin = any(row[r] == -3 for r in range(3)) or any(col[c] == -3 for c in range(3)) or diag1 == -3 or diag2 == -3

return "A" if oWin else "B" if xWin else "Draw" if len(moves) == 9 else "Pending"

Below we give another AC solution. Despite more code, it's more efficient than previous one because for a given game state, it does not need to visit each cell on the board. How is it achieved? The problem guarentees each move is valid, so what's sufficent to examine is to check neighbours of the final move and see if any line including final move creates a winning condition. Later we will reuse the code in this solution to create tic-tac-toe game logic.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# AC
from typing import List

class Solution:
def checkWin(self, r: int, c: int) -> bool:
north = self.getConnectedNum(r, c, -1, 0)
south = self.getConnectedNum(r, c, 1, 0)

east = self.getConnectedNum(r, c, 0, 1)
west = self.getConnectedNum(r, c, 0, -1)

south_east = self.getConnectedNum(r, c, 1, 1)
north_west = self.getConnectedNum(r, c, -1, -1)

north_east = self.getConnectedNum(r, c, -1, 1)
south_west = self.getConnectedNum(r, c, 1, -1)

if (north + south + 1 >= 3) or (east + west + 1 >= 3) or \
(south_east + north_west + 1 >= 3) or (north_east + south_west + 1 >= 3):
return True
return False

def getConnectedNum(self, r: int, c: int, dr: int, dc: int) -> int:
player = self.board[r][c]
result = 0
i = 1
while True:
new_r = r + dr * i
new_c = c + dc * i
if 0 <= new_r < 3 and 0 <= new_c < 3:
if self.board[new_r][new_c] == player:
result += 1
else:
break
else:
break
i += 1
return result

def tictactoe(self, moves: List[List[int]]) -> str:
self.board = [[0] * 3 for _ in range(3)]
for idx, xy in enumerate(moves):
player = 1 if idx % 2 == 0 else -1
self.board[xy[0]][xy[1]] = player

# only check last move
r, c = moves[-1]
win = self.checkWin(r, c)
if win:
return "A" if len(moves) % 2 == 1 else "B"

return "Draw" if len(moves) == 9 else "Pending"

Leetcode 794. Valid Tic-Tac-Toe State(Medium)

A Tic-Tac-Toe board is given as a string array board. Return True if and only if it is possible to reach this board position during the course of a valid tic-tac-toe game.
The board is a 3 x 3 array, and consists of characters " ", "X", and "O". The " " character represents an empty square.
Here are the rules of Tic-Tac-Toe:
Players take turns placing characters into empty squares (" ").
The first player A always places "X" characters, while the second player B always places "O" characters.
"X" and "O" characters are always placed into empty squares, never on filled ones.
The game ends when there are 3 of the same (non-empty) character filling any row, column, or diagonal.
The game also ends if all squares are non-empty.
No more moves can be played if the game is over.

Example 1:
Input: board = ["O ", " ", " "]
Output: false
Explanation: The first player always plays "X".

Example 2:
Input: board = ["XOX", " X ", " "]
Output: false
Explanation: Players take turns making moves.

Example 3:
Input: board = ["XXX", " ", "OOO"]
Output: false

Example 4:
Input: board = ["XOX", "O O", "XOX"]
Output: true

Note: board is a length-3 array of strings, where each string board[i] has length 3.
Each board[i][j] is a character in the set {" ", "X", "O"}.

Surely, it can be solved using DFS, checking if the state given would be reached from initial state. However, this involves lots of states to search. Could we do better? There are obvious properties we can rely on. For example, the number of X is either equal to the number of O or one more. If we can enumerate a combination of necessary and sufficient conditions of checking its reachability, we can solve it in O(1) time complexity.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# AC
from typing import List

class Solution:

def convertCell(self, c:str):
return 1 if c == 'X' else -1 if c == 'O' else 0

def validTicTacToe(self, board: List[str]) -> bool:
turn = 0
row, col = [0, 0, 0], [0, 0, 0]
diag1, diag2 = False, False
for r in range(3):
for c in range(3):
turn += self.convertCell(board[r][c])
row[r] += self.convertCell(board[r][c])
col[c] += self.convertCell(board[r][c])
if r == c:
diag1 += self.convertCell(board[r][c])
if r + c == 2:
diag2 += self.convertCell(board[r][c])

xWin = any(row[r] == 3 for r in range(3)) or any(col[c] == 3 for c in range(3)) or diag1 == 3 or diag2 == 3
oWin = any(row[r] == -3 for r in range(3)) or any(col[c] == -3 for c in range(3)) or diag1 == -3 or diag2 == -3
if (xWin and turn == 0) or (oWin and turn == 1):
return False
return (turn == 0 or turn == 1) and (not xWin or not oWin)

Leetcode 348. Design Tic-Tac-Toe (Medium, Locked)

Design a Tic-tac-toe game that is played between two players on a n x n grid.
You may assume the following rules:
A move is guaranteed to be valid and is placed on an empty block.
Once a winning condition is reached, no more moves is allowed.
A player who succeeds in placing n of their marks in a horizontal, vertical, or diagonal row wins the game.

Example:
Given n = 3, assume that player 1 is "X" and player 2 is "O" in the board.
TicTacToe toe = new TicTacToe(3);

toe.move(0, 0, 1); -> Returns 0 (no one wins)
|X| | |
| | | | // Player 1 makes a move at (0, 0).
| | | |

toe.move(0, 2, 2); -> Returns 0 (no one wins)
|X| |O|
| | | | // Player 2 makes a move at (0, 2).
| | | |

toe.move(2, 2, 1); -> Returns 0 (no one wins)
|X| |O|
| | | | // Player 1 makes a move at (2, 2).
| | |X|

toe.move(1, 1, 2); -> Returns 0 (no one wins)
|X| |O|
| |O| | // Player 2 makes a move at (1, 1).
| | |X|

toe.move(2, 0, 1); -> Returns 0 (no one wins)
|X| |O|
| |O| | // Player 1 makes a move at (2, 0).
|X| |X|

toe.move(1, 0, 2); -> Returns 0 (no one wins)
|X| |O|
|O|O| | // Player 2 makes a move at (1, 0).
|X| |X|

toe.move(2, 1, 1); -> Returns 1 (player 1 wins)
|X| |O|
|O|O| | // Player 1 makes a move at (2, 1).
|X|X|X|

Follow up:
Could you do better than O(n2) per move() operation?

348 is a locked problem. For each player's move, we can resort to checkWin function in second solution for 1275. We show another solution based on first solution of 1275, where 8 winning condition flags are kept and each move only touches associated several flag variables.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# AC
class TicTacToe:

def __init__(self, n:int):
"""
Initialize your data structure here.
:type n: int
"""
self.row, self.col, self.diag1, self.diag2, self.n = [0] * n, [0] * n, 0, 0, n

def move(self, row:int, col:int, player:int) -> int:
"""
Player {player} makes a move at ({row}, {col}).
@param row The row of the board.
@param col The column of the board.
@param player The player, can be either 1 or 2.
@return The current winning condition, can be either:
0: No one wins.
1: Player 1 wins.
2: Player 2 wins.
"""
if player == 2:
player = -1

self.row[row] += player
self.col[col] += player
if row == col:
self.diag1 += player
if row + col == self.n - 1:
self.diag2 += player

if self.n in [self.row[row], self.col[col], self.diag1, self.diag2]:
return 1
if -self.n in [self.row[row], self.col[col], self.diag1, self.diag2]:
return 2
return 0


Optimal Strategy of Tic-Tac-Toe

Tic-tac-toe and Gomoku (Connect Five in a Row) share the same rules and are generally considered as M,n,k-game, where board size range to M x N and winning condition changes to k.

ConnectNGame class implements M,n,k-game of MxM board size. It encapsulates the logic of checking each move and also is able to undo last move to facilitate backtrack in game search algorithm later.

ConnectNGame

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ConnectNGame:

PLAYER_A = 1
PLAYER_B = -1
AVAILABLE = 0
RESULT_TIE = 0
RESULT_A_WIN = 1
RESULT_B_WIN = -1

def __init__(self, N:int = 3, board_size:int = 3):
assert N <= board_size
self.N = N
self.board_size = board_size
self.board = [[ConnectNGame.AVAILABLE] * board_size for _ in range(board_size)]
self.gameOver = False
self.gameResult = None
self.currentPlayer = ConnectNGame.PLAYER_A
self.remainingPosNum = board_size * board_size
self.actionStack = []

def move(self, r: int, c: int) -> int:
"""

:param r:
:param c:
:return: None: game ongoing
"""
assert self.board[r][c] == ConnectNGame.AVAILABLE
self.board[r][c] = self.currentPlayer
self.actionStack.append((r, c))
self.remainingPosNum -= 1
if self.checkWin(r, c):
self.gameOver = True
self.gameResult = self.currentPlayer
return self.currentPlayer
if self.remainingPosNum == 0:
self.gameOver = True
self.gameResult = ConnectNGame.RESULT_TIE
return ConnectNGame.RESULT_TIE
self.currentPlayer *= -1

def undo(self):
if len(self.actionStack) > 0:
lastAction = self.actionStack.pop()
r, c = lastAction
self.board[r][c] = ConnectNGame.AVAILABLE
self.currentPlayer = ConnectNGame.PLAYER_A if len(self.actionStack) % 2 == 0 else ConnectNGame.PLAYER_B
self.remainingPosNum += 1
self.gameOver = False
self.gameResult = None
else:
raise Exception('No lastAction')

def getAvailablePositions(self) -> List[Tuple[int, int]]:
return [(i,j) for i in range(self.board_size) for j in range(self.board_size) if self.board[i][j] == ConnectNGame.AVAILABLE]

def getStatus(self) -> Tuple[Tuple[int, ...]]:
return tuple([tuple(self.board[i]) for i in range(self.board_size)])

Note that checkWin code is identical to second solution in 1275.

Minimax Strategy

Now we have Connect-N game logic, let's finish its minimax algorithm to solve the game.

Define a generic strategy base class, where action method needs to be overridden. Action method expects ConnectNGame class telling current game state and returns a tuple of 2 elements, the first element is the estimated or exact game result after taking action specified by second element. The second element is of form Tuple[int, int], denoting the position of the move, for instance, (1,1).

{linenos
1
2
3
4
5
6
7
8
class Strategy(ABC):

def __init__(self):
super().__init__()

@abstractmethod
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
pass
MinimaxStrategy code is very similar to previous minimax algorithms. The only added piece is the corresponding move returned by action method.
{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class MinimaxStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = copy.deepcopy(game)
result, move = self.minimax()
return result, move

def minimax(self) -> Tuple[int, Tuple[int, int]]:
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax()
game.undo()
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if ret == 1:
return 1, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax()
game.undo()
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if ret == -1:
return -1, move
return ret, bestMove
We plot up to first 2 moves with code above. For first player O, there are possibly 9 positions, where due to symmetry, only 3 kinds of moves, which we call corner, edge and center, respectively. The following graph shows whatever 9 positions the first player takes, the best result is draw. So solution of tic-tac-toe is draw.
Tic-tac-toe 9 First Step
Plot first step of 3 kinds of moves one by one below.
Tic-tac-toe First Step Corner
Tic-tac-toe First Step Edge
Tic-tac-toe First Step Center

Tic-tac-toe Solution and Number of States

An interesting question is the number of game states of tic-tac-toe. A loosely upper bound can be derived by \(3^9=19683\), which includes lots of inreachable states. This article Tic-Tac-Toe (Naughts and Crosses, Cheese and Crackers, etc lists number of states after each move. The total number is 5478.

Moves Positions Terminal Positions
0 1
1 9
2 72
3 252
4 756
5 1260 120
6 1520 148
7 1140 444
8 390 168
9 78 78
Total 5478 958

We can verify the number if we change a little of existing code to code below.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

class CountingMinimaxStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = copy.deepcopy(game)
self.dpMap = {}
result, move = self.minimax(game.getStatus())
return result, move

def minimax(self, gameStatus: Tuple[Tuple[int, ...]]) -> Tuple[int, Tuple[int, int]]:
# print(f'Current {len(strategy.dpMap)}')

if gameStatus in self.dpMap:
return self.dpMap[gameStatus]

game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.minimax(game.getStatus())
self.dpMap[game.getStatus()] = result, oppMove
else:
self.dpMap[game.getStatus()] = result, move
game.undo()
ret = max(ret, result)
bestMove = move if ret == result else bestMove
self.dpMap[gameStatus] = ret, bestMove
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)

if result is None:
assert not game.gameOver
result, oppMove = self.minimax(game.getStatus())
self.dpMap[game.getStatus()] = result, oppMove
else:
self.dpMap[game.getStatus()] = result, move
game.undo()
ret = min(ret, result)
bestMove = move if ret == result else bestMove
self.dpMap[gameStatus] = ret, bestMove
return ret, bestMove


if __name__ == '__main__':
tic_tac_toe = ConnectNGame(N=3, board_size=3)
strategy = CountingMinimaxStrategy()
strategy.action(tic_tac_toe)
print(f'Game States Number {len(strategy.dpMap)}')

Running the code proves the total number is 5478. Also illustrate some small scale game configuration results.

3x3 4x4
k=3 5478 (Draw) 6035992 (Win)
k=4 9722011 (Draw)
k=5

According to Wikipedia M,n,k-game, below are results for some game configuration.

3x3 4x4 5x5 6x6
k=3 Draw Win Win Win
k=4 Draw Draw Win
k=5 Draw Draw

What's worth mentioning is that Gomoku (Connect Five in a Row), of board size MxM >= 15x15 is proved by L. Victor Allis to be Win.

Alpha-Beta Pruning Strategy

Alpha Beta Pruning Strategy is pasted below.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class AlphaBetaStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = game
result, move = self.alpha_beta(self.game.getStatus(), -math.inf, math.inf)
return result, move

def alpha_beta(self, gameStatus: Tuple[Tuple[int, ...]], alpha:int=None, beta:int=None) -> Tuple[int, Tuple[int, int]]:
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.alpha_beta(game.getStatus(), alpha, beta)
game.undo()
alpha = max(alpha, result)
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == 1:
return ret, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
result, oppMove = self.alpha_beta(game.getStatus(), alpha, beta)
game.undo()
beta = min(beta, result)
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == -1:
return ret, move
return ret, bestMove

Rewrite alpha beta pruning with DP, where we omit alpha and beta parameters in alpha_beta_dp because lru_cache cannot specify effective parameters. Instead, we keep alpha and beta in a stack variable and maintain the stack according to alpha_bate_dp calling stack.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class AlphaBetaDPStrategy(Strategy):
def action(self, game: ConnectNGame) -> Tuple[int, Tuple[int, int]]:
self.game = game
self.alphaBetaStack = [(-math.inf, math.inf)]
result, move = self.alpha_beta_dp(self.game.getStatus())
return result, move

@lru_cache(maxsize=None)
def alpha_beta_dp(self, gameStatus: Tuple[Tuple[int, ...]]) -> Tuple[int, Tuple[int, int]]:
alpha, beta = self.alphaBetaStack[-1]
game = self.game
bestMove = None
assert not game.gameOver
if game.currentPlayer == ConnectNGame.PLAYER_A:
ret = -math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
self.alphaBetaStack.append((alpha, beta))
result, oppMove = self.alpha_beta_dp(game.getStatus())
self.alphaBetaStack.pop()
game.undo()
alpha = max(alpha, result)
ret = max(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == 1:
return ret, move
return ret, bestMove
else:
ret = math.inf
for pos in game.getAvailablePositions():
move = pos
result = game.move(*pos)
if result is None:
assert not game.gameOver
self.alphaBetaStack.append((alpha, beta))
result, oppMove = self.alpha_beta_dp(game.getStatus())
self.alphaBetaStack.pop()
game.undo()
beta = min(beta, result)
ret = min(ret, result)
bestMove = move if ret == result else bestMove
if alpha >= beta or ret == -1:
return ret, move
return ret, bestMove

This is fifth episode of series: TSP From DP to Deep Learning. In this episode, we turn to Reinforcement Learning technology, in particular, a model-free policy gradient method that embeds pointer network to learn minimal tour without supervised best tour label in dataset. Full list of this series is listed below.

Pointer Network Refresher

In previous episode Pointer Networks in PyTorch, we implemented Pointer Networks in PyTorch with a 2D Euclidean dataset.

Recall that the input is a graph as a sequence of \(n\) cities in a two dimensional space

\[ s=\{\mathbf{x_i}\}_{i=1}^n, \mathbf{x}_{i} \in \mathbb{R}^{2} \]

The output is a permutation of the points \(\pi\), that visits each city exactly once and returns to starting point with minimal distance.

Let us define the total distance of a \(\pi\) with respect to \(s\) as \(L\)

\[ L(\pi | s)=\left\|\mathbf{x}_{\pi(n)}-\mathbf{x}_{\pi(1)}\right\|_{2}+\sum_{i=1}^{n-1}\left\|\mathbf{x}_{\pi(i)}-\mathbf{x}_{\pi(i+1)}\right\|_{2} \]

The stochastic policy \(p(\pi | s; \theta)\), parameterized by \(\theta\), is aiming to assign high probabilities to short tours and low probabilities to long tours. The joint probability assumes independency to allow factorization.

\[ p(\pi | s; \theta) = \prod_{i=1}^{n} p\left({\pi(i)} | {\pi(1)}, \ldots, {\pi(i-1)} , s; \theta\right) \]

The loss of the model is cross entropy between the network’s output probabilities \(\pi\) and the best tour \(\hat{\pi}\) generated by a TSP solver.

Contribution made by Pointer networks is that it addressed the constraint in that it allows for dynamic index value given by the particular test case, instead of from a fixed-size vocabulary.

Reinforcement Learning

Neural Combinatorial Optimization with Reinforcement Learning combines the power of Reinforcement Learning (RL) and Deep Learning to further eliminate the constraint required by Pointer Networks that the training dataset has to have supervised labels of best tour. With deep RL, test cases do not need to have a solution which is common pattern in deep RL. In the paper, a model-free policy-based RL method is adopted.

Model-Free Policy Gradient Methods

In the authoritative RL book, chapter 8 Planning and Learning with Tabular Methods, there are two major approaches in RL. One is model-based RL and the other is model-free RL. Distinction between the two relies on concept of model, which is stated as follows:

By a model of the environment we mean anything that an agent can use to predict how the environment will respond to its actions.

So model-based methods demand a model of the environment, and hence dynamic programming and heuristic search fall into this category. With model in mind, utility of the state can be computed in various ways and planning stage that essentially builds policy is needed before agent can take any action. In contrast, model-free methods, without building a model, are more direct, ignoring irrelevant information and just focusing on the policy which is ultimately needed. Typical examples of model-free methods are Monte Carlo Control and Temporal-Difference Learning. >Model-based methods rely on planning as their primary component, while model-free methods primarily rely on learning.

In TSP problem, the model is fully determined by all points given, and no feedback is generated for each decision made. So it's unclear to how to map state value with a tour. Therefore, we turn to model-free methods. In chapter 13 Policy Gradient Methods, a particular approximation model-free method that learns a parameterized policy that can select actions without consulting a value function. This approach fits perfectly with aforementioned pointer networks where the parameterized policy \(p(\pi | s; \theta)\) is already defined.

Training objective is obvious, the expected tour length of \(\pi_\theta\) which, given an input graph \(s\)

\[ J(\theta | s) = \mathbb{E}_{\pi \sim p_{\theta}(\cdot | s)} L(\pi | s) \]

Monte Carlo Policy Gradient: REINFORCE with Baseline

In order to find largest reward, a typical way is to optimize the parameters \(\theta\) in the direction of derivative: \(\nabla_{\theta} J(\theta | s)\).

\[ \nabla_{\theta} J(\theta | s)=\nabla_{\theta} \mathbb{E}_{\pi \sim p_{\theta}(\cdot | s)} L(\pi | s) \]

RHS of equation above is the derivative of expectation that we have no idea how to compute or approximate. Here comes the well-known REINFORCE trick that turns it into form of expectation of derivative, which can be approximated easily with Monte Carlo sampling, where the expectation is replaced by averaging.

\[ \nabla_{\theta} J(\theta | s)=\mathbb{E}_{\pi \sim p_{\theta}(. | s)}\left[L(\pi | s) \nabla_{\theta} \log p_{\theta}(\pi | s)\right] \]

Another common trick, subtracting a baseline \(b(s)\), leads the derivative of reward to the following equation. Note that \(b(s)\) denotes a baseline function that must not depend on \(\pi\). \[ \nabla_{\theta} J(\theta | s)=\mathbb{E}_{\pi \sim p_{\theta}(. | s)}\left[(L(\pi | s)-b(s)) \nabla_{\theta} \log p_{\theta}(\pi | s)\right] \]

The trick is explained in as:

Because the baseline could be uniformly zero, this update is a strict generalization of REINFORCE. In general, the baseline leaves the expected value of the update unchanged, but it can have a large effect on its variance.

Finally, the equation can be approximated with Monte Carlo sampling, assuming drawing \(B\) i.i.d: \(s_{1}, s_{2}, \ldots, s_{B} \sim \mathcal{S}\) and sampling a single tour per graph: $ {i} p{}(. | s_{i}) $, as follows \[ \nabla_{\theta} J(\theta) \approx \frac{1}{B} \sum_{i=1}^{B}\left(L\left(\pi_{i} | s_{i}\right)-b\left(s_{i}\right)\right) \nabla_{\theta} \log p_{\theta}\left(\pi_{i} | s_{i}\right) \]

Actor Critic Methods

REINFORCE with baseline works quite well but it also has disadvantage.

REINFORCE with baseline is unbiased and will converge asymptotically to a local minimum, but like all Monte Carlo methods it tends to learn slowly (produce estimates of high variance) and to be inconvenient to implement online or for continuing problems.

A typical improvement is actor–critic methods, that not only learn approximate policy, the actor job, but also learn approximate value funciton, the critic job. This is because it reduces variance and accelerates learning via a bootstrapping critic that introduce bias which is often beneficial. Detailed algorithm in the paper illustrated below.

\[ \begin{align*} &\textbf{Algorithm Actor-critic training} \\ &1: \quad \textbf{ procedure } \text{ TRAIN(training set }S \text{, training steps }T \text{, batch size } B \text{)} \\ &2: \quad \quad \text{Initialize pointer network params } \theta \\ &3: \quad \quad \text{Initialize critic network params } \theta_{v} \\ &4: \quad \quad \textbf{for }t=1 \text{ to } T \textbf{ do }\\ &5: \quad \quad \quad s_{i} \sim \operatorname{SAMPLE INPUT } (S) \text{ for } i \in\{1, \ldots, B\} \\ &6: \quad \quad \quad \pi_{i} \sim \operatorname{SAMPLE SOLUTION } \left(p_{\theta}\left(\cdot | s_{i}\right)\right) \text{ for } i \in\{1, \ldots, B\} \\ &7: \quad \quad \quad b_{i} \leftarrow b_{\theta_{v}}\left(s_{i}\right) \text{ for } i \in\{1, \ldots, B\} \\ &8: \quad \quad \quad g_{\theta} \leftarrow \frac{1}{B} \sum_{i=1}^{B}\left(L\left(\pi_{i} | s_{i}\right)-b_{i}\right) \nabla_{\theta} \log p_{\theta}\left(\pi_{i} | s_{i}\right) \\ &9: \quad \quad \quad \mathcal{L}_{v} \leftarrow \frac{1}{B} \sum_{i=1}^{B} \left\| b_{i}-L\left(\pi_{i}\right) \right\| _{2}^{2} \\ &10: \quad \quad \quad \theta \leftarrow \operatorname{ADAM} \left( \theta, g_{\theta} \right) \\ &11: \quad \quad \quad \theta_{v} \leftarrow \operatorname{ADAM}\left(\theta_{v}, \nabla_{\theta_{v}} \mathcal{L}_{v}\right) \\ &12: \quad \quad \textbf{end for} \\ &13: \quad \textbf{return } \theta \\ &14: \textbf{end procedure} \end{align*} \]

Implementation in PyTorch

Beam Search in OpenNMT-py

In Episode 4 Search for Most Likely Sequence, an 3x3 rectangle trellis is given and several decoding methods are illustrated in plain python. In PyTorch version, there is a package OpenNMT-py that supports efficient batched beam search. But due to its complicated BeamSearch usage, previous problem is demonstrated using its API. For its details, please refer to Implementing Beam Search — Part 1: A Source Code Analysis of OpenNMT-py

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from copy import deepcopy
from math import exp
import torch
from onmt.translate import BeamSearch, GNMTGlobalScorer

def run_example():
BEAM_SIZE = 2
N_BEST = 1
BATCH_SZ = 1
SEQ_LEN = 3

initial = [0.35, 0.25, 0.4]
transition_matrix = [
[0.3, 0.6, 0.1],
[0.4, 0.2, 0.4],
[0.3, 0.4, 0.4]]

beam = BeamSearch(BEAM_SIZE, BATCH_SZ, 0, 1, 2, N_BEST, GNMTGlobalScorer(0.7, 0., "avg", "none"), 0, 30, False, 0, set(), False, 0.)
device_init = torch.zeros(1, 1)
beam.initialize(device_init, torch.randint(0, 30, (BATCH_SZ,)))

def printBestNPaths(beam: BeamSearch, step: int):
print(f'\nstep {step} beam results:')
for k in range(BEAM_SIZE):
best_path = beam.alive_seq[k].squeeze().tolist()[1:]
prob = exp(beam.topk_log_probs[0][k])
print(f'prob {prob:.3f} with path {best_path}')

init_scores = torch.log(torch.tensor([initial], dtype=torch.float))
init_scores = deepcopy(init_scores.repeat(BATCH_SZ * BEAM_SIZE, 1))
beam.advance(init_scores, None)
printBestNPaths(beam, 0)

for step in range(SEQ_LEN - 1):
idx_list = beam.topk_ids.squeeze().tolist()
beam_transition = []
for idx in idx_list:
beam_transition.append(transition_matrix[idx])
beam_transition_tensor = torch.log(torch.tensor(beam_transition))

beam.advance(beam_transition_tensor, None)
beam.update_finished()

printBestNPaths(beam, step + 1)

The output is as follows. When \(k=2\) and 3 steps, the most likely sequence is \(0 \rightarrow 1 \rightarrow 0\), whose probability is 0.084.

1
2
3
4
5
6
7
8
9
10
11
12
step 0 beam results:
prob 0.400 with path [2]
prob 0.350 with path [0]

step 1 beam results:
prob 0.210 with path [0, 1]
prob 0.160 with path [2, 1]

step 2 beam results:
prob 0.084 with path [0, 1, 0]
prob 0.000 with path [0, 1, 2]

RL with PointerNetwork

The complete code is on github TSP RL. Below are partial core classes.

{linenos
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class CombinatorialRL(nn.Module):
actor: PointerNet

def __init__(self, rnn_type, use_embedding, embedding_size, hidden_size, seq_len, num_glimpse, tanh_exploration, use_tanh, attention):
super(CombinatorialRL, self).__init__()

self.actor = PointerNet(rnn_type, use_embedding, embedding_size, hidden_size, seq_len, num_glimpse, tanh_exploration, use_tanh, attention)

def forward(self, batch_input: Tensor) -> Tuple[Tensor, List[Tensor], List[Tensor], List[Tensor]]:
"""
Args:
batch_input: [batch_size * 2 * seq_len]
Returns:
R: Tensor of shape [batch_size]
action_prob_list: List of [seq_len], tensor shape [batch_size]
action_list: List of [seq_len], tensor shape [batch_size * 2]
action_idx_list: List of [seq_len], tensor shape [batch_size]
"""
batch_size = batch_input.size(0)
seq_len = batch_input.size(2)
prob_list, action_idx_list = self.actor(batch_input)

action_list = []
batch_input = batch_input.transpose(1, 2)
for action_id in action_idx_list:
action_list.append(batch_input[[x for x in range(batch_size)], action_id.data, :])
action_prob_list = []
for prob, action_id in zip(prob_list, action_idx_list):
action_prob_list.append(prob[[x for x in range(batch_size)], action_id.data])

R = self.reward(action_list)

return R, action_prob_list, action_list, action_idx_list

def reward(self, sample_solution: List[Tensor]) -> Tensor:
"""
Computes total distance of tour
Args:
sample_solution: list of size N, each tensor of shape [batch_size * 2]

Returns:
tour_len: [batch_size]

"""
batch_size = sample_solution[0].size(0)
n = len(sample_solution)
tour_len = Variable(torch.zeros([batch_size]))

for i in range(n - 1):
tour_len += torch.norm(sample_solution[i] - sample_solution[i + 1], dim=1)
tour_len += torch.norm(sample_solution[n - 1] - sample_solution[0], dim=1)
return tour_len

References

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×