VariBAD

VariBAD

Bayesian-Adaptive MDP

The Bayesian-Adaptive MDP (BAMDP) requires posterior inference on task (transition and reward model) on history

The belief state on task bt(T, R)=p(T, Rτt)=bt(T)bt(R)b_{t}(T,\ R) = p(T,\ R \mid \tau_{t}) = b_{t}(T) b_{t}(R) updates deterministically according to Bayes rule

bt+1(T)=p(Tτt+1)=p(τt+1T)p(T)p(τt+1)=p(τtT)p(T)p(τt+1)T(st+1st, at)=p(τt)p(τt+1)p(Tτt)T(st+1st, at)bt(T)T(st+1st, at)\begin{aligned} b_{t + 1}(T) &= p(T \mid \tau_{t + 1}) = \frac{p(\tau_{t + 1} \mid T) p(T)}{p(\tau_{t + 1})} \\[7mm] &= \frac{p(\tau_{t} \mid T) p(T)}{p(\tau_{t + 1})} T(s_{t + 1} \mid s_{t},\ a_{t}) \\[7mm] &= \frac{p(\tau_{t})}{p(\tau_{t + 1})} p(T \mid \tau_{t}) \cdot T(s_{t + 1} \mid s_{t},\ a_{t}) \\[7mm] &\propto b_{t}(T) \cdot T(s_{t + 1} \mid s_{t},\ a_{t}) \end{aligned}

bt+1(R)=p(Rτt+1)=p(τt+1R)p(R)p(τt+1)=p(τtR)p(R)p(τt+1)R(rtst, at)=p(τt)p(τt+1)p(Rτt)R(rtst, at)bt(R)R(rtst, at)\begin{aligned} b_{t + 1}(R) &= p(R \mid \tau_{t + 1}) = \frac{p(\tau_{t + 1} \mid R) p(R)}{p(\tau_{t + 1})} \\[7mm] &= \frac{p(\tau_{t} \mid R) p(R)}{p(\tau_{t + 1})} R(r_{t} \mid s_{t},\ a_{t}) \\[7mm] &= \frac{p(\tau_{t})}{p(\tau_{t + 1})} p(R \mid \tau_{t}) \cdot R(r_{t} \mid s_{t},\ a_{t}) \\[7mm] &\propto b_{t}(R) \cdot R(r_{t} \mid s_{t},\ a_{t}) \end{aligned}

where the prior distribution b0=p(R)p(T)b_{0} = p(R)p(T). The expected transition and reward function after observing the history τt\tau_{t} is

p(st+1st, at)=p(st+1st, at, τt)=TTT(st+1st, at)bt(T)dT=ETbt(T)T(st+1st, at)p(rtst, at, r<t)=p(rtst, at, τt)=RRR(rtst, at)bt(R)dR=ERbt(R)R(rtst, at)\begin{gathered} p(s_{t + 1} \mid s_{\le t},\ a_{\le t}) = p(s_{t + 1} \mid s_{t},\ a_{t},\ \tau_{t}) = \int_{T \in \mathcal{T}} T(s_{t + 1} \mid s_{t},\ a_{t}) b_{t}(T) dT = \mathcal{E}_{T \sim b_{t}(T)} T(s_{t + 1} \mid s_{t},\ a_{t}) \\[7mm] p(r_{t} \mid s_{\le t},\ a_{\le t},\ r_{< t}) = p(r_{t} \mid s_{t},\ a_{t},\ \tau_{t}) = \int_{R \in \mathcal{R}} R(r_{t} \mid s_{t},\ a_{t}) b_{t}(R) dR = \mathcal{E}_{R \sim b_{t}(R)} R(r_{t} \mid s_{t},\ a_{t}) \end{gathered}

The agent’s objective in BAMDP is to maximize the expected return with the Bayesian estimation on transition and reward

maxπJ(π)=Est+1p(st+1st, at, τt), rtp(rtst, at, τt), atπ(atst, bt)[t=0H1γtrt]\max_{\pi} \mathcal{J}(\pi) = \mathcal{E}_{s_{t + 1} \sim p(s_{t + 1} \mid s_{t},\ a_{t},\ \tau_{t}),\ r_{t} \sim p(r_{t} \mid s_{t},\ a_{t},\ \tau_{t}),\ a_{t} \sim \pi(a_{t} \mid s_{t},\ b_{t})} \left[ \sum_{t = 0}^{H - 1} \gamma^{t} r_{t} \right]

The Bayes-optimal agent’s exploration is regulated by maximizing the expected return, thus balanced with exploitation

Approximate Task Inference

However, direct solving BAMDP is intractable, VariBAD proposes to make approximate task inference via meta-learning

The encoder qϕ(mτt)q_{\phi}(m \mid \tau_{t}) is trained by amortised inference with the decoder pθ(r, ss, a, m)p_{\theta}(r,\ s' \mid s,\ a,\ m) to maximize the ELBO

EτHMlogpθ(τH)ELBOt(ϕ, θM)=EτHM[Emqϕ(τt)logpθ(τHm)DKL(qϕ(mτt)  p(m))]\mathcal{E}_{\tau_{H} \sim M} \log p_{\theta}(\tau_{H}) \ge \text{ELBO}_{t}(\phi,\ \theta \mid M) = \mathcal{E}_{\tau_{H} \sim M} \left[ \mathcal{E}_{m \sim q_{\phi}(\cdot \mid \tau_{t})} \log p_{\theta}(\tau_{H} \mid m) - D_{\text{KL}} \Big( q_{\phi}(m \mid \tau_{t})\ |\ p(m) \Big) \right]

where the prior p(m)p(m) is set to the previous posterior qϕ(mτt1)q_{\phi}(m \mid \tau_{t - 1}) and the reconstruction term logpθ(τHm)\log p_{\theta}(\tau_{H} \mid m) is

logpθ(τHm)=log[pθT(s0m)i=0H1pθT(si+1si, ai, m)pθR(ri+1si, ai, m)]=logpθT(s0m)+i=0H1logpθT(si+1si, ai, m)+logpθR(risi, ai, m)\begin{aligned} \log p_{\theta}(\tau_{H} \mid m) &= \log \left[ p_{\theta}^{T}(s_{0} \mid m) \prod_{i = 0}^{H - 1} p_{\theta}^{T}(s_{i + 1} \mid s_{i},\ a_{i},\ m) p_{\theta}^{R}(r_{i + 1} \mid s_{i},\ a_{i},\ m) \right] \\[7mm] &= \log p_{\theta}^{T}(s_{0} \mid m) + \sum_{i = 0}^{H - 1} \log p_{\theta}^{T}(s_{i + 1} \mid s_{i},\ a_{i},\ m) + \log p_{\theta}^{R}(r_{i} \mid s_{i},\ a_{i},\ m) \end{aligned}

A meta policy πψ(as, qϕ(mτt))\pi_{\psi}(a \mid s,\ q_{\phi}(m \mid \tau_{t})) is trained jointly on multiple tasks derived from p(M)p(M) by common RL algorithms

maxϕ, θ, ψL(ϕ, θ, ψ)=EMp(M)[J(πψM)+t=0HELBOt(ϕ, θM)]\max_{\phi,\ \theta,\ \psi} \mathcal{L}(\phi,\ \theta,\ \psi) = \mathcal{E}_{M \sim p(M)} \left[ \mathcal{J}(\pi_{\psi} \mid M) + \sum_{t = 0}^{H} \operatorname{ELBO}_{t}(\phi,\ \theta \mid M) \right]

where the p(M)p(M) can be viewed as the prior in BAMDP and meta-RL can thus be viewed as an instance of BAMDP


VariBAD
http://example.com/2024/11/01/VariBAD/
Author
木辛
Posted on
November 1, 2024
Licensed under