Understanding the AlphaZero algorithm

As I'm learning the AlphaZero algorithm, I figured I might as well take some notes that may benefit others.

Note that there are (at least) three algorithms in the evolution of AlphaZero: AlphaGo (Fan and Lee versions), AlphaGo Zero, and AlphaZero (which extends to other games). We're ignoring the first algorithm.

1. Learn Monte-Carlo Tree Search (MCTS)

This is a wonderful lecture that walks you through an example run.

The basic idea is that you're building out a tree of game play from your current board state. You run N simulations, each time walking down the tree, potentially adding new child nodes and/or updating the estimated score of nodes. As you walk down the tree, you choose the child that best balances the tradeoff between exploration (of relatively unvisited moves) and exploitation (of estimated good moves). If you have no children yet, you either get an estimate for your own value (by simulating random moves until a terminal state) or you create child nodes from legal actions.

Anyway, that's too quick and dirty to be of much use. Watch the video (only 8 minutes at 2x speed).

You may also wish to learn more about the "UCB1" (Upper Confidence Bound) algorithm, used to trade off exploitation and exploration (a key concept in reinforcement learning).

2. Read about AlphaGo Zero / AlphaZero

AlphaGo Zero paper: Mastering the game of Go without human knowledge
AlphaZero paper: Mastering Chess and Shogi by Self-Play with aGeneral Reinforcement Learning Algorithm

This is a cool cheat sheet for AlphaGo Zero.

Here's a rough summary of the algorithm:


1. Create a neural network (a deep convolutional net with residual layers -- but we'll gloss over these details) that, given a board state (s), outputs two values f(s) = (p, v):
    1. p is a vector of probabilities for selecting each possible next move.
    2. v is an estimated value of the current board state (1 if current player wins, -1 if loses).
Initialize this NN randomly.



2. Simulate a game using a modified MCTS. Three main differences:

1. In normal MCTS, you invoke UCB until you hit a leaf node, at which point you either do a rollout (on the first visit) or create all children and do a rollout on one of them (second visit). On the third "visit" it's no longer a leaf node.

Here, there are no rollouts. You invoke UCB until you hit a leaf node, and then invoke it again on your imaginary children. Whichever one wins gets created (s'). Instead of rolling it out, you query the NN for f(s') = (p, v) and store these values. The value v is taken as the result of the game, and is "back-propagated" as before. (A bit confusingly, this creation is called "expanding" s', whereas normally "expand" means to add its children.)

Because we now have p for this node, we can calculate UCB on its imaginary children. (As you'll see, we also need the child Q and N values, which we can treat as being zero initially.)

2. Instead of UCB, they use something called PUCT. (They reference another paper here, but as others have pointed out, that seems to be a wrong reference. This seems to be the paper that introduced PUCT.) In particular, they calculate PUCT(s, a) = Q(s, a) + U(s, a), where:

Q(s, a) is just like v-bar before, representing the average value for this action, and





Where c_puct is some constant, P(s, a) is the p value we fetched earlier, and there's a term balancing how often this action has been picked compared to all actions from this state.

3. When it comes time to pick a move, we don't just pick the one with highest Q or highest N. We pick according to distribution π(s, a) = N(s, a)^(1/τ), where τ is a "temperature" parameter. For the first 30 moves, τ is set to 1, and after that, it is infinitesimal. When playing competitively, it is also set to an infinitesimal value (thus picking the strongest move instead of doing exploration).


So that's how they use MCTS to pick a single move. Each side picks moves until the game is over.


3. Retraining the NN.

After a game is completed, we annotate it. Each played action is assigned a tuple (s, π, z), where
  • s is the game state
  • π is the distribution mentioned earlier
  • z is the actual outcome of the game (1 for win, -1 for loss)

Many games can be played, and from them, some thousands of tuples are picked and used as training data for the NN. Recall that the NN is computing v and p. The loss function being optimized is:

  • (z - v)^2 tries to make the board state's value as close to the actual outcome (z) as possible.
  • -π log p computes cross-entropy loss, i.e., it tries to make p as close to π as possible.
  • A regularization term to prevent over-fitting.

In AlphaGo Zero, this retrained NN is tested and if it wins 55% of games vs the previous champ, it is kept. In AlphaZero it is always kept.

4. Go back to step 2.


No comments:

Post a Comment

Maximum Likelihood Estimation for dummies

What is Maximum Likelihood Estimation (MLE)? It's simple, but there are some gotchas. First, let's recall what likelihood  is. ...