TD-MPC

TD-MPC Series

TD-MPC v1

Model Predictive Path Integral(MPPI)

TD-MPC adapts MPPI as inference algorithm, where action trajectory is sampled from a time-dependent multivariate Gaussian with diagonal covariance over a horizon of length HH

{N(μτ, στ2)}τ=0Hμτ, στRA\Big\{ \mathcal{N}(\mu_{\tau},\ \sigma_{\tau}^{2}) \Big\}_{\tau = 0}^{H} \Leftarrow \mu_{\tau},\ \sigma_{\tau} \in \mathbb{R}^{|\mathcal{A}|}

Sample NN trajectories indepedently using rollouts generated by the learned environment model, and estimate the total return of trajectories with learned value function

ϕ(at:t+H)=Ezt+1Ezt+2Ezt+H[γHQθ(zH, aH)+τ=0H1γτRθ(zt+τ, at+τ)]=γHQθ(zH, aH)+τ=0H1γτRθ(zt+τ, at+τ)zt=hθ(st); zt+1=dθ(zt, at)\begin{aligned} \phi(a_{t : t + H}) &= \mathcal{E}_{z_{t + 1}} \mathcal{E}_{z_{t + 2}} \cdots \mathcal{E}_{z_{t + H}} \left[ \gamma^{H} Q_{\theta}(z_{H},\ a_{H}) + \sum_{\tau = 0}^{H - 1} \gamma^{\tau} R_{\theta}(z_{t + \tau},\ a_{t + \tau}) \right] \\[7mm] &= \gamma^{H} Q_{\theta}(z_{H},\ a_{H}) + \sum_{\tau = 0}^{H - 1} \gamma^{\tau} R_{\theta}(z_{t + \tau},\ a_{t + \tau}) \Leftarrow z_{t} = h_{\theta}(s_{t});\ z_{t + 1} = d_{\theta}(z_{t},\ a_{t}) \end{aligned}

The parameters of sampling distribution is updated via action trajectories with top-kk returns

μτi=1kΩiaτ(i)i=1kΩiστi=1kΩi(aτ(i)μτ)2i=1kΩi\mu_{\tau} \leftarrow \frac{\sum_{i = 1}^{k} \Omega_{i} a_{\tau}^{(i)}}{\sum_{i = 1}^{k} \Omega_{i}} \qquad \sigma_{\tau} \leftarrow \sqrt{\frac{\sum_{i = 1}^{k} \Omega_{i} (a_{\tau}^{(i)} - \mu_{\tau})^{2}}{\sum_{i = 1}^{k} \Omega_{i}}}

where trajectories are weighted by corresponding returns as Ωi=exp(κϕi)=exp(κϕ(at:t+H(i)))\Omega_{i} = \exp (\kappa \phi_{i}) = \exp \left( \kappa \phi(a_{t : t + H}^{(i)}) \right), κ\kappa is a temperature parameter controlling the “sharpness” of the weighting.

After a fixed number of iterations J, the action of current decision step tt is sampled from N(μ0, σ02)\mathcal{N}(\mu_{0},\ \sigma_{0}^{2}) to be taken.

Parameter Initialization

To reduce the number of iterations required for convergence, TD-MPC reuses the 1-step shifted mean value μτ\mu_{\tau} obtained at the previous step, but always use a large initial variance to avoid local minima.

Exploration by Planning

To promote consistent exploration, TD-MPC constrains the std. deviation by updating as

στmax(i=1kΩi(aτ(i)μτ)2i=1kΩi, ϵ)\sigma_{\tau} \leftarrow \max \left( \sqrt{\frac{\sum_{i = 1}^{k} \Omega_{i} (a_{\tau}^{(i)} - \mu_{\tau})^{2}}{\sum_{i = 1}^{k} \Omega_{i}}},\ \epsilon \right)

where ϵR+\epsilon \in \mathbb{R}^{+} is a linearly decayed constant. Likewise, the planning horizon is increased linearly from 1 to H in the early stages of training as the model is initially inaccurate.

Policy-guided Trajectory Optimization

TD-MPC augments the sampling procedure with additional NπN_{\pi} samples from learned policy πθ\pi_{\theta}.

Task-Oriented Latent Dynamics(TOLD)

TD-MPC leverages the following components of TOLD model during inferene:

Components Definition
representation z^t=hθ(st)\hat{z}_{t} = h_{\theta}(s_{t})
latent dynamics z^t=dθ(zt, at)\hat{z}_{t}' = d_{\theta}(z_{t},\ a_{t})
reward r^t=Rθ(zt, at)\hat{r}_{t} = R_{\theta}(z_{t},\ a_{t})
value q^t=Qθ(zt, at)\hat{q}_{t} = Q_{\theta}(z_{t},\ a_{t})
policy a^t=πθ(zt)\hat{a}_{t} = \pi_{\theta}(z_{t})

TOLD model is trained to minimize a temporally weighted objective

minθJ(θ; Γ)=τ=tt+HλτtL(θ; Γτ)\min_{\theta} \mathcal{J}(\theta;\ \Gamma) = \sum_{\tau = t}^{t + H} \lambda^{\tau - t} \mathcal{L}(\theta;\ \Gamma_{\tau})

where Γ={(sτ, aτ, rτ, sτ+1)}τ=tt+HB\Gamma = \Big\{ (s_{\tau},\ a_{\tau},\ r_{\tau},\ s_{\tau + 1}) \Big\}_{\tau = t}^{t + H} \sim \mathcal{B} is a trajectory sampled from replay buffer B\mathcal{B}, which consists of interaction data collected by TD-MPC during planning. A single-step loss is made up of

L(θ; Γτ)=crLr(θ; Γτ)+cvLv(θ; Γτ)+cπLπ(θ; Γτ)+ccLc(θ; Γτ)\mathcal{L}(\theta;\ \Gamma_{\tau}) = c_{r} \mathcal{L}_{r}(\theta;\ \Gamma_{\tau}) + c_{v} \mathcal{L}_{v}(\theta;\ \Gamma_{\tau}) + c_{\pi} \mathcal{L}_{\pi}(\theta;\ \Gamma_{\tau}) + c_{c} \mathcal{L}_{c}(\theta;\ \Gamma_{\tau})

Error Type Definition
reward prediction error Lr(θ; Γτ)=[Rθ(zτ, aτ)rτ]2\mathcal{L}_{r}(\theta;\ \Gamma_{\tau}) = \Big[ R_{\theta}(z_{\tau},\ a_{\tau}) - r_{\tau} \Big]^{2}
TD error of value function Lv(θ; Γτ)=[Qθ(zτ, aτ)rτγqθ(zτ+1, sg[πθ(zt+1)])]2\mathcal{L}_{v}(\theta;\ \Gamma_{\tau}) = \Big[ Q_{\theta}(z_{\tau},\ a_{\tau}) - r_{\tau} - \gamma q_{\theta^{-}} \Big( z_{\tau + 1},\ \operatorname{sg} \big[ \pi_{\theta}(z_{t + 1}) \big] \Big) \Big]^{2}
critic target for actor Lπ(θ; Γτ)=qsg[θ](zτ, πθ(zτ))\mathcal{L}_{\pi}(\theta;\ \Gamma_{\tau}) = -q_{\operatorname{sg}[\theta]}(z_{\tau},\ \pi_{\theta}(z_{\tau}))
latent state consistency loss Lc(θ; Γτ)=dθ(zτ, aτ)hθ(sτ+1)22\mathcal{L}_{c}(\theta;\ \Gamma_{\tau}) = \Big\| d_{\theta}(z_{\tau},\ a_{\tau}) - h_{\theta^{-}}(s_{\tau + 1}) \Big\|_{2}^{2}

where θ\theta^{-} is parameter of target net to improve the stability during training.

TD-MPC v2

TD-MPC v2 uses a learnable task embedding ee (constrained by e21\| e \|_{2} \le 1) to represent compact task semantics. For a new task, ee can be initialized as the embedding of a semantically similar task for subsequent fine-tune

Components Definition
encoder z^t=hθ(st, e)\hat{z}_{t} = h_{\theta}(s_{t},\ e)
latent dynamics z^t=dθ(zt, at, e)\hat{z}_{t}' = d_{\theta}(z_{t},\ a_{t},\ e)
reward (discrete) r^t=Rθ(zt, at, e)\hat{r}_{t} = R_{\theta}(z_{t},\ a_{t},\ e)
terminal value (discrete) q^t=Qθ(zt, at, e)\hat{q}_{t} = Q_{\theta}(z_{t},\ a_{t},\ e)
policy prior a^t=πθ(zt, e)\hat{a}_{t} = \pi_{\theta}(z_{t},\ e)

The latent representation is normalized by SimNorm (project zz into LL fixed-dimensional simplices via softmax)

z=[g1, g2, , gL]gi=softmaxτ(zi:i+V)z^{\circ} = [g_{1},\ g_{2},\ \cdots,\ g_{L}] \quad g_{i} = \operatorname{softmax}_{\tau} (z_{i:i + V})

which can naturally bias the representation towards sparsity without enforcing hard constraints.

The h, d, R, Qh,\ d,\ R,\ Q components are jointly optimized to minimize the model objective under a replay buffer B\mathcal{B}

L(θ)=E(s, a, r, s)0:TB[t=0Tλt(z^tsg(hθ(st))2+CE(r^t, rt)+CE(q^t, rt+γQθ(z^t, πθ(z^t))))]\mathcal{L}(\theta) = \mathcal{E}_{(s,\ a,\ r,\ s')_{0:\mathrm{T}} \sim \mathcal{B}} \left[ \sum_{t = 0}^{\mathrm{T}} \lambda^{t} \left( \Big\| \hat{z}_{t}' - \operatorname{sg}(h_{\theta}(s_{t}')) \Big\|^{2} + \operatorname{CE}(\hat{r}_{t},\ r_{t}) + \operatorname{CE} \Big( \hat{q}_{t},\ r_{t} + \gamma Q_{\theta^{-}}(\hat{z}_{t}',\ \pi_{\theta}(\hat{z}_{t}')) \Big) \right) \right]

The policy prior learns to maximize the maximum entropy objective, whose gradient are taken w.r.t. policy parameters only

Lp(θ)=E(s, a)0:TB[t=0Tλt(αQθ(zt, πθ(zt))+βH(πθzt))]s.t. zt+1=dθ(zt, at), z0=hθ(s0)\mathcal{L}_{p}(\theta) = \mathcal{E}_{(s,\ a)_{0:\mathrm{T}} \sim \mathcal{B}} \left[ \sum_{t = 0}^{\mathrm{T}} \lambda^{t} \Big( \alpha Q_{\theta}(z_{t},\ \pi_{\theta}(z_{t})) + \beta \mathcal{H}(\pi_{\theta} \mid z_{t}) \Big) \right] \quad \mathrm{s.t.}\ z_{t + 1} = d_{\theta}(z_{t},\ a_{t}),\ z_{0} = h_{\theta}(s_{0})

Similar to TD-MPC v1, TD-MPC v2 leverages MPPI for local trajectory optimization with terminal value

μ, σ=arg maxμ, σEat:t+HN(μ, σ)[h=tH1γhRθ(zh, ah)+γHQθ(zt+H, at+H)]\mu^{\star},\ \sigma^{\star} = \argmax_{\mu,\ \sigma} \mathcal{E}_{a_{t:t + \mathrm{H}} \sim \mathcal{N}(\mu,\ \sigma)} \left[ \sum_{h = t}^{\mathrm{H} - 1} \gamma^{h} R_{\theta}(z_{h},\ a_{h}) + \gamma^{\mathrm{H}} Q_{\theta}(z_{t + \mathrm{H}},\ a_{t + \mathrm{H}}) \right]

To accelerate convergence of planning, a fraction of action sequences originate from the policy prior is used for data augment, and 1-step shifted parameter initialization is used to warm-start planning


TD-MPC
http://example.com/2024/07/31/TD-MPC/
Author
木辛
Posted on
July 31, 2024
Licensed under