MACD

MACD

World Model Learning

MAMBA reconstructs local information in a decentralized manner. Consequently, the agent’s latent state contains solely local information. While MACD reconstructs global information in a centralized manner

Model Type Definition Distribution Family
Recurrent Model Generation hti=fϕ(ht1i, eti)h_{t}^{i} = f_{\phi}(h_{t - 1}^{i},\ e_{t}^{i}) Deterministic
Communication Block Generation et=gϕ(zt1, at1)e_{t} = g_{\phi}(z_{t - 1},\ a_{t - 1}) Deterministic
Representation Model Inference ztiqϕ(ztihti, oti)z_{t}^{i} \sim q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i}) Categorical
Transition Predictor Generation ztipϕ(ztihti)z_{t}^{i} \sim p_{\phi}(z_{t}^{i} \mid h_{t}^{i}) Categorical
State Predictor Generation stpϕ(stht, zt)s_{t} \sim p_{\phi}(s_{t} \mid h_{t},\ z_{t})
Team Reward Predictor Generation rtpϕ(rtht, zt)r_{t} \sim p_{\phi}(r_{t} \mid h_{t},\ z_{t})
Discount Predictor Generation γtpϕ(γtht, zt)\gamma_{t} \sim p_{\phi}(\gamma_{t} \mid h_{t},\ z_{t})

Following Dreamer v3, the twohot symlog module is adopted as the final output head for the reward predictor

The overall loss function for world model is

L(ϕ)=Lpred(ϕ)+βdynLdyn(ϕ)+βrepLrep(ϕ)\mathcal{L}(\phi) = \mathcal{L}_{\mathrm{pred}}(\phi) + \beta_{\mathrm{dyn}} \mathcal{L}_{\mathrm{dyn}}(\phi) + \beta_{\mathrm{rep}} \mathcal{L}_{\mathrm{rep}}(\phi)

where

Lpred(ϕ)=lnpϕ(stht, zt)lnpϕ(rtht, zt)lnpϕ(γtht, zt)Ldyn(ϕ)=i=1nDKL(sg(qϕ(ztihti, oti))  pϕ(ztihti))Lrep(ϕ)=i=1nDKL(qϕ(ztihti, oti)  sg(pϕ(ztihti)))\begin{gathered} \mathcal{L}_{\mathrm{pred}}(\phi) = -\ln p_{\phi}(s_{t} \mid h_{t},\ z_{t}) - \ln p_{\phi}(r_{t} \mid h_{t},\ z_{t}) - \ln p_{\phi}(\gamma_{t} \mid h_{t},\ z_{t}) \\[5mm] \mathcal{L}_{\mathrm{dyn}}(\phi) = \sum_{i = 1}^{n} D_{\mathrm{KL}} \Big( \operatorname{sg}(q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i}))\ \|\ p_{\phi}(z_{t}^{i} \mid h_{t}^{i}) \Big) \quad \mathcal{L}_{\mathrm{rep}}(\phi) = \sum_{i = 1}^{n} D_{\mathrm{KL}} \Big( q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})\ \|\ \operatorname{sg}(p_{\phi}(z_{t}^{i} \mid h_{t}^{i})) \Big) \end{gathered}

Behavior Learning

The agent-wise policy πi(atihti, zti; θi)\pi_{i}(a_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i};\ \theta_{i}) is optimized by PPO with counterfactual advantage

JactorMACD(θiθiold)=t=0Tmin[πi(atihti, zti; θi)πi(atihti, zti; θiold)Ati, clip(πi(atihti, zti; θi)πi(atihti, zti; θiold), 1ϵ, 1+ϵ)Ati]\mathcal{J}_{\mathrm{actor}}^{\mathrm{MACD}}(\theta_{i} \mid \theta_{i}^{\mathrm{old}}) = \sum_{t = 0}^{\mathrm{T}} \min \left[ \frac{\pi_{i}(a_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i};\ \theta_{i})}{\pi_{i}(a_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i};\ \theta_{i}^{\mathrm{old}})} A_{t}^{i},\ \operatorname{clip} \left( \frac{\pi_{i}(a_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i};\ \theta_{i})}{\pi_{i}(a_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i};\ \theta_{i}^{\mathrm{old}})},\ 1 - \epsilon,\ 1 + \epsilon \right) A_{t}^{i} \right]

where the counterfactual baseline is calculated through HcH_{c} step rollout under default action dd

Ati=qπ(ht, zt, ati, ati)qπ(ht, zt, d, ati)k=0Hc1γkrt+k+1+γHcvψ(ht+Hc, zt+Hc)k=0Hc1γkrt+k+1MACDA_{t}^{i} = q_{\pi}(h_{t},\ z_{t},\ a_{t}^{i},\ a_{t}^{-i}) - q_{\pi}(h_{t},\ z_{t},\ d,\ a_{t}^{-i}) \approx \sum_{k = 0}^{H_{c} - 1} \gamma^{k} r_{t + k + 1} + \gamma^{H_{c}} v_{\psi}(h_{t + H_{c}},\ z_{t + H_{c}}) - \sum_{k = 0}^{H_{c} - 1} \gamma^{k} r_{t + k + 1}^{\mathrm{MACD}}

The critic vψ(ht, zt)v_{\psi}(h_{t},\ z_{t}) also uses the twohot symlog module and is optimized to minimized the HH step TD error


MACD
http://example.com/2024/09/01/MACD/
Author
木辛
Posted on
September 1, 2024
Licensed under