FOCAL

FOCAL

Task Representation Learning

FOCAL considers a series of tasks {Ti}\{ \mathcal{T}_{i} \} with point-wise unique transitions, rewards and respective offline datasets {Di}\{ \mathcal{D}_{i} \}

 (s, a)S×A:T1=T2P1(s, a)=P2(s, a)R1(s, a)=R2(s, a)\forall\ (s,\ a) \in \mathcal{S} \times \mathcal{A} : \mathcal{T}_{1} = \mathcal{T}_{2} \Longleftrightarrow \mathcal{P}_{1}(s,\ a) = \mathcal{P}_{2}(s,\ a) \wedge \mathcal{R}_{1}(s,\ a) = \mathcal{R}_{2}(s,\ a)

To achieve efficient and robust task representation inference, FOCAL adopts a negative-power variant of contrastive loss

Ldml=i=1nj=1nE(si, ai, ri, si)DiE(sj, aj, rj, sj)Dj[1{i=j}zizj22+1{ij}βzizj2n+ϵ]\mathcal{L}_{dml} = \sum_{i = 1}^{n} \sum_{j = 1}^{n} \mathcal{E}_{(s_{i},\ a_{i},\ r_{i},\ s_{i}') \sim \mathcal{D}_{i}} \mathcal{E}_{(s_{j},\ a_{j},\ r_{j},\ s_{j}') \sim \mathcal{D}_{j}} \left[ \boldsymbol{1} \{ i = j \} \| z_{i} - z_{j} \|_{2}^{2} + \boldsymbol{1} \{ i \ne j \} \frac{\beta}{\| z_{i} - z_{j} \|_{2}^{n} + \epsilon} \right]

where the task embedding zz is derived from a deterministic context encoder qϕ(zc=s, a, r, s)q_{\phi}(z \mid c = s,\ a,\ r,\ s'), since the transition tuple (context) is unique across tasks. And the negative-power distance ensures seperation between task clusters

Meta Behavior Learning

Based on the multi-task offline datasets, FOCAL trains behavior regularized actor+critic conditioned on task embeddings

maxπEs0ρ0(, z), atπ(st, z), st+1p(st, at, z)[t=0γt(R(st, at, z)αD(π, πbst, z))]\max_{\pi} \mathcal{E}_{s_{0} \sim \rho_{0}(\cdot,\ z),\ a_{t} \sim \pi(\cdot \mid s_{t},\ z),\ s_{t + 1} \sim p(\cdot \mid s_{t},\ a_{t},\ z)} \left[ \sum_{t = 0}^{\infty} \gamma^{t} \Big( \mathcal{R}(s_{t},\ a_{t},\ z) - \alpha D(\pi,\ \pi_{b} \mid s_{t},\ z) \Big) \right]

Similar to SAC, the loss functions approximated by offline datasets for regularized actor and critic are

Lcritic=i=1nE(s, a, r, s)Di[r+γ(Eaπθ(s, zi)Qψ(s, a, zi)αD(πθ, πbs, zi))Qψ(s, a)]2Lactor=i=1nE(s, a, r, s)Di[Ea~πθ(s)Qψ(s, a~)αD(πθ, πbs, zi)]\begin{gathered} \mathcal{L}_{critic} = \sum_{i = 1}^{n} \mathcal{E}_{(s,\ a,\ r,\ s') \sim \mathcal{D}_{i}} \left[ r + \gamma \Big( \mathcal{E}_{a' \sim \pi_{\theta}(\cdot \mid s',\ z_{i})} Q_{\psi}(s',\ a',\ z_{i}) - \alpha D(\pi_{\theta},\ \pi_{b} \mid s',\ z_{i}) \Big) - Q_{\psi}(s,\ a) \right]^{2} \\[7mm] \mathcal{L}_{actor} = \sum_{i = 1}^{n} \mathcal{E}_{(s,\ a,\ r,\ s') \sim \mathcal{D}_{i}} \Big[ \mathcal{E}_{\tilde{a} \sim \pi_{\theta}(\cdot \mid s)} Q_{\psi}(s,\ \tilde{a}) - \alpha D(\pi_{\theta},\ \pi_{b} \mid s,\ z_{i}) \Big] \end{gathered}

The learning process of task representation and behavior are detached from each other for better efficiency and stability

The learned contextual policy can be deployed to unseen tasks with task embeddings generated from few-shot samples


FOCAL
http://example.com/2024/09/30/FOCAL/
Author
木辛
Posted on
September 30, 2024
Licensed under