MAMBA

MAMBA

World Model Learning

MAMBA choose Dreamer v2 for world model representation and learning

Model Type Definition Distribution Family
Recurrent Model Generation hti=fϕ(ht1i, eti)h_{t}^{i} = f_{\phi}(h_{t - 1}^{i},\ e_{t}^{i}) Deterministic
Communication Block Generation et=gϕ(zt1, at1)e_{t} = g_{\phi}(z_{t - 1},\ a_{t - 1}) Deterministic
Representation Model Inference ztiqϕ(ztihti, oti)z_{t}^{i} \sim q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i}) Categorical
Transition Predictor Generation ztipϕ(ztihti)z_{t}^{i} \sim p_{\phi}(z_{t}^{i} \mid h_{t}^{i}) Categorical
Observation Predictor Generation otipϕ(otihti, zti)o_{t}^{i} \sim p_{\phi}(o_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i})
Reward Predictor Generation rtipϕ(rtihti, zti)r_{t}^{i} \sim p_{\phi}(r_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i})
Discount Predictor Generation γtipϕ(γtihti, zti)\gamma_{t}^{i} \sim p_{\phi}(\gamma_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) Bernoulli

All components are optimized jointly to maximize the ELBO of the log-likelihood

lnpϕ(o0:T, r0:T, γ0:Ta0:T1)=lnz0:Tpϕ(o0:T, r0:T, γ0:T, z0:Ta0:T1)= lnz0:T1z0:T2z0:Tnt=0Ti=1npϕ(otihti, zti)pϕ(rtihti, zti)p(γtihti, zti)p(ztihti)qϕ(ztihti, oti)qϕ(ztihti, oti)= lnEztiqϕ(ztihti, oti)[t=0Ti=1npϕ(otihti, zti)pϕ(rtihti, zti)p(γtihti, zti)p(ztihti)qϕ(ztihti, oti)] Eztiqϕ(ztihti, oti)[t=0Ti=1nlnpϕ(otihti, zti)+lnpϕ(rtihti, zti)+lnpϕ(γtihti, zti)lnqϕ(ztihti, oti)pϕ(ztihti)]= Eztiqϕ(ztihti, oti)[t=0Ti=1nlnpϕ(otihti, zti)+lnpϕ(rtihti, zti)+lnpϕ(γtihti, zti)DKL(qϕ(hti, oti)  pϕ(hti))]\begin{aligned} &\ln p_{\phi}(o_{0:\mathrm{T}},\ r_{0:\mathrm{T}},\ \gamma_{0:\mathrm{T}} \mid a_{0:\mathrm{T} - 1}) = \ln \sum_{z_{0:\mathrm{T}}} p_{\phi}(o_{0:\mathrm{T}},\ r_{0:\mathrm{T}},\ \gamma_{0:\mathrm{T}},\ z_{0:\mathrm{T}} \mid a_{0:\mathrm{T} - 1}) \\[7mm] =\ &\ln \sum_{z_{0:\mathrm{T}}^{1}} \sum_{z_{0:\mathrm{T}}^{2}} \cdots \sum_{z_{0:\mathrm{T}}^{n}} \prod_{t = 0}^{\mathrm{T}} \prod_{i = 1}^{n} p_{\phi}(o_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) p_{\phi}(r_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) p(\gamma_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) p(z_{t}^{i} \mid h_{t}^{i}) \frac{q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})}{q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})} \\[7mm] =\ &\ln \mathcal{E}_{z_{t}^{i} \sim q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})} \left[ \prod_{t = 0}^{\mathrm{T}} \prod_{i = 1}^{n} p_{\phi}(o_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) p_{\phi}(r_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) p(\gamma_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) \frac{p(z_{t}^{i} \mid h_{t}^{i})}{q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})} \right] \\[7mm] \ge\ &\mathcal{E}_{z_{t}^{i} \sim q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})} \left[ \sum_{t = 0}^{\mathrm{T}} \sum_{i = 1}^{n} \ln p_{\phi}(o_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) + \ln p_{\phi}(r_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) + \ln p_{\phi}(\gamma_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) - \ln \frac{q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})}{p_{\phi}(z_{t}^{i} \mid h_{t}^{i})} \right] \\[7mm] =\ &\mathcal{E}_{z_{t}^{i} \sim q_{\phi}(z_{t}^{i} \mid h_{t}^{i},\ o_{t}^{i})} \left[ \sum_{t = 0}^{\mathrm{T}} \sum_{i = 1}^{n} \ln p_{\phi}(o_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) + \ln p_{\phi}(r_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) + \ln p_{\phi}(\gamma_{t}^{i} \mid h_{t}^{i},\ z_{t}^{i}) - D_{\mathrm{KL}} \Big( q_{\phi}(\cdot \mid h_{t}^{i},\ o_{t}^{i})\ \|\ p_{\phi}(\cdot \mid h_{t}^{i}) \Big) \right] \end{aligned}

To encourage the latent state to depend mostly on its own actions, and further allow the distanglement of latent states, the mutual information I(hti, zti; at1i)I(h_{t}^{i},\ z_{t}^{i};\ a_{t - 1}^{i}) is maximized through its lower bounnd

maxϕI(hti, zti; at1i)E(hti, zti, at1i)pϕlnqϕ(at1ihti, zti)+H(hti, zti)maxϕE(hti, zti, at1i)pϕlnqϕ(at1ihti, zti)\max_{\phi} I(h_{t}^{i},\ z_{t}^{i};\ a_{t - 1}^{i}) \ge \mathcal{E}_{(h_{t}^{i},\ z_{t}^{i},\ a_{t - 1}^{i}) \sim p_{\phi}} \ln q_{\phi}(a_{t - 1}^{i} \mid h_{t}^{i},\ z_{t}^{i}) + \mathcal{H}(h_{t}^{i},\ z_{t}^{i}) \Rightarrow \max_{\phi} \mathcal{E}_{(h_{t}^{i},\ z_{t}^{i},\ a_{t - 1}^{i}) \sim p_{\phi}} \ln q_{\phi}(a_{t - 1}^{i} \mid h_{t}^{i},\ z_{t}^{i})

Behavior Learning

The actor πψ(atizti, hti)\pi_{\psi}(a_{t}^{i} \mid z_{t}^{i},\ h_{t}^{i}) and critic vηi(zt)v_{\eta}^{i}(z_{t}) is learned by PPO and TD(λ) method with the imagined rollout

During execution, agents need to broadcast their stochastic state zt1iz_{t - 1}^{i} and action at1ia_{t - 1}^{i} from the previous step in order to obtain feature vector etie_{t}^{i} and further update their world model with current observation

Communication

The communication block applies self-attention mechanism to process sequential state-action tuples

eti=softmax(1dQ(zti, ati)K(zt1:n, at1:n))V(zt1:n, at1:n)e_{t}^{i} = \operatorname{softmax} \left( \frac{1}{\sqrt{d}} Q(z_{t}^{i},\ a_{t}^{i}) \cdot K(z_{t}^{1:n},\ a_{t}^{1:n}) \right) \cdot V(z_{t}^{1:n},\ a_{t}^{1:n})

With the setting of locality, agents may only receive messages from its neighbours U(i)U(i)

eti=softmax(1dQi(zt, at)K(ztU(i), atU(i)))V(ztU(i), atU(i))e_{t}^{i} = \operatorname{softmax} \left( \frac{1}{\sqrt{d}} Q^{i}(z_{t},\ a_{t}) \cdot K(z_{t}^{U(i)},\ a_{t}^{U(i)}) \right) \cdot V(z_{t}^{U(i)},\ a_{t}^{U(i)})

The sparsity and low dimension of stochastic state can alleviate the limitation of the communication bandwidth


MAMBA
http://example.com/2024/08/31/MAMBA/
Author
木辛
Posted on
August 31, 2024
Licensed under