MACD
World Model Learning
MAMBA reconstructs local information in a decentralized manner. Consequently, the agent’s latent state contains solely local information. While MACD reconstructs global information in a centralized manner
Model |
Type |
Definition |
Distribution Family |
Recurrent Model |
Generation |
hti=fϕ(ht−1i, eti) |
Deterministic |
Communication Block |
Generation |
et=gϕ(zt−1, at−1) |
Deterministic |
Representation Model |
Inference |
zti∼qϕ(zti∣hti, oti) |
Categorical |
Transition Predictor |
Generation |
zti∼pϕ(zti∣hti) |
Categorical |
State Predictor |
Generation |
st∼pϕ(st∣ht, zt) |
|
Team Reward Predictor |
Generation |
rt∼pϕ(rt∣ht, zt) |
|
Discount Predictor |
Generation |
γt∼pϕ(γt∣ht, zt) |
|
Following Dreamer v3, the twohot symlog module is adopted as the final output head for the reward predictor
The overall loss function for world model is
L(ϕ)=Lpred(ϕ)+βdynLdyn(ϕ)+βrepLrep(ϕ)
where
Lpred(ϕ)=−lnpϕ(st∣ht, zt)−lnpϕ(rt∣ht, zt)−lnpϕ(γt∣ht, zt)Ldyn(ϕ)=i=1∑nDKL(sg(qϕ(zti∣hti, oti)) ∥ pϕ(zti∣hti))Lrep(ϕ)=i=1∑nDKL(qϕ(zti∣hti, oti) ∥ sg(pϕ(zti∣hti)))
Behavior Learning
The agent-wise policy πi(ati∣hti, zti; θi) is optimized by PPO with counterfactual advantage
JactorMACD(θi∣θiold)=t=0∑Tmin[πi(ati∣hti, zti; θiold)πi(ati∣hti, zti; θi)Ati, clip(πi(ati∣hti, zti; θiold)πi(ati∣hti, zti; θi), 1−ϵ, 1+ϵ)Ati]
where the counterfactual baseline is calculated through Hc step rollout under default action d
Ati=qπ(ht, zt, ati, at−i)−qπ(ht, zt, d, at−i)≈k=0∑Hc−1γkrt+k+1+γHcvψ(ht+Hc, zt+Hc)−k=0∑Hc−1γkrt+k+1MACD
The critic vψ(ht, zt) also uses the twohot symlog module and is optimized to minimized the H step TD error