LanGWM

LanGWM

Observation Representation Learning

The proposed language grounded representation learning has the following sub-modules

  1. object instance masking
    1. randomly select an instance of the object then mask the smallest rectangular bounding box
    2. add additional random margins to the bounding box
    3. mask up to 3 objects and stop early if the masked region reaches 75% of the image
    4. use spatial jitter, Gaussian blur, color jitter and grayscale data augmentations
  2. language description generation
    1. use language templates to generate the description of the masked object, for example
      1. If you look {distance} in the {direction}, you will see {object}
      2. There is {object} in the {direction} {distance}
      3. The {object} is approximately {distance}, {direction} from here
    2. The template is parameterised by
      1. {object} ⇒ semantic class of the object
      2. {direction} ⇒ average horizontal position of the object center
      3. {distance} ⇒ average distance of the object regions
    3. use BERT to extract the feature embeds from the language descriptions
    4. give constant token values of the empty description during the evaluation and test time
  3. object instance MAE
    1. use early convolution layer and apply the masking in the convolutional feature maps
    2. convert feature maps into a sequence of patches (tokens) with a patch size of 1
    3. use ViT encoder to extract the grounded visual features from concatenated image + language tokens
    4. discard the concatenated language tokens after the encoder and feed the visual positional tokens only
    5. use ViT decoder to reconstruct depth and predict reward from visual tokens

All modules are optimized using the mean square loss tq\ell_{t}^{q} of the depth reconstruction and reward prediction

Module Definition
Object Masking + Language Description om, otl=fobjectmask(ot)o_{m},\ o_{t}^{l} = f^{\mathrm{object-mask}}(o_{t})
Early Convolution htc=fϕconv(otm)h_{t}^{c} = f_{\phi}^{\mathrm{conv}}(o_{t}^{m})
Convolution Tokens Masking htc, mpmask(htc, mhtc, m)h_{t}^{c,\ m} \sim p^{\mathrm{mask}}(h_{t}^{c,\ m} \mid h_{t}^{c},\ m)
Language Embedding htl=fbert(otl)h_{t}^{l} = f^{\mathrm{bert}}(o_{t}^{l})
Tokens Concatence htc, m, lconcat(htc, m, htl)h_{t}^{c,\ m,\ l} \sim \operatorname{concat}(h_{t}^{c,\ m},\ h_{t}^{l})
MAE encoder ztc, m, lpϕ(ztc, m, lhtc, m, l)z_{t}^{c,\ m,\ l} \sim p_{\phi}(z_{t}^{c,\ m,\ l} \mid h_{t}^{c,\ m,\ l})
Slice Out s~tpslicel(ztc, m, l)\tilde{s}_{t} \sim p^{\mathrm{slice-l}}(z_{t}^{c,\ m,\ l})
MAE decoder o^t, r^tpϕ(o^t, r^ts~t)\hat{o}_{t},\ \hat{r}_{t} \sim p_{\phi}(\hat{o}_{t},\ \hat{r}_{t} \mid \tilde{s}_{t})

Predictive World Model + Behavior Learning

The future predictive world model contains the following components

Component Type Definition
Representation Model Inference stqθ(stst1, at1, s~t)s_{t} \sim q_{\theta}(s_{t} \mid s_{t - 1},\ a_{t - 1},\ \tilde{s}_{t})
Transition Model Generation stpθ(stst1, at1)s_{t} \sim p_{\theta}(s_{t} \mid s_{t - 1},\ a_{t - 1})

The observation representation model and predictive world model are optimized jointly through

LWM=Estqθ(st1, at1, s~t)[ttqβtKL]tKL=DKL(qθ(st1, at1, s~t)  pθ(st1, at1))\mathcal{L}_{\mathrm{WM}} = \mathcal{E}_{s_{t} \sim q_{\theta}(\cdot \mid s_{t - 1},\ a_{t - 1},\ \tilde{s}_{t})} \left[ \sum_{t} \ell_{t}^{q} - \beta \ell_{t}^{\mathrm{KL}} \right] \quad \ell_{t}^{\mathrm{KL}} = D_{\mathrm{KL}} \Big( q_{\theta}(\cdot \mid s_{t - 1},\ a_{t - 1},\ \tilde{s}_{t})\ \|\ p_{\theta}(\cdot \mid s_{t - 1},\ a_{t - 1}) \Big)

The actor qϕ(atst)q_{\phi}(a_{t} \mid s_{t}) and critic vϕ(st)=Eq(sτ)τ=tt+Hγτtrτv_{\phi}(s_{t}) = \mathcal{E}_{q(\cdot \mid s_{\tau})} \sum_{\tau = t}^{t + H} \gamma^{\tau - t} r_{\tau} is learned in the imagined latent space


LanGWM
http://example.com/2024/09/08/LanGWM/
Author
木辛
Posted on
September 8, 2024
Licensed under