IRIS

IRIS

World Model Learning

The world model is composed of a discrete autoencoder (E, D)(E,\ D) to learn representation of observation and a GPT-like autoregressive Transformer GG to capture environment dynamics

Component Type Definition Distribution Family
Observation Encoder Representation E:Rh×w×3{1, 2, , N}KE : \mathbb{R}^{h \times w \times 3} \mapsto \{ 1,\ 2,\ \cdots,\ N \}^{K} Deterministic
Observation Decoder Representation D:{1, 2, , N}KRh×w×3D : \{ 1,\ 2,\ \cdots,\ N \}^{K} \mapsto \mathbb{R}^{h \times w \times 3} Deterministic
Transition Predictor Dynamics zt+1pG(zt+1zt, at)z_{t + 1} \sim p_{G}(z_{t + 1} \mid z_{\le t},\ a_{\le t})

zt+1kpG(zt+1kzt, at, zt+1<k)z_{t + 1}^{k} \sim p_{G}(z_{t + 1}^{k} \mid z_{\le t},\ a_{\le t},\ z_{t + 1}^{< k})
Categorical
Reward Predictor Dynamics rtpG(rtzt, at)r_{t} \sim p_{G}(r_{t} \mid z_{\le t},\ a_{\le t}) Categorical / Deterministic
Termination Predictor Dynamics dtpG(γtzt, at)d_{t} \sim p_{G}(\gamma_{t} \mid z_{\le t},\ a_{\le t}) Bernoulli

Representation

The representation state ztz_{t} consists of KK tokens from a vocabulary of size NN. The encoder EE first produces a group of vector ze(xt)RK×dz_{e}(x_{t}) \in \mathbb{R}^{K \times d}, then obtains the output tokens through a codebook E={eiRd}i=1N\mathcal{E} = \{e_{i} \in \mathbb{R}^{d}\}_{i = 1}^{N}

zt=(zt1, zt2, , ztK)=[arg minizek(xt)ei2]k=1Kz_{t} = (z_{t}^{1},\ z_{t}^{2},\ \cdots,\ z_{t}^{K}) = \left[ \argmin_{i} \Big\| z_{e}^{k}(x_{t}) - e_{i} \Big\|_{2} \right]_{k = 1}^{K}

The discrete autoencoder (E, D)(E,\ D) is trained to maximize the ELBO of the log-likelihood

lnp(x)=lnEzq(zx)[p(xz)k=1Kp(zk)k=1Kq(zkx)]Ezq(zx)lnp(xz)k=1KDKL(q(zkx)  p(zk))\ln p(x) = \ln \mathcal{E}_{z \sim q(z \mid x)} \left[ p(x \mid z) \frac{\prod_{k = 1}^{K} p(z^{k})}{\prod_{k = 1}^{K} q(z^{k} \mid x)} \right] \ge \mathcal{E}_{z \sim q(z \mid x)} \ln p(x \mid z) - \sum_{k = 1}^{K} D_{\mathrm{KL}} \Big( q(z^{k} \mid x)\ \|\ p(z^{k}) \Big)

where the posterior q(zkx)q(z^{k} \mid x) is one-hot distributed and the prior p(zk)p(z^{k}) is assumed to be uniformly distributed

lnp(x)lnp(xzq(x))KlnNconsts.t.zqk(x)=eii=arg minjzek(x)ei2\ln p(x) \ge \ln p(x \mid z_{q}(x)) - \underset{\mathrm{const}}{\underbrace{K \ln N}} \qquad \mathrm{s.t.} \quad z_{q}^{k}(x) = e_{i} \quad i = \argmin_{j} \Big\| z_{e}^{k}(x) - e_{i} \Big\|_{2}

where zq(x)z_{q}(x) is calculated via zqk(x)=zek(x)+sg(eizek(x))z_{q}^{k}(x) = z_{e}^{k}(x) + \operatorname{sg} (e_{i} - z_{e}^{k}(x)) in practice to introduce straight-through gradients. The overall objective for (E, D, E)(E,\ D,\ \mathcal{E}) includes the aforementioned ELBO and several additional items

Loss Definition Target
Reconstruction Loss Lrec=logp(xzq(x))xD(z)1\mathcal{L}_{\mathrm{rec}} = \log p(x \mid z_{q}(x)) \Rightarrow \| x - D(z) \|_{1} encoder + decoder
Codebook Loss Lcode=k=1Ksg(zek(x))E(zk)22\mathcal{L}_{\mathrm{code}} = \sum_{k = 1}^{K} \Big\| \operatorname{sg}(z_{e}^{k}(x)) - \mathcal{E}(z^{k}) \Big\|_{2}^{2} codebook
Commitment Loss Lcom=k=1Kzek(x)sg(E(zk))22\mathcal{L}_{\mathrm{com}} = \sum_{k = 1}^{K} \Big\| z_{e}^{k}(x) - \operatorname{sg}(\mathcal{E}(z^{k})) \Big\|_{2}^{2} encoder
Perceptual Loss Lperceptual(x, D(z))\mathcal{L}_{\mathrm{perceptual}}(x,\ D(z)) encoder + decoder

Dynamics

The autoregressive Transformer GG is trained in a self-supervised manner on segments sampled from past experience to minimize the difference between prediction and ground truth. The overall objective includes

Loss Target
Cross Entropy Loss Transition Predictor
Cross Entropy Loss / MSE Loss Reward Predictor
Cross Entropy Loss Termination Predictor

Behavior Learning

Following Dreamer, the critic network v(xt)v(x_{t}) is optimized through λ return, which is recursively defined as

Λt={rt+γ(1dt)[(1λ)v(xt+1)+λΛt+1]t<Hv(xH)t=H\Lambda_{t} = \left\{ \begin{matrix} r_{t} + \gamma (1 - d_{t}) \Big[ (1 - \lambda) v(x_{t + 1}) + \lambda \Lambda_{t + 1} \Big] & t < H \\[5mm] v(x_{H}) & t = H \end{matrix} \right.

Lcritic=Eπ[t=0H1(v(xt)sg(Λt))2]\mathcal{L}_{\mathrm{critic}} = \mathcal{E}_{\pi} \left[ \sum_{t = 0}^{H - 1} \Big( v(x_{t}) - \operatorname{sg}(\Lambda_{t}) \Big)^{2} \right]

The actor network π(atxt)\pi(a_{t} \mid x_{\le t}) is trained to minimize the REINFORCE objective over imagined trajectories

Lπ=Eπ[t=0H1lnπ(atxt)sg(Λtv(xt))+ηH(πxt)]\mathcal{L}_{\pi} = - \mathcal{E}_{\pi} \left[ \sum_{t = 0}^{H - 1} \ln \pi(a_{t} \mid x_{\le t}) \operatorname{sg}(\Lambda_{t} - v(x_{t})) + \eta \mathcal{H}(\pi \mid x_{\le t}) \right]


IRIS
http://example.com/2024/09/04/IRIS/
Author
木辛
Posted on
September 4, 2024
Licensed under