GenRL

GenRL

World Model Learning

GenRL adopts a variant of RSSM structure with categorical latent stochastic state sts_{t} as the representation of world model

Component Type Definition Component Type Definition
Encoder Inference stqϕ(stxt)s_{t} \sim q_{\phi}(s_{t} \mid x_{t}) Decoder Generation xtpϕ(xtst)x_{t} \sim p_{\phi}(x_{t} \mid s_{t})
Sequence Generation ht=fϕ(ht1, st1, at1)h_{t} = f_{\phi}(h_{t - 1},\ s_{t - 1},\ a_{t - 1}) Dynamics Generation stpϕ(stht)s_{t} \sim p_{\phi}(s_{t} \mid h_{t})

The world model is trained to maximize the ELBO of the log-likelihood on sampled trajectory p(x1:Ta1:T)p(x_{1:\mathrm{T}} \mid a_{1:\mathrm{T}})

maxϕEstqϕ(ht)[t=0Tlnp(xtst)DKL(qϕ(xt)  pϕ(stht))]s.t.ht=fϕ(ht1, st1, at1)\max_{\phi} \mathcal{E}_{s_{t} \sim q_{\phi}(\cdot \mid h_{t})} \left[ \sum_{t = 0}^{\mathrm{T}} \ln p(x_{t} \mid s_{t}) - D_{\mathrm{KL}} \Big( q_{\phi}(\cdot \mid x_{t})\ \|\ p_{\phi}(s_{t} \mid h_{t}) \Big) \right] \quad \operatorname{s.t.} \quad h_{t} = f_{\phi}(h_{t - 1},\ s_{t - 1},\ a_{t - 1})

To utilize the knowledge of pretrained VLM, GenRL connects the representation space between VLM and world model

where the connector pψ(stst1, e)p_{\psi}(s_{t} \mid s_{t - 1},\ e) learns to predict latent states st:t+ks_{t:t + k} from VLM-embeddings of observations xt:t+kx_{t:t + k}

Lconn=τ=tt+kDKL(pψ(stst1, e)  sg(qϕ(stxt)))s.t.e=e(v)=fVLM(v)(xt:t+k)\mathcal{L}_{\mathrm{conn}} = \sum_{\tau = t}^{t + k} D_{\mathrm{KL}} \Big( p_{\psi}(s_{t} \mid s_{t - 1},\ e)\ \|\ \operatorname{sg}(q_{\phi}(s_{t} \mid x_{t})) \Big) \quad \operatorname{s.t.} \quad e = e^{(v)} = f_{\mathrm{VLM}}^{(v)}(x_{t:t + k})

and the aligner fψ(e(l))f_{\psi}(e^{(l)}) is used to align different modality due to the multimodality gap caused by contrastive pretraining

Lalign=e(v)fψ(e(l))22\mathcal{L}_{\mathrm{align}} = \left\| e^{(v)} - f_{\psi}(e^{(l)}) \right\|_{2}^{2}

As vision-language data is typically unavailable in embodied domains, the aligner can be trained in a language-free manner

Lalign=e(v)fψ(e(l))22e(v)fψ(e(v)+ϵ)22\mathcal{L}_{\mathrm{align}} = \left\| e^{(v)} - f_{\psi}(e^{(l)}) \right\|_{2}^{2} \approx \left\| e^{(v)} - f_{\psi}(e^{(v)} + \epsilon) \right\|_{2}^{2}

which assumes that language embeddings can be treated as a corrupted version of their vision counterparts

Multi-Task Behavior Learning

GenRL adopts trajectory matching reward for behavior learning on latent states under user-prompted tasks

minθEstpϕ(ht), atπθ(st)[t=0Tγtdistance(pϕ(ht)  pψ(st1, etask))]s.t.ht=fϕ(ht1, st1, at1)\min_{\theta} \mathcal{E}_{s_{t} \sim p_{\phi}(\cdot \mid h_{t}),\ a_{t} \sim \pi_{\theta}(\cdot \mid s_{t})} \left[ \sum_{t = 0}^{\mathrm{T}} \gamma^{t} \operatorname{distance} \Big( p_{\phi}(\cdot \mid h_{t})\ \|\ p_{\psi}(\cdot \mid s_{t - 1},\ e_{\mathrm{task}}) \Big) \right] \quad \operatorname{s.t.} \quad h_{t} = f_{\phi}(h_{t - 1},\ s_{t - 1},\ a_{t - 1})

where the etaske_{\mathrm{task}} is the VLM-embedding of task prompts, the distribution distance can be KL divergence or cosine distance

In addition, the initial state of trajectories suggested by VLM from task prompts may differ from the trajectories generated by policy and world model, causing disalignment in the reward. GenRL performs the following steps to address this issue

  1. compares the similarity between initial bb states in target trajectory and sliding bb states in imagined trajectory
  2. finds the timestep tat_{a} with the highest similarity as the aligned initial timestep
  3. calculates the matching reward with the initial state of target trajectory for those timesteps before tat_{a}
  4. calculates the matching reward with the state on corresponding timestep for those timesteps after tat_{a}

GenRL
http://example.com/2024/09/27/GenRL/
Author
木辛
Posted on
September 27, 2024
Licensed under