Depthwise Separable Convolutions

Recently(ish) there has been a development called "depthwise separable convolutions." Unfortunately, it's hard to find a resource that explains clearly what they are. You could just go off of the formulae, of course:

But those look like a pain. Here's a description in words.

Let's say our input is 128x128x16.

Regular convolution: We can think of a 3x3 convolution (with "same" padding) as follows. A single kernel has size 3x3x16 (matching the depth of the input). We walk it over our input as usual (over each of the 128x128 positions). At each step, each element in the kernel is multiplied by the corresponding element in the input, and all 3*3*16=144 elements are summed to produce one output. After walking over the whole input, we have a feature map of size 128x128x1. If we do this 32 times (that is, use 32 kernels), we stack them to get an output of 128x128x32.

Pointwise convolution: This is just a regular convolution, but our kernels are always 1x1(x16).

Depthwise convolution: We use a single kernel. Say it has size 3x3x16. We walk as usual, but instead of each time summing all 3*3*16 elements, we sum the 3*3 elements in each layer, and leave them in their layer. Thus we get a 128x128x16 output even though we have only one kernel.

Depthwise separable convolution: This is just a depthwise convolution followed by a pointwise convolution.

A keras layer that fetches data by index

Suppose you have a set of training images in a numpy array with shape (num_imgs, height, width, channels), and you want your model to take as input not a batch of images, but their indices. Your model will fetch the images using those indices. What might that code look like?

First, your input (the index) is a scalar, but Keras doesn't let you use scalar inputs. Easiest might be to make it have shape (1,):
 inp = Input(shape=(1,))  
Next, you might try to use a Lambda layer to extract the image from the input:
 def fetch_img(x):  
   return x_train[x.flatten()]   
   
 fetch_img = Lambda(fetch_img)(inp)  
If x = [[1], [2], [3]] (i.e., a bunch of arrays of shape (1,)), then we want to turn it into [1, 2, 3], so we flatten it. Recall that x_train[[1, 2, 3]] is the same as x_train[[1, 2, 3], : , : , : ], which selects images 1, 2, and 3.

Next, we don't want to hardcode the use of x_train:
 def fetch_img(x_train):    
   def _fetch_img(x):      
     return x_train[x.flatten()]     
   return _fetch_img    
   
 fetched_imgs = Lambda(fetch_img(x_train))(inp)       
But we still have a problem: x is of type Tensor, and the function given to Lambda must also return a Tensor. We're trying to operate on a numpy array (x_train).

Often the function you see passed to Lambda will appear to be a mathematical operation, but is really an overloaded TF op (e.g., "x + y" is tf.Tensor.__add__(x, y)). If you want to run arbitrary Python code, you have to invoke tf.py_func:
 def fetch_img(x_train):  
   def _fetch_img(x):  
     return tf.py_func(lambda x: x_train[x.flatten()],  
                       [x], tf.float32)    
   return _fetch_img   

Seems like it works! The following assert passes:
  def fetch_img(x_train):   
    def _fetch_img(x):   
      return tf.py_func(lambda x: x_train[x.flatten()],    
                        [x], tf.float32)     
   return _fetch_img   
     
  inp = Input(shape=(1,), dtype=tf.uint16)   
  fetched_imgs = Lambda(fetch_img(x_train))(inp)   
  model = Model(inp, fetched_imgs)   
     
  x_in = np.array([1, 5, 10])   
  res = model.predict(x_in)   
  exp = x_train[x_in]   
  assert np.array_equal(res, exp)
Still one problem though. What happens if we try to evaluate fetched_imgs.shape? It's unknown, meaning that later stages will crash if they try to make use of it. You have to explicitly set the shape:
 def fetch_img(x_train):  
   def _fetch_img(x):  
     res = tf.py_func(lambda x: x_train[x.flatten()], [x], tf.float32)  
     res.set_shape((x.shape[0],) + x_train[0].shape)  
     return res  
   return _fetch_img  
Whew, that should do it.

QM and confusing terminology, redux


I made a post earlier explaining why QM terminology can be so confusing as to prevent you (well, me) from learning it. It was probably too long, so I'm going to make it simpler.

Phase shift

Recall that light is an electromagnetic wave. This means it has (or is) an electric field (E) and a magnetic field (B) that are at right angles to each other:


If you put a measuring device anywhere along the x axis, you'd find that there's an electric field pointing up (or down) with some strength, and a magnetic field pointing left (or right) with proportional strength. They're proportional because they're "in phase." (The values are also oscillating in time, but forget about that for now.)
The electric and magnetic fields in EMR waves are always in phase and at 90 degrees to each other. -- Wikipedia
Now I want you to forget about the magnetic field and focus on the electric field. As a function of x, it is oscillating in one dimension (+z, -z, +z, ...). This is called linear polarization. You could rotate the field so that it's still linearly polarized, but instead of going up, down, up, down it's going upper-right, lower-left, upper-right, lower-left. (Apologies, the axes have been renamed here so that x is now called z.)

Since the x-y plane is two-dimensional, we can break that red wave down into two components: blue for x and green for y. Looked at in this way, the blue and green waves are "in-phase": when we're maximally on the right, we're maximally up; when we're down we're right; when one is zero so is the other. Note that the electric field components being in-phase has nothing to do with the magnetic field being in-phase with it.

We can also push these two "out of phase", so that when x is maximized y is zero, and vice versa. If we do this, we'll actually get a circularly polarized wave:
We say that the green and blue (or x and y) components are "90 degrees out of phase," or that we "phase shifted" the components of the wave with respect to each other. This is called a "relative" phase shift.

You can also "phase shift" the entire wave, by leaving the relative phase of the components the same, and shifting the entire spiral down the z axis. We might call this "absolute" phase shift.

When no context is given, it's almost always assumed the author means a relative phase shift, since that's the only one measurable by experiment.

Now for some gotchas.


Here's a claim from someone who "worked as a physicist at the Fermi National Accelerator Laboratory and the Superconducting Super Collider Laboratory," who explains that what's changing is the relative phase of the $E$ and $B$ fields:
The circularly polarized wave can be expressed as two linearly polarized waves, shifted by 90° in phase and rotated by 90° in polarization. If you pick some direction to measure the fields along, the components of E and B along that direction have a 90° phase shift with respect to each other. A phase shift of 90° means that as peaks B becomes zero, and as peaks becomes zero.
As far as I can tell, this is just wrong.


Here's another: from Wikipedia, among many other sources:
"Light waves change phase by 180° when they reflect [off a mirror]."
If the relative phase of light were to change by 180°, then it would change from upper-right, lower-left... to upper-left, lower-right.... This would be easy to detect with a polarizing filter, but it doesn't happen. That's because they're talking about absolute phase. Good luck finding anyone who will explain that.



Finally, recall the two-slit experiment.

"When the two waves are in phase... the summed intensity is maximum, and when they are in anti-phase... then the two waves cancel and the summed intensity is zero. This effect is known as interference."
Which phase are they talking about here? Neither, of course. This is referring to the quantum wave function, where the "wave" is a complex-valued function representing the "probability amplitude" that helps you figure out where the photon might be found.

Which brings us to....

Amplitude

Recall that a single complex number $c$ can be written as: $c = re^{i\theta}$.

Now, amplitude can refer to three different things here:
  1. The key innovation of QM is that we use "probability amplitudes." These refer to the complex number ($c$) itself.
  2. But complex numbers themselves have "amplitudes", usually referring to $r$ above.
  3. Some authors even use it to refer to $\theta$ ("The argument is sometimes also known as the phase or, more rarely and more confusingly, the amplitude")
To recap, a complex number $c = re^{i\theta}$ is an "amplitude," but it also has an amplitude, which can refer to either $r$ or $\theta$. One wonders if some day they'll use it to refer to $i$ and $e$, too.

Physicists can't decide whether they like (1) or (2). You commonly run into both 1:
"In quantum mechanics, a probability amplitude is a complex number used in describing the behaviour of systems." -- Wikipedia
And 2:
"The probability of getting any particular eigenvalue is equal to the square of the amplitude for that eigenvalue." -- Quantum Physicist Sean Carroll.
You first read (1) and when you get to (2) you think "wait, you can't just square a complex number and hope to get a real number (probability)." Amplitude must mean length. So you look up the relevant equation for "square of the amplitude":
Those bars are called "norm," or "length," or the "amplitude." So now whenever you detect usage (2) you mentally replace it with "norm," a concept from vector spaces. This begins to reinforce that terrible intuition they teach you in high school, of complex numbers being Real vectors.

How did you get that "vector" again? By taking the inner product of two complex vectors. That's funny, I thought the inner product was supposed to yield a scalar. No matter, let's just internalize this rule: the inner product of two complex vectors is another vector....

From the other side, maybe you have a hard time visualizing a vector in C^2 (i.e., a pair of complex numbers). So you mentally visualize it as a real vector. What is the inner product of two real vectors? It can be thought of as the length of the projection of one onto the other. So now you reinforce the intuition of the inner product as a (real) length.

So now you can't remember whether the inner product should yield a scalar (a + bi), a vector (a, b), or the length of that vector.

This is not a recipe for success.

Conclusion

Keep your concepts straight, or else you'll end up in an an abyss.

Quantum Mechanics and confusing terminology



If QM is hard, it shouldn't be because we're mixing up concepts that are distinct but go by the same name. The below is an attempt to prevent you from falling into some traps that I did.

What is phase?

Phase is one of the most important concepts in QM, and yet, when you're starting to learn QM, it can be terribly confusing. That's because what the word means depends on the context. Even people who should know better sometimes confuse it!

Let's look at what it can mean for light.

Recall that light is an electromagnetic wave. Light propagating in some direction will carry an electric field and a magnetic field, perpendicular to each other, and both perpendicular to the direction of travel.


$E$ represents the electric field, and $B$ the magnetic field. The lines represent the direction and strength of the field. Note how $E$ is vertical, moving sinusoidally from up to down and back. Also note how $B$ is doing the same thing but horizontally. We call this linearly polarized light for obvious reasons.

Let's look at how we represent this quantum mechanically. Sorry, some math will be involved here.

We will let the vector $|x\rangle$ represent horizontal polarization (using Dirac bra-ket notation), and $|y\rangle$ represent vertical polarization. Because the polarization is restricted to the y-z plane, it's two dimensional, and we only need two basis vectors. (We can also write these as (1, 0) and (0, 1) in coordinate form. Similarly, the vector $|x\rangle + |y\rangle$ can be written as (1, 1).)

We will let $|\psi\rangle$ indicate the photon's polarization state (i.e., its electrical field direction). In the above diagram, we have:
$$|\psi\rangle = |y\rangle = (0, 1)$$ But recall that in QM we use complex vectors, not real vectors. This just means that the components can take on complex values. For example: $$|\psi\rangle = i|y\rangle = (0, i)$$ What does this correspond to physically? Recall that a complex number can be written as $re^{i\phi}$, where $r$ is the "norm" (or "length") and $\phi$ is the "phase."


Now note that $i = e^{i\pi/2}$ (corresponding to the line pointing up in the above picture, where the angle is $\pi/2 = 90°$). That is, it has a phase of $\pi/2$. To understand what this means physically, imagine the EM diagram with the waves shifted left by $\pi/2$ -- that is, one quarter wavelength backward (since a whole wavelength is $2\pi$). So you can think of the phase $\phi$ as the position where the wave "starts." This would still be a linearly (in particular, vertically) polarized wave, but one that started at its maximum position instead of at zero.

How about this state? $$|\psi\rangle = \frac{\sqrt{2}}{2}(|x\rangle + |y\rangle) = (\frac{\sqrt{2}}{2}, \frac{\sqrt{2}}{2})$$ This looks a little scarier, but notice that the $\frac{\sqrt{2}}{2}$ is just there to normalize the vector so that its length is one. That is, it's just a scaled version of $|x\rangle + |y\rangle = (1,1)$.

This seems to indicate that the light is polarized in both the x and y directions. And indeed, you could call it diagonally polarized light.
Now the very important thing to note about the picture above is that the green and blue squiggles are not the $E$ and $B$ fields. They are just the x- and y-components of the $E$ field.

Notice how they touch the z-axis at the same time. In other words, they are zero at the same time. They also reach their maxima and minima at the same time. Of course, in this particular diagram they both start at their maximum instead of zero, so they both have phase $\pi/2$: $$|\psi\rangle = \frac{i\sqrt{2}}{2}(|x\rangle + |y\rangle) = (\frac{i\sqrt{2}}{2}, \frac{i\sqrt{2}}{2})$$ Also notice that for any phase, as long as the components both have the same phase, the light will be diagonally polarized. Another way of saying it is that their phase difference (or "relative phase") is zero. Also note that if their phase difference is $\pi$ so that one reaches its maximum as the other reaches its minimum, it will be diagonally polarized but in the other direction (i.e., up and to the left instead of up and to the right).

What about if the x- and y-components are out of phase by $\pi/2$? That's a little harder to visualize, but this picture should help you see what's going on when one direction reaches its maximum (or minimum) a quarter phase after the other:
Notice how the vertical arrows peak a quarter wavelength later than the horizontal ones. This is called "circularly polarized" light, and can be expressed as: $$|\psi\rangle = \frac{\sqrt{2}}{2}(|x\rangle + i|y\rangle) = (\frac{\sqrt{2}}{2}, \frac{i\sqrt{2}}{2}) = \frac{\sqrt{2}}{2}(1, i)$$

The relative phase is $\pi/2$. As long as that's true, it will be circularly polarized (in the clockwise direction). In the above, the phases were 0 (because $e^{0i} = 1$) and $\pi/2$ ($i = e^{i\pi/2}$). They could also have been, say, $\frac{3\pi}{2}$ and $2\pi$: $$|\psi\rangle = \frac{\sqrt{2}}{2}(-i|x\rangle + |y\rangle) = (\frac{-i\sqrt{2}}{2}, \frac{\sqrt{2}}{2}) = \frac{\sqrt{2}}{2}(-i, 1)$$
(Noting that $e^{\frac{3\pi i}{2}} = -i$ and $e^{2\pi i} = 1$).

Also notice that if the phase difference were $-\pi/2$ instead of $\pi/2$, then that would correspond to the counterclockwise direction: $$|\psi\rangle = \frac{\sqrt{2}}{2}(i|x\rangle + |y\rangle) = (\frac{i\sqrt{2}}{2}, \frac{\sqrt{2}}{2})$$

Phase shift

This brings us to two meanings of "phase shift."

1. We can shift both of the components' phases by some amount, so that their relative phase does not change: $$(x, y) \Rightarrow e^{i\phi}(x, y)$$ Recall from the previous section that as long as the relative phase doesn't change, the polarization doesn't change. If it was circularly polarized, it will stay circularly polarized, though it would start (and of course continue) at a different angle.

2. We can shift one component only:
$$(x, y) \Rightarrow (x, e^{i\phi}y)$$ Thus changing their relative phase. For example, if the relative phase changes from 0 to $\pi/2$, this corresponds to a change from linear to circular polarization.

The second definition is the one most commonly used, because changing relative phase results in a physically distinct system. So maybe you're really used to that definition, and then you come across this (Wikipedia, and many other sources):
"Light waves change phase by 180° when they reflect [off a mirror]."
Well, you can test for yourself that polarization doesn't change when light reflects from a mirror. That's because they're using definition one! Their relative phase doesn't change!

If you were smart, you'd quickly figure that out. But if you're me, you might resign yourself to never understanding anything about QM!

To make things worse, as I was learning this, I came across the blog of someone who "worked as a physicist at the Fermi National Accelerator Laboratory and the Superconducting Super Collider Laboratory," who explains that what's changing is the relative phase of the $E$ and $B$ fields:
The circularly polarized wave can be expressed as two linearly polarized waves, shifted by 90° in phase and rotated by 90° in polarization. If you pick some direction to measure the fields along, the components of E and B along that direction have a 90° phase shift with respect to each other. A phase shift of 90° means that as peaks B becomes zero, and as peaks becomes zero.
As far as I can tell, this is just wrong.

The other other meaning of phase

One of the first experiments you come across when hearing about QM is the famous Double Slit experiment. Recall that the interference pattern at the wall is explained by waves adding up or canceling out whether they are "in phase" or "out of phase."

Which phase are we talking about here? To explain it carefully would require some heavy-duty math, so this will be a gross oversimplification.

Above, the state vector was two dimensional, corresponding to the two possible directions of polarization. But the location of a photon can be any point in 3D space. So the state vector corresponding to a photon's position is (uncountably) infinite-dimensional. We call this state vector the "wave function." It's not terribly convenient to write out uncountably many components, so sometimes you see it depicted by a diagram:

(Credit Science4all)

If it hurts your head to think of this as an infinite-dimensional vector (where each point on the x-axis is a basis vector, and the height is the corresponding component -- and we pretend it is real-valued because it is hard to draw a complex height), you're not alone. If you want to understand the math, have at it.

Because light is a plane wave, the formula for its wave function is given by:


This formula tells you the (complex) value for the components of the wave function at distance $r$ and time $t$ (with amplitude $A$). Ignoring time for a moment, the form $e^{ikr}$ tells us that it's a complex number that's "rotating" according to $r$. (See the circle picture near the beginning.)

Light is effectively being emitted from two slits, and (almost) any point on the screen has a different distance ($r_1$ and $r_2$) from those two slits. Thus, the two waves are at different points in their evolution, sometimes canceling out and sometimes adding up.


(Credit Openstax)

So which meaning of phase is being used here? In the previous section we were talking about the phases of different components of the same state vector. Here we're talking about contributions to the same component (the position) from different paths. Seems like a different usage to me.

Again, if you were reading all this stuff by yourself, it would be pretty hard not to become terribly confused.

But wait, there's more!

Lest you think that's the end of the difficulties, consider the usage of the word "length" above. I used it in two different ways. A single complex number $c$ can be written as: $$c = re^{i\theta}$$ In this form, $r$ is called the norm. More commonly, it's called the modulus or the amplitude. But amplitude can refer to three different things here:
  1. The "key innovation" of QM is that we use "probability amplitudes," sometimes called "complex amplitudes." These refer to the complex number ($c$) itself.
  2. But complex numbers themselves have "amplitudes", usually referring to $r$ above.
  3. Most confusingly of all, some authors use it to refer to $\theta$! ("The argument is sometimes also known as the phase or, more rarely and more confusingly, the amplitude")
To recap, a complex number $$c = re^{i\theta}$$ is an "amplitude," but it also has an amplitude, which can refer to either $r$ or $\theta$. One wonders if some day they'll use it to refer to $i$ and $e$, too.

Further, when we have a vector $v = (x, y) = (r_1e^{i\theta}, r_2e^{i\phi})$, we just called $r_1$ the "norm" of the first component. But, like all vectors in a normed vector space, $v$ itself has a norm (the thing you normally think of as a vector's length).

People also fall into the trap of thinking of a single complex number as a pair of real numbers. This makes it impossible to reason about pairs of complex numbers. So please, don't do that.

Summary

QM is hard, but in part that's because it's easy to become confused by terminology and concepts which look the same or similar. As Scott Aaronson points out:

http://www.scottaaronson.com/democritus/lec9.html

"Today, in the quantum information age, the fact that all the physicists had to learn quantum this way seems increasingly humorous. For example, I've had experts in quantum field theory -- people who've spent years calculating path integrals of mind-boggling complexity -- ask me to explain the Bell inequality to them. That's like Andrew Wiles asking me to explain the Pythagorean Theorem.

As a direct result of this "QWERTY" approach to explaining quantum mechanics - which you can see reflected in almost every popular book and article, down to the present -- the subject acquired an undeserved reputation for being hard. Educated people memorized the slogans -- "light is both a wave and a particle," "the cat is neither dead nor alive until you look," "you can ask about the position or the momentum, but not both," "one particle instantly learns the spin of the other through spooky action-at-a-distance," etc. -- and also learned that they shouldn't even try to understand such things without years of painstaking work."

Why does AlphaZero use Dirichlet?

In the AlphaZero paper, they add Dirichlet noise to the prior probabilities P(s, a) for the root node. Specifically:
Dirichlet noise Dir(α) was added to the prior probabilities in the root node; this was scaled in inverse proportion to the approximate number of legal moves in a typical position, to a value of α = {0.3, 0.15, 0.03} for chess, shogi and Go respectively.
As I understand it, this means that if the P(s, a) vector has n components, then α is also n-dimensional, with each value the same.

What does Dir(0.03) look like? This blog gives some hints. Dir(0.999) looks like this:

Low values are blue, high goes toward red. So this is concentrated at the corners (e.g., the values (0, 0, 1), (0, 1, 0), and (1, 0, 0)). Unfortunately the plot is misleading, in that it looks very tightly concentrated. In actuality, smaller values of α are more concentrated to the corners.

Now look what happens when α > 1. Dir(5):
This is getting warmer near (1/3, 1/3, 1/3). (Recall that this triangle represents all locations where x + y + z = 1.)

As α increases (above 1), it gets more tightly clustered. Dir(50):


From this, we can infer that Dir(0.3) is more tightly concentrated near the corners than Dir(0.999), and Dir(0.03) more still. Also note that Dir(1.0) is a uniform distribution, favoring no particular point.

So the greater the number of available actions, and the smaller the α, the more strongly we try to bias the move toward one of them (at random).

Still, why Dirichlet and not something else? I'm not entirely sure. It does prevent favoring any particular move, but you could accomplish that by any distribution if you draw each action independently. Another benefit is that it ensures that the total bias added always sums to 1 (or really, to the constant ε = 0.25, which they use for scaling), but of course that can be accomplished by scaling the result of any set of distributions. But Dirichlet may be the most straightforward distribution that favors the standard basis vectors (i.e., the corners).


The Dirichlet distribution is the conjugate prior of the multinomial distribution

Say what? Let's walk through it step by step, starting with a brief recap of Bayesian inference.

Suppose you have a population of people, each with (unknown) probability θ of having a disease. You have some prior assumptions about the distribution of θ (for example, you might say "without any information, I think all values of θ are equally likely"). Now, you draw a sample of N people, out of which x are found to have the disease. From this data, how should you adjust your assumption about the distribution of θ?

You first calculate the probability of drawing this sample (x sick people from N), given θ. This is known as the likelihood, and in this example it's given by a binomial distribution:
P(sample | θ) = (N choose x) * θ^x * (1-θ)^(N-x)
Next, you make explicit your prior assumption about the distribution of θ, simply called the prior. Suppose we think it is uniform:
P(θ) = 1
The uniform distribution happens to also be special case of the beta distribution, which will be relevant in a moment. Beta is a distribution that takes two parameters, and in particular, Uniform = Beta(1, 1).

Using Bayes' rule, you can answer the question: how should I update my knowledge about θ, given that sample? That is, given this sample, what should the updated PDF of θ be?
P(θ | sample) = P(sample | θ) * P(θ) / P(sample)
This is called the posterior distribution. Intuitively, if x << N, our estimate of θ should go from being uniform to being skewed toward a low value.

How do we calculate P(sample)? Note that the posterior P(θ | sample) and likelihood P(sample | θ) are both functions of θ, but the former must be a PDF (and thus integrate to 1) while the latter need not be. Because the LHS of the above equation must integrate to 1, P(sample) can be calculated by taking the integral of the RHS. In general, this is not an easy thing to do, which is why it's convenient to know about conjugate priors.

Suppose we know that "beta is the conjugate prior of binomial." What this means is that if our prior is beta and our likelihood is binomial, then the posterior will still be some beta distribution. Moreover, we can look up the details and find that the parameters of our new beta are:
Beta(a + x, b + N - 1))
Where (a, b) are the parameters from the prior beta. In our case those were (1, 1), so our posterior is given by:
P(θ | sample) = Beta(1 + x, N)
Tada! Much simpler than taking the nasty integral. This neat trick explains why we called our prior Beta(1, 1) instead of just Uniform, even though they're the same thing. Even if we didn't have a uniform prior, choosing something in the Beta family makes our lives easier.

Now consider the multinomial distribution. In the binomial distribution, there are two outcomes (heads or tails; sick or not sick). The multinomial distribution is the generalization to many outcomes. Well, it turns out that the Dirichlet distribution is conjugate prior to multinomial. So in a problem where you know your likelihood is multinomial, it's useful to find a Dirichlet distribution that expresses your prior estimate of your parameters (e.g., the probability of each possible outcome), so that your posterior can be expressed in a simple (Dirichlet) form.


The above are heatmaps of various Dirichlet distributions of three parameters (for when there are three possible outcomes). The triangles are flattened 3d simplexes:


The green surface is all points (x, y, z) where x + y + z = 1. This is important because if x, y, and z represent the odds of getting three different outcomes, their sum should be 1.

Note that Dirichlet(1, 1, 1) (in the upper left) is constant, meaning it's just the Uniform distribution -- where all possible triplets are equally likely. When all three parameters are the same (but not 1), the distribution is symmetrical but skewed either toward the corners or the center (where all three coordinates are equal).

Of course, just because dirichlet helps you calculate your posterior when your likelihood is multinomial, this doesn't mean you have to (or even should) use it. It's just another tool in your belt.

Some confusion I encountered while deciphering AlphaZero

I realized I was a little confused about the details of the algorithm, and finally figured out what was confusing me.

Let's briefly revisit regular MCTS:
  • Starting at root, recursively invoke UCB on children (picking child with highest score) until you hit a leaf node.
  • On the node's first visit, perform a rollout (i.e., randomly simulated moves until end of game) to estimate node value score (v).
  • On the second visit, expand (i.e., create) its children, and visit+rollout one of them.
  • There is no third visit, since it's no longer a leaf node.

In AlphaZero there is no rollout, and a modified UCB (called PUCT) is used:
  • Invoke PUCT on children until you hit a leaf node (s).
  • On the node's first visit, invoke NN to estimate v(s) (node value) and p(s) (probabilities for each next action, used in PUCT). Also expand it (i.e., create but do not visit children).
  • There is no second visit, since it's no longer a leaf node.

From this, it seems like "expand s" means "create s's children." This agrees with the text in the paper:
"Each simulation starts from the root state and iteratively selects moves that maximize an upper confidence bound Q(s, a)+U(s, a) until a leaf node s′ is encountered. This leaf position is expanded and evaluated only once by the network to generate both prior probabilities and evaluation, (P(s′, ·),V(s′))= fθ(s′)."
So the leaf node is one that already exists, and expanding it creates its children. But now look at Figure 2:

"MCTS in AlphaGo Zero.  
a (Select): Each simulation traverses the tree by selecting the edge with maximum action value Q, plus an upper confidence bound U that depends on a stored prior probability P and visit count N for that edge (which is incremented once traversed). 
b (Expand and evaluate): The leaf node is expanded and the associated position s is evaluated by the neural network (P(s, ·),V(s))=fθ(s); the vector of P values are stored in the outgoing edges from s."

The circled node is the "leaf node" described in step b. But notice that it doesn't yet exist during the tree-walking phase (a). We created it during step b, and so "expand" s here must mean to create s (not its child nodes).

So it seems the paper describes the algorithm in two slightly different ways, using the word "expand" differently:
  1. Walk nodes until a leaf node is encountered. Evaluate it and expand it (create its children).
  2. Walk edges until the edge has no node attached to it (does this even meet the definition of a graph?). Create ("expand") that node and then evaluate it.
The second way is also how this (very well-written) tutorial explains it:
A single simulation proceeds as follows. We compute the action a that maximises the upper confidence bound U(s,a). If the next state s (obtained by playing action a on state s) exists in our tree, we recursively call the search on s. If it does not exist, we add the new state to our tree and initialise P(s,)=pθ(s) and the value v(s)=vθ(s) from the neural network, and initialise Q(s,a) and N(s,a) to 0 for all a. Instead of performing a rollout, we then propagate v(s) up along the path seen in the current simulation and update all Q(s,a) values.

Maybe it's obvious while reading the paper that they're the same thing, but it sure confused me.

Maximum Likelihood Estimation for dummies

What is Maximum Likelihood Estimation (MLE)? It's simple, but there are some gotchas. First, let's recall what likelihood  is. ...