UNICORN

UNICORN

Theoretic Framework

The probabilistic model of COMRL consists of the following random variables and (conditional) disribution terms

Random Variable Description CPD
MM Task (Instance of MDP) p(M)p(M)
Xb=(s, a)X_{b} = (s,\ a) Behavior-Related Context p(Xb)p(X_{b})
Xt=(r, s)X_{t} = (r,\ s') Task-Related Context p(XtXb, M)p(X_{t} \mid X_{b},\ M)
ZZ Task Representation p(ZXb, Xt)p(Z \mid X_{b},\ X_{t})

The task representation learning in COMRL aims to find a minimal sufficient statistics ZZ of task MM based on context XX

maxp(zx)I(Z; M)s.t.I(Z; M; Xb)0\begin{gathered} \max_{p(z \mid x)} I(Z;\ M) \\[5mm] \text{s.t.} \quad I(Z;\ M;\ X_{b}) \ge 0 \end{gathered}

Direct optimization is intractable in practice, with the assumption of I(Z; M;Xb)0I(Z;\ M; X_{b}) \ge 0, I(Z; M)I(Z;\ M) is lower bounded as

I(Z; M)=I(Z; MXb)+I(Z; M; Xb)0I(Z; MXb)=I(Z, Xt; MXb)I(Xt; MZ, Xb)=I(Xt; MXb)+I(Z; MXt, Xb)=0I(Xt; MZ, Xb)(ZMXt, Xb)=I(Xt; MXb)H(XtZ, Xb)+H(XtZ, Xb, M)0I(Xt; MXb)H(XtZ, Xb)=I(Xt; MXb)H(Xt)+I(Z, Xb; Xt)=I(Xt; MXb)constH(Xt)const+I(Xt; Xb)const+I(Z; XtXb)\begin{aligned} I(Z;\ M) &= I(Z;\ M \mid X_{b}) + \underset{\ge 0}{\underbrace{I(Z;\ M;\ X_{b})}} \ge I(Z;\ M \mid X_{b}) = I(Z,\ X_{t};\ M \mid X_{b}) - I(X_{t};\ M \mid Z,\ X_{b}) \\[7mm] &= I(X_{t};\ M \mid X_{b}) + \underset{= 0}{\underbrace{I(Z;\ M \mid X_{t},\ X_{b})}} - I(X_{t};\ M \mid Z,\ X_{b}) \quad \Leftarrow \quad (Z \perp M \mid X_{t},\ X_{b}) \\[7mm] &= I(X_{t};\ M \mid X_{b}) - H(X_{t} \mid Z,\ X_{b}) + \underset{\ge 0}{\underbrace{H(X_{t} \mid Z,\ X_{b},\ M)}} \ge I(X_{t};\ M \mid X_{b}) - H(X_{t} \mid Z,\ X_{b}) \\[7mm] &= I(X_{t};\ M \mid X_{b}) - H(X_{t}) + I(Z,\ X_{b};\ X_{t}) = \underset{\text{const}}{\underbrace{I(X_{t};\ M \mid X_{b})}} - \underset{\text{const}}{\underbrace{H(X_{t})}} + \underset{\text{const}}{\underbrace{I(X_{t};\ X_{b})}} + I(Z;\ X_{t} \mid X_{b}) \end{aligned}

Besides, I(Z; M)I(Z;\ M) is also upper bounded by I(Z; X)I(Z;\ X) due to the Markov chain MXZM \to X \to Z in the dependency graph

I(Z; M)I(Z; X)=EM, x, z[logp(zM)p(zx)]logMxzp(M)p(xM)p(zx)p(zM)p(zx)=0I(Z;\ M) - I(Z;\ X) = \mathcal{E}_{M,\ x,\ z} \left[ \log \frac{p(z \mid M)}{p(z \mid x)} \right] \le \log \sum_{M} \sum_{x} \sum_{z} p(M) p(x \mid M) \cancel{p(z \mid x)} \frac{p(z \mid M)}{\cancel{p(z \mid x)}} = 0

Consider the aforementioned two bounds of I(Z; M)I(Z;\ M), some pre-existing COMRL algorithms can be interpreted as

Algorithm Essential Optimization Objective Description
FOCAL maxI(Z; X)\max I(Z;\ X) Upper Bound
CORRO maxI(Z; XtXb)\max I(Z;\ X_{t} \mid X_{b}) Lower Bound
CSRO maxλI(Z; XtXb)+(1λ)I(Z; X)\max \lambda I(Z;\ X_{t} \mid X_{b}) + (1 - \lambda) I(Z;\ X) Hybrid

FOCAL

FOCAL tries to maximize the upper bound I(Z; X)I(Z;\ X), which is equivalent to the negative distance metric loss

I(Z; X)=Ex, z[logp(z, x)p(z)p(x)]=Ex, z[logp(zx)p(z)h(x, z)/MEx[p(zx)p(z)]=1]+logMconstMMxD(M)logh(x, z)MMxD(M)h(x, z)zfϕ(x)\begin{aligned} I(Z;\ X) &= \mathcal{E}_{x,\ z} \bigg[ \log \frac{p(z,\ x)}{p(z) p(x)} \bigg] = \mathcal{E}_{x,\ z} \bigg[ \log \underset{h(x,\ z)}{\underbrace{\frac{p(z \mid x)}{p(z)}}} \bigg/ |\mathcal{M}| \underset{= 1}{\underbrace{\mathcal{E}_{x'} \left[ \frac{p(z \mid x')}{p(z)} \right]}} \bigg] + \underset{\text{const}}{\underbrace{\log |\mathcal{M}|}} \\[7mm] &\approx \sum_{M \in \mathcal{M}} \sum_{x \in \mathcal{D}(\mathcal{M})} \log \frac{h(x,\ z)}{\sum_{M' \in \mathcal{M}} \sum_{x' \in \mathcal{D}(M')} h(x',\ z)} \quad \Leftarrow \quad z \sim f_{\phi}(\cdot \mid x) \end{aligned}

However, such objective may lead to spurious correlation under distribution shift of XbX_{b} (ZZ is solely conditioned on XbX_{b})

CORRO

To alleviate the degeneration caused by distribution shift, CORRO proposes to maximize the lower bound I(Z; XtXb)I(Z;\ X_{t} \mid X_{b})

I(Z; XtXb)=Ex, z[logp(zxt, xb)p(zxb)h(xb, xt, z)/MEMp(M)Extp(xtM, xb)[p(zxt, xb)p(zxb)]=1]+logMconstMMxD(M)logh(xb, xt, z)MDh(xb, xt, z)zfϕ(xb, xt)xtgψ(xb, M)\begin{aligned} I(Z;\ X_{t} \mid X_{b}) &= \mathcal{E}_{x,\ z} \bigg[ \log \underset{h(x_{b},\ x_{t},\ z)}{\underbrace{\frac{p(z \mid x_{t},\ x_{b})}{p(z \mid x_{b})}}} \bigg/ |\mathcal{M}| \underset{= 1}{\underbrace{\mathcal{E}_{M^{*} \sim p(M)} \mathcal{E}_{x_{t}^{*} \sim p(x_{t} \mid M^{*},\ x_{b})} \left[ \frac{p(z \mid x_{t}^{*},\ x_{b})}{p(z \mid x_{b})} \right]}} \bigg] + \underset{\text{const}}{\underbrace{\log |\mathcal{M}|}} \\[7mm] &\approx \sum_{M \in \mathcal{M}} \sum_{x \in \mathcal{D}(\mathcal{M})} \log \frac{h(x_{b},\ x_{t},\ z)}{\sum_{M^{*} \sim \mathcal{D}} h(x_{b},\ x_{t}^{*},\ z)} \quad \Leftarrow \quad z \sim f_{\phi}(\cdot \mid x_{b},\ x_{t}) \quad x_{t}^{*} \sim g_{\psi}(\cdot \mid x_{b},\ M^{*}) \end{aligned}

CSRO

CSRO maximizes I(Z; X)I(Z;\ X) and minimizes the CLUB of I(Z; Xb)I(Z;\ X_{b}) to alleviate the distribution shift problem of context

I(Z; X)λICLUB(Z; Xb)=I(Z; X)λ[E(z, xb)p(z, xb)logp(zxb)Ezp(z)Exbp(xb)logp(zxb)] I(Z; X)λI(Z; Xb)=I(Z; Xt, Xb)λ[I(Z; Xt, Xb)I(Z; XtXb)]=λI(Z; XtXb)+(1λ)I(Z; X)\begin{aligned} &I(Z;\ X) - \lambda I_{\text{CLUB}}(Z;\ X_{b}) = I(Z;\ X) - \lambda \Big[ \mathcal{E}_{(z,\ x_{b}) \sim p(z,\ x_{b})} \log p(z \mid x_{b}) - \mathcal{E}_{z \sim p(z)} \mathcal{E}_{x_{b} \sim p(x_{b})} \log p(z \mid x_{b}) \Big] \\[5mm] \ge\ &I(Z;\ X) - \lambda I(Z;\ X_{b}) = I(Z;\ X_{t},\ X_{b}) - \lambda \Big[ I(Z;\ X_{t},\ X_{b}) - I(Z;\ X_{t} \mid X_{b}) \Big] = \lambda I(Z;\ X_{t} \mid X_{b}) + (1 - \lambda) I(Z;\ X) \end{aligned}

General Implementation

With the derived theoretic framework, UNICORN formulates the optimization objective based on information bottleneck

minp(zx)I(Z; X)s.t.I(Z; M)Icminp(zx)LIB=I(Z; X)βI(Z; M)\min_{p(z \mid x)} I(Z;\ X) \quad \text{s.t.} \quad I(Z;\ M) \ge I_{c} \quad \Longrightarrow \quad \min_{p(z \mid x)} \mathcal{L}_{\text{IB}} = I(Z;\ X) - \beta I(Z;\ M)

The first term I(Z; X)I(Z;\ X) is implement as the FOCAL objective and the second term is approximated as

I(Z; M)αI(Z; X)+(1α)I(Z; XtXb)α[0, 1]I(Z;\ M) \approx \alpha I(Z;\ X) + (1 - \alpha) I(Z;\ X_{t} \mid X_{b}) \qquad \alpha \in [0,\ 1]

which is a convex combinition of FOCAL and CORRO like CSRO. Substitute the approximation into LIB\mathcal{L}_{\text{IB}} and scale it as

LIB=I(Z; X)αβI(Z; X)(1α)βI(Z; XtXb)[I(Z; XtXb)+αβ1(1α)βI(Z; X)]\mathcal{L}_{\text{IB}} = I(Z;\ X) - \alpha \beta I(Z;\ X) - (1 - \alpha) \beta I(Z;\ X_{t} \mid X_{b}) \Rightarrow -\left[ I(Z;\ X_{t} \mid X_{b}) + \frac{\alpha \beta - 1}{(1 - \alpha) \beta} I(Z;\ X) \right]

Instead of using CORRO or CLUB in CSRO to approximate I(Z; XtXb)I(Z;\ X_{t} \mid X_{b}), UNICORN proposes to rewrite it as

I(Z; XtXb)=I(Z, Xb; Xt)I(Xt; Xb)constE(xt, xb)p(xt, xb)Ezqϕ(zxt, xb)logpθ(xtz, xb)Lrecon+constI(Z;\ X_{t} \mid X_{b}) = I(Z,\ X_{b};\ X_{t}) - \underset{\text{const}}{\underbrace{I(X_{t};\ X_{b})}} \approx \underset{-\mathcal{L}_{\text{recon}}}{\underbrace{\mathcal{E}_{(x_{t},\ x_{b}) \sim p(x_{t},\ x_{b})} \mathcal{E}_{z \sim q_{\phi}(z \mid x_{t},\ x_{b})} \log p_{\theta}(x_{t} \mid z,\ x_{b})}} + \text{const}

where pθ(xtz, xb)p_{\theta}(x_{t} \mid z,\ x_{b}) is introduced as an unbiased estimator for p(xtz, xb)p(x_{t} \mid z,\ x_{b}) and also able to serve as the world model


UNICORN
http://example.com/2024/10/26/UNICORN/
Author
木辛
Posted on
October 26, 2024
Licensed under