PlaNet

PlaNet

Recurrent State Space Model

Model Type Definition Distribution Family
Transition Generation stp(stst1, at1)s_{t} \sim p(s_{t} \mid s_{t - 1},\ a_{t - 1}) N(μ(st1, at1), diag(st1, at1))\mathcal{N} \Big( \mu(s_{t - 1},\ a_{t - 1}),\ \mathrm{diag}(s_{t - 1},\ a_{t - 1}) \Big)
Observation Generation otp(otst)o_{t} \sim p(o_{t} \mid s_{t}) N(μ(st), I)\mathcal{N} \Big( \mu(s_{t}),\ \boldsymbol{I} \Big)
Reward Generation rtp(rtst)r_{t} \sim p(r_{t} \mid s_{t}) N(μ(st), 1)\mathcal{N} \Big( \mu(s_{t}),\ 1 \Big)
Posterior Inference stq(stst1, at1, ot)s_{t} \sim q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ o_{t}) N(μ(st1, at1, ot), diag(st1, at1, ot))\mathcal{N} \Big( \mu(s_{t - 1},\ a_{t - 1},\ o_{t}),\ \mathrm{diag}(s_{t - 1},\ a_{t - 1},\ o_{t}) \Big)

Given the action sequence, the POMDP model can be simplified as a non-stationary Markovian process

All components are trained jointly to maximize a variational lower bound (ELBO) instead of log-likelihood

lnp(o1:T,r1:Ta1:T)= lns1:Tp(o1:T, r1:T, s1:Ta1:T)=lns1:Tp(o1:T, r1:Ts1:T, a1:T)p(s1:Ta1:T)q(s1:To1:T, a1:T)q(s1:To1:T, a1:T)= lns1s2sTt=1Tp(otst)p(rtst)p(stst1, at1)t=1Tq(stst1, at1, ot)q(stst1, at1, ot)= lnEs1q(o1)Es2q(s1, a1, o2)EsTq(sT1, aT1, oT)[t=1Tp(otst)p(rtst)p(stst1, at1)q(stst1, at1, ot)] Es1Es2EsT[t=1Tlnp(otst)+lnp(rtst)+lnp(stst1, at1)lnq(stst1, at1, ot)]= t=1TEs1Es2Est[lnp(otst)+lnp(rtst)]Es1Es2Estq(st1, at1, ot)lnq(stst1, at1, ot)p(stst1, at1)= t=1TEs1Es2Est[lnp(otst)+lnp(rtst)]Es1Es2Est1DKL(q(st1, at1, ot)  p(st1, at1))\begin{aligned} &\ln p(o_{1:\mathrm{T}}, r_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) \\[7mm] =\ &\ln \sum_{s_{1:\mathrm{T}}} p(o_{1:\mathrm{T}},\ r_{1:\mathrm{T}},\ s_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) = \ln \sum_{s_{1:\mathrm{T}}} p(o_{1:\mathrm{T}},\ r_{1:\mathrm{T}} \mid s_{1:\mathrm{T}},\ a_{1:\mathrm{T}}) p(s_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) \frac{q(s_{1:\mathrm{T}} \mid o_{1:\mathrm{T}},\ a_{1:\mathrm{T}})}{q(s_{1:\mathrm{T}} \mid o_{1:\mathrm{T}},\ a_{1:\mathrm{T}})} \\[7mm] =\ &\ln \sum_{s_{1}} \sum_{s_{2}} \cdots \sum_{s_{\mathrm{T}}} \prod_{t = 1}^{\mathrm{T}} p(o_{t} \mid s_{t}) p(r_{t} \mid s_{t}) p(s_{t} \mid s_{t - 1},\ a_{t - 1}) \prod_{t = 1}^{\mathrm{T}} \frac{q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ o_{t})}{q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ o_{t})} \\[7mm] =\ &\ln \mathcal{E}_{s_{1} \sim q(\cdot \mid o_{1})} \mathcal{E}_{s_{2} \sim q(\cdot \mid s_{1},\ a_{1},\ o_{2})} \cdots \mathcal{E}_{s_{\mathrm{T}} \sim q(\cdot \mid s_{\mathrm{T} - 1},\ a_{\mathrm{T} - 1},\ o_{\mathrm{T}})} \left[ \prod_{t = 1}^{\mathrm{T}} \frac{p(o_{t} \mid s_{t}) p(r_{t} \mid s_{t}) p(s_{t} \mid s_{t - 1},\ a_{t - 1})}{q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ o_{t})} \right] \\[7mm] \ge\ &\mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{\mathrm{T}}} \left[ \sum_{t = 1}^{\mathrm{T}} \ln p(o_{t} \mid s_{t}) + \ln p(r_{t} \mid s_{t}) + \ln p(s_{t} \mid s_{t - 1},\ a_{t - 1}) - \ln q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ o_{t}) \right] \\[7mm] =\ &\sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t}} \Big[ \ln p(o_{t} \mid s_{t}) + \ln p(r_{t} \mid s_{t}) \Big] - \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t} \sim q(\cdot \mid s_{t - 1},\ a_{t - 1},\ o_{t})} \ln \frac{q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ o_{t})}{p(s_{t} \mid s_{t - 1},\ a_{t - 1})} \\[7mm] =\ &\sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t}}\Big[ \ln p(o_{t} \mid s_{t}) + \ln p(r_{t} \mid s_{t}) \Big] - \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t - 1}} D_{\mathrm{KL}} \Big( q(\cdot \mid s_{t - 1},\ a_{t - 1},\ o_{t})\ \|\ p(\cdot \mid s_{t - 1},\ a_{t - 1}) \Big) \end{aligned}

Replace the posterior to condition on past observations only and the objective can be rewritten as

lnp(o1:T,r1:Ta1:T)= lns1:Tp(o1:T, r1:T, s1:Ta1:T)=lns1:Tp(o1:T, r1:Ts1:T, a1:T)p(s1:Ta1:T)t=1Tq(stot, a<t)q(stot, a<t)= lnEs1q(o1)Es2q(o1, o2, a1)EsTq(o1:T, a1:T1)[t=1Tp(otst)p(rtst)p(stst1, at1)q(stot, a<t)] Es1Es2EsT[t=1Tlnp(otst)+lnp(rtst)+lnp(stst1, at1)lnq(stot, a<t)]= t=1TEs1Es2Est[lnp(otst)+lnp(rtst)]Es1Es2Est1DKL(q(ot, a<t)  p(st1, at))\begin{aligned} &\ln p(o_{1:\mathrm{T}}, r_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) \\[7mm] =\ &\ln \sum_{s_{1:\mathrm{T}}} p(o_{1:\mathrm{T}},\ r_{1:\mathrm{T}},\ s_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) = \ln \sum_{s_{1:\mathrm{T}}} p(o_{1:\mathrm{T}},\ r_{1:\mathrm{T}} \mid s_{1:\mathrm{T}},\ a_{1:\mathrm{T}}) p(s_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) \prod_{t = 1}^{\mathrm{T}} \frac{q(s_{t} \mid o_{\le t},\ a_{< t})}{q(s_{t} \mid o_{\le t},\ a_{< t})} \\[7mm] =\ &\ln \mathcal{E}_{s_{1} \sim q(\cdot \mid o_{1})} \mathcal{E}_{s_{2} \sim q(\cdot \mid o_{1},\ o_{2},\ a_{1})} \cdots \mathcal{E}_{s_{\mathrm{T}} \sim q(\cdot \mid o_{1:\mathrm{T}},\ a_{1:\mathrm{T} - 1})} \left[ \prod_{t = 1}^{\mathrm{T}} \frac{p(o_{t} \mid s_{t}) p(r_{t} \mid s_{t}) p(s_{t} \mid s_{t - 1},\ a_{t - 1})}{q(s_{t} \mid o_{\le t},\ a_{< t})} \right] \\[7mm] \ge\ &\mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{\mathrm{T}}} \left[ \sum_{t = 1}^{\mathrm{T}} \ln p(o_{t} \mid s_{t}) + \ln p(r_{t} \mid s_{t}) + \ln p(s_{t} \mid s_{t - 1},\ a_{t - 1}) - \ln q(s_{t} \mid o_{\le t},\ a_{< t}) \right] \\[7mm] =\ &\sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t}} \Big[ \ln p(o_{t} \mid s_{t}) + \ln p(r_{t} \mid s_{t}) \Big] - \mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{t - 1}} D_{\mathrm{KL}} \Big( q(\cdot \mid o_{\le t},\ a_{< t})\ \|\ p(\cdot \mid s_{t - 1},\ a_{t}) \Big) \end{aligned}

The parameters in the probabilistic model can be optimized by reparmeterized sample and gradient descent.

RNN SM RSSM
Diagram
Generation ht=f(ht1, at1)otp(otht)rtp(rtht)\begin{gathered} h_{t} = f(h_{t - 1},\ a_{t - 1}) \\[3mm] o_{t} \sim p(o_{t} \mid h_{t}) \\[3mm] r_{t} \sim p(r_{t} \mid h_{t}) \end{gathered} stp(stst1, at1)otp(otst)rtp(rtst)\begin{gathered} s_{t} \sim p(s_{t} \mid s_{t - 1},\ a_{t - 1}) \\[3mm] o_{t} \sim p(o_{t} \mid s_{t}) \\[3mm] r_{t} \sim p(r_{t} \mid s_{t}) \end{gathered} ht=f(ht1, st1, at1)stp(stht)otp(otht, st)rtp(rtht, st)\begin{gathered} h_{t} = f(h_{t - 1},\ s_{t - 1},\ a_{t - 1}) \\[3mm] s_{t} \sim p(s_{t} \mid h_{t}) \\[3mm] o_{t} \sim p(o_{t} \mid h_{t},\ s_{t}) \\[3mm] r_{t} \sim p(r_{t} \mid h_{t},\ s_{t}) \end{gathered}
Inference stq(stst1, at1, ot)stq(stst1, at1, rt)\begin{gathered} s_{t} \sim q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ o_{t}) \\[3mm] s_{t} \sim q(s_{t} \mid s_{t - 1},\ a_{t - 1},\ r_{t}) \end{gathered} stq(stht, ot)stq(stht, rt)\begin{gathered} s_{t} \sim q(s_{t} \mid h_{t},\ o_{t}) \\[3mm] s_{t} \sim q(s_{t} \mid h_{t},\ r_{t}) \end{gathered}

Transitions in SM are purely stochastic. This makes it difficult to remember information over multiple time steps. While the RSSM combines RNN and SM, spliting the state into stochastic and deterministic parts.

Latent Overshooting

Because of the limited capacity and restricted distributional family, training the model only on one-step predictions until convergence does in general not coincide with the model that is best at multi-step predictions

Standard Observation Overshooting Latent Overshooting
Diagram
Description Multi-Step Reconstruction Multi-Step Prior Prediction

Generalize the standard latent variational lower bound from one-step prior prediction to dd step prediction

lnpd(o1:Ta1:T)⇍lnp(o1:Ta1:T)= lns1:Tp(o1:Ts1:T)p(s1:Ta1:T)=lns1:Tt=1Tp(otst)p(ststd, atd1:t1)= lns1:Tt=1Tp(otst)(std+1st1τ=td+1tp(sτsτ1, aτ1))q(stot, a<t)q(stot, a<t)= lnEs1q(o1)Es2q(o1, o2, a1)EsTq(oT, a<T)[p(otst)st1p(st1std, atd1:t2)p(stst1, at1)q(stot, a<t)] Es1Es2EsT[t=1Tlnp(otst)+lnEst1p(std, atd1:t2)p(stst1, at1)q(stot, a<t)]= t=1TEstq(ot, a<t)lnp(otst)+Estdq(otd, a<td)Estq(ot, a<t)lnEst1p(std, atd1:t2)p(stst1, at1)q(stot, a<t) t=1TEstq(ot, a<t)lnp(otst)Estdq(otd, a<td)Est1p(std, atd1:t2)Estq(ot, a<t)lnq(stot, a<t)p(stst1, at1)= t=1TEstq(ot, a<t)lnp(otst)Estdq(otd, a<td)Est1p(std, atd1:t2)DKL(q(ot, a<t)  p(st1, at1))\begin{aligned} &\ln p_{d}(o_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) \not \Leftarrow \ln p(o_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) \\[5mm] =\ &\ln \sum_{s_{1:\mathrm{T}}} p(o_{1:\mathrm{T}} \mid s_{1:\mathrm{T}}) p(s_{1:\mathrm{T}} \mid a_{1:\mathrm{T}}) = \ln \sum_{s_{1:\mathrm{T}}} \prod_{t = 1}^{\mathrm{T}} p(o_{t} \mid s_{t}) p(s_{t} \mid s_{t - d},\ a_{t - d - 1 : t - 1}) \\[7mm] =\ &\ln \sum_{s_{1:\mathrm{T}}} \prod_{t = 1}^{\mathrm{T}} p(o_{t} \mid s_{t}) \left( \sum_{s_{t - d + 1}} \cdots \sum_{s_{t - 1}} \prod_{\tau = t - d + 1}^{t} p(s_{\tau} \mid s_{\tau - 1},\ a_{\tau - 1}) \right) \frac{q(s_{t} \mid o_{\le t},\ a_{< t})}{q(s_{t} \mid o_{\le t},\ a_{< t})} \\[7mm] =\ &\ln \mathcal{E}_{s_{1} \sim q(\cdot \mid o_{1})} \mathcal{E}_{s_{2} \sim q(\cdot \mid o_{1},\ o_{2},\ a_{1})} \cdots \mathcal{E}_{s_{\mathrm{T}} \sim q(\cdot \mid o_{\le \mathrm{T}},\ a_{< \mathrm{T}})} \left[ p(o_{t} \mid s_{t}) \sum_{s_{t - 1}} p(s_{t - 1} \mid s_{t - d},\ a_{t - d - 1 : t - 2}) \frac{p(s_{t} \mid s_{t - 1},\ a_{t - 1})}{q(s_{t} \mid o_{\le t},\ a_{< t})} \right] \\[7mm] \ge\ &\mathcal{E}_{s_{1}} \mathcal{E}_{s_{2}} \cdots \mathcal{E}_{s_{\mathrm{T}}} \left[ \sum_{t = 1}^{\mathrm{T}} \ln p(o_{t} \mid s_{t}) + \ln \mathcal{E}_{s_{t - 1} \sim p(\cdot \mid s_{t - d},\ a_{t - d - 1 : t - 2})} \frac{p(s_{t} \mid s_{t - 1},\ a_{t - 1})}{q(s_{t} \mid o_{\le t},\ a_{< t})} \right] \\[7mm] =\ &\sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{t} \sim q(\cdot \mid o_{\le t},\ a_{< t})} \ln p(o_{t} \mid s_{t}) + \mathcal{E}_{s_{t - d} \sim q(\cdot \mid o_{\le t - d},\ a_{< t - d})} \mathcal{E}_{s_{t} \sim q(\cdot \mid o_{\le t},\ a_{< t})} \ln \mathcal{E}_{s_{t - 1} \sim p(\cdot \mid s_{t - d},\ a_{t - d - 1 : t - 2})} \frac{p(s_{t} \mid s_{t - 1},\ a_{t - 1})}{q(s_{t} \mid o_{\le t},\ a_{< t})} \\[7mm] \ge\ &\sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{t} \sim q(\cdot \mid o_{\le t},\ a_{< t})} \ln p(o_{t} \mid s_{t}) - \mathcal{E}_{s_{t - d} \sim q(\cdot \mid o_{\le t - d},\ a_{< t - d})} \mathcal{E}_{s_{t - 1} \sim p(\cdot \mid s_{t - d},\ a_{t - d - 1 : t - 2})} \mathcal{E}_{s_{t} \sim q(\cdot \mid o_{\le t},\ a_{< t})} \ln \frac{q(s_{t} \mid o_{\le t},\ a_{< t})}{p(s_{t} \mid s_{t - 1},\ a_{t - 1})} \\[7mm] =\ &\sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{t} \sim q(\cdot \mid o_{\le t},\ a_{< t})} \ln p(o_{t} \mid s_{t}) - \mathcal{E}_{s_{t - d} \sim q(\cdot \mid o_{\le t - d},\ a_{< t - d})} \mathcal{E}_{s_{t - 1} \sim p(\cdot \mid s_{t - d},\ a_{t - d - 1 : t - 2})} D_{\mathrm{KL}} \Big( q(\cdot \mid o_{\le t},\ a_{< t})\ \|\ p(\cdot \mid s_{t - 1},\ a_{t - 1}) \Big) \end{aligned}

The latent overshooting objective train the model on multi-step predictions of all distance 1dD1 \le d \le D

1Dd=1Dlnpd(o1:T)t=1TEstlnp(otst)1Dd=1DβdEstdEst1DKL(q(ot, a<t)  p(st1, at1))\frac{1}{D} \sum_{d = 1}^{D} \ln p_{d}(o_{1:\mathrm{T}}) \ge \sum_{t = 1}^{\mathrm{T}} \mathcal{E}_{s_{t}} \ln p(o_{t} \mid s_{t}) - \frac{1}{D} \sum_{d = 1}^{D} \beta_{d} \mathcal{E}_{s_{t - d}} \mathcal{E}_{s_{t - 1}} D_{\mathrm{KL}} \Big( q(\cdot \mid o_{\le t},\ a_{< t})\ \|\ p(\cdot \mid s_{t - 1},\ a_{t - 1}) \Big)

where {βd}d=1D\{ \beta_{d} \}_{d = 1}^{D} is weighting factor for multi-step predictions analogously to the β-VAE

Learning and Planning

PlaNet fits model by maximizing the ELBO under collected dataset

With the learned model for generative process, the local action sequence can be optimized by MPC CEM with short-term rollout begin with a state randomly sampled from current posterior (belief)


PlaNet
http://example.com/2024/08/03/PlaNet/
Author
木辛
Posted on
August 3, 2024
Licensed under