MuZero

MuZero

Planning

MuZero makes decision via MCTS with UCB with learned representation, dynamics, policy and value model

Each internal state ss in the search tree maintains a set of statistics for each action

Statistics Definition Initialized by
visit counts N(s, a)N(s,\ a) 0
mean value Q(s, a)Q(s,\ a) 0
policy P(s, a)P(s,\ a) prediction model
reward R(s, a)R(s,\ a) dynamics model
state transition S(s, a)S(s,\ a) dynamics model

Each simulation starts from current state s0=sts^{0} = s_{t} and finishes at a leaf node sls^{l}, the search is divided into

Selection

For each hypothetical step k=1, 2,, lk = 1,\ 2, \cdots,\ l, an action is selected by maximizing over UCB (pUCT)

ak=arg maxa[Q(sk1, a)+P(sk1, a)bN(sk1, b)1+N(sk1, a)(c1+logbN(sk1, b)+c2+1c2)]a^{k} = \argmax_{a} \left[ Q(s^{k - 1},\ a) + P(s^{k - 1},\ a) \frac{\sqrt{\sum_{b} N(s^{k - 1},\ b)}}{1 + N(s^{k - 1},\ a)} \left( c_{1} + \log \frac{\sum_{b} N(s^{k - 1},\ b) + c_{2} + 1}{c_{2}} \right) \right]

To allow the combination of value and policy in the pUCT rule, the mean value is normalized as

Qˉ(sk1, a)=Q(sk1, a)mins, aTreeQ(s, a)maxs, aTreeQ(s, a)mins, aTreeQ(s, a)\bar{Q}(s^{k - 1},\ a) = \frac{Q(s^{k - 1},\ a) - \min_{s,\ a \in \mathrm{Tree}} Q(s,\ a)}{\max_{s,\ a \in \mathrm{Tree}} Q(s,\ a) - \min_{s,\ a \in \mathrm{Tree}} Q(s,\ a)}

The next state sks^{k} and reward rkr^{k} are looked up in the state transition and reward table of state sk1s^{k - 1} when k<lk < l

Expansion

At the final time-step ll of the simulaion, the state and reward are computed through learned dynamics model

sl, rl=gθ(sl1, al)s^{l},\ r^{l} = g_{\theta}(s^{l - 1},\ a^{l})

The state transition and reward table of state sl1s^{l - 1} is updated as S(sl1, al)=slS(s^{l - 1},\ a^{l}) = s^{l} and R(sl1, al)=rlR(s^{l - 1},\ a^{l}) = r^{l}. State sls^{l} is added to search tree with policy table initialized by policy of prediction model P(sl, a)=plP(s^{l},\ a) = \boldsymbol{p}^{l}

Backup

For k=l, l1, , 0k = l,\ l - 1,\ \cdots,\ 0, the mean value and visit count of (sk1, ak)(s^{k - 1},\ a^{k}) on the simulated trajectory is updated as

N(sk1, ak)N(sk1, ak)+1Q(sk1, ak)Q(sk1, ak)+1N(sk1, ak)[GkQ(sk1, ak)]\begin{gathered} N(s^{k - 1},\ a^{k}) \leftarrow N(s^{k - 1},\ a^{k}) + 1 \\[5mm] Q(s^{k - 1},\ a^{k}) \leftarrow Q(s^{k - 1},\ a^{k}) + \frac{1}{N(s^{k - 1},\ a^{k})} \Big[ G^{k} - Q(s^{k - 1},\ a^{k}) \Big] \end{gathered}

Where GkG^{k} is made up of (lk)(l - k)-step estimated cumulative discounted reward and value of prediction model

Gk=τ=0lk1γτR(sk+τ, ak+τ+1)+γlkvlG^{k} = \sum_{\tau = 0}^{l - k - 1} \gamma^{\tau} R(s^{k + \tau},\ a^{k + \tau + 1}) + \gamma^{l - k} v^{l}

After a certain number of simulation, MCTS outputs an estimated value νt\nu_{t} and a recommended policy πt()\pi_{t}(\cdot) based on the visit count of root node

πt(a)=N(s0, a)1/TbN(s0, b)1/T\pi_{t}(a) = \frac{N(s^{0},\ a)^{1 / T}}{\sum_{b} N(s^{0},\ b)^{1 / T}}

where temperature parameter TT is used for training of model and decayes from 1 w.r.t. training steps. This ensures that the action selection becomes greedier as training progresses

Training

The model of MuZero μθ\mu_{\theta} consists of a representation model, a dynamics model and a prediction model

SubPart Type Description Definition
representation world model encodes the past observations s0=hθ(o1, o2, , ot)s^{0} = h_{\theta}(o_{1},\ o_{2},\ \cdots,\ o_{t})
dynamics world model dynamics and reward on internal state sk, rk=gθ(sk1, ak)s^{k},\ r^{k} = g_{\theta}(s^{k - 1},\ a^{k})
prediction policy policy and value on internal state pk, vk=fθ(sk)\boldsymbol{p}^{k},\ v^{k} = f_{\theta}(s^{k})

A trajectory is sampled from replay buffer for training and the model is unrolled recurrently for KK steps

s0=hθ(o1, o2, , ot)sk, rk=gθ(sk1, ak)pk, vk=fθ(sk)s^{0} = h_{\theta}(o_{1},\ o_{2},\ \cdots,\ o_{t}) \quad s^{k},\ r^{k} = g_{\theta}(s^{k - 1},\ a^{k}) \quad \boldsymbol{p}^{k},\ v^{k} = f_{\theta}(s^{k})

The model μθ={hθ, gθ, fθ}\mu_{\theta} = \{ h_{\theta},\ g_{\theta},\ f_{\theta} \} is trained jointly to accurately match the policy, value, and reward on a trajectory

(θ)=k=0Kr(ut+k, rtk)+v(zt+k, vtk)+p(πt+k, ptk)+cθ22\ell(\theta) = \sum_{k = 0}^{K} \ell^{r}(u_{t + k},\ r_{t}^{k}) + \ell^{v}(z_{t + k},\ v_{t}^{k}) + \ell^{p}(\pi_{t + k},\ \boldsymbol{p}_{t}^{k}) + c \| \theta \|_{2}^{2}

where expected return ztz_{t} is computed by intermediate rewards and nn-step bootstrapping

zt=ut+1+γut+2++γn1ut+n+γnνt+nz_{t} = u_{t + 1} + \gamma u_{t + 2} + \cdots + \gamma^{n - 1} u_{t + n} + \gamma^{n} \nu_{t + n}

The latest checkpoint of the network is used to play games with MCTS to generate training data in replay buffer


MuZero
http://example.com/2024/08/02/MuZero/
Author
木辛
Posted on
August 2, 2024
Licensed under