MBVD

MBVD

World Model Learning

The latent dynamcis p(s^ts^t1, at1)p(\hat{s}_{t} \mid \hat{s}_{t - 1},\ a_{t - 1}) and posterior q(s^ts^t1, at1, ot)q(\hat{s}_{t} \mid \hat{s}_{t - 1},\ a_{t - 1},\ o_{t}) are optimized to maximize the ELBO

lnp(o1:Ta0:T, s0)=lnEs1q(o1, a0)Es2q(o1, o2, a0, a1)EsTq(o1:T, a0:T1)[t=1Tp(otst)p(stst1, at1)q(stot, a<t)]t=1TEs1Es2Estlnp(otst)Es1Es2Est1DKL(q(ot, a<t)  p(st1, at1))\begin{aligned} \ln p(o_{1:\mathrm{T}} \mid a_{0:\mathrm{T}},\ s_{0}) &= \ln \mathcal{E}_{s_{1} \sim q(\cdot \mid o_{1},\ a_{0})} \mathcal{E}_{s_{2} \sim q(\cdot \mid o_{1},\ o_{2},\ a_{0},\ a_{1})} \cdots \mathcal{E}_{s_{\mathrm{T}} \sim q(\cdot \mid o_{1:\mathrm{T}},\ a_{0:\mathrm{T} - 1})} \left[ \prod_{t = 1}^{\mathrm{T}} \frac{p(o_{t} \mid s_{t}) p(s_{t} \mid s_{t - 1},\ a_{t - 1})}{q(s_{t} \mid o_{\le t},\ a_{< t})} \right] \\[7mm] &\ge \sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t}} \ln p(o_{t} \mid s_{t}) - \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t - 1}} D_{\mathrm{KL}} \Big( q(\cdot \mid o_{\le t},\ a_{< t})\ \|\ p(\cdot \mid s_{t - 1},\ a_{t - 1}) \Big) \end{aligned}

The posterior distribution can be used for optional auxiliary tasks, e.g. feasible action prediction in StarCraft

The observation history τt=(ot, a<t)\tau_{t} = (o_{\le t},\ a_{< t}) can be replaced by the joint hidden state ht1:nh_{t}^{1:n}, since it’s encodes the former through the recurrent network in each agent. The loss function for world model in MBVD consists of

Loss Definition
Posterior Reconstruction LRC=MSE(τt, τt)\mathcal{L}_{\mathrm{RC}} = \operatorname{MSE}(\tau_{t},\ \tau_{t}')
Prior Reconstruction LRCPrior=MSE((st1, at1), (st1, at1))\mathcal{L}_{\mathrm{RC}}^{\mathrm{Prior}} = \operatorname{MSE} \big( (s_{t - 1},\ a_{t - 1}),\ (s_{t - 1}',\ a_{t - 1}') \Big)
KL Divergence (Balance) LKL=αDKL(sg[q(ht)]  p(st1, at1))+(1α)DKL(q(ht)  sg[p(st1, at1)])\begin{aligned} \mathcal{L}_{\mathrm{KL}} &= \alpha D_{\mathrm{KL}} \Big( \operatorname{sg}[q(\cdot \mid h_{t})]\ \|\ p(\cdot \mid s_{t - 1},\ a_{t - 1}) \Big) \\[5mm] &+ (1 - \alpha) D_{\mathrm{KL}} \Big( q(\cdot \mid h_{t})\ \|\ \operatorname{sg}[p(\cdot \mid s_{t - 1},\ a_{t - 1})] \Big) \end{aligned}
Prior Regularizer LKLreg=DKL(p(st1, at1)  N(0, I))\mathcal{L}_{\mathrm{KL}}^{\mathrm{reg}} = D_{\mathrm{KL}} \Big( p(\cdot \mid s_{t - 1},\ a_{t - 1})\ \|\ \mathcal{N}(0,\ \boldsymbol{I}) \Big)
Feasible Action BCE Loss LFA=BCE(At, At)\mathcal{L}_{\mathrm{FA}} = \operatorname{BCE}(\mathcal{A}_{t},\ \mathcal{A}_{t}')

Value Learning

The imagined states contain information about the possible states of the future, which can help the evaluation of global value. MBVD take the current state and extra multi-step latent rollout state as input in mix network

Qtot(τt, at, st, s^tRolloutψ)=f[Q1(τt1, at1), Q2(τt2, at2), , Qn(τtn, atn)st, s^tRollout]Q_{\mathrm{tot}}(\tau_{t},\ a_{t},\ s_{t},\ \hat{s}_{t}^{\mathrm{Rollout}} \mid \psi) = f \Big[ Q_{1}(\tau_{t}^{1},\ a_{t}^{1}),\ Q_{2}(\tau_{t}^{2},\ a_{t}^{2}),\ \cdots,\ Q_{n}(\tau_{t}^{n},\ a_{t}^{n}) \mid s_{t},\ \hat{s}_{t}^{\mathrm{Rollout}} \Big]

where s^tRollout\hat{s}_{t}^{\mathrm{Rollout}} is the recurrent encoding of kk step rollout latent state generated by learned dynamics. The value function is trained to minimize the square of TD error

LRL=[rt+γQtot(τt+1, at+1, st+1, s^t+1Rolloutψ)Qtot(τt, at, st, s^tRolloutψ)]2\mathcal{L}_{\mathrm{RL}} = \Big[ r_{t} + \gamma Q_{\mathrm{tot}}(\tau_{t + 1},\ a_{t + 1},\ s_{t + 1},\ \hat{s}_{t + 1}^{\mathrm{Rollout}} \mid \psi^{-}) - Q_{\mathrm{tot}}(\tau_{t},\ a_{t},\ s_{t},\ \hat{s}_{t}^{\mathrm{Rollout}} \mid \psi) \Big]^{2}


MBVD
http://example.com/2024/08/24/MBVD/
Author
木辛
Posted on
August 24, 2024
Licensed under