RAP

RAP

LLM World Model

RAP repurpose LLM as an internal world model, which enables problem-specific definition of state and action

Blocksworld Planning Math Reasoning Logical Reasoning

The default policy π(atst, c)\pi(a_{t} \mid s_{t},\ c) and dynamics function p(st+1st, at, c)p(s_{t + 1} \mid s_{t},\ a_{t},\ c') are modeled by generative LLM, where cc and cc' are task-specific prompt for LLM to behave as policy and dynamics, respectively

Compared to previous reasoning method like CoT, augmenting the reasoning process with the help of states predicted by LLM as internal world model makes more grounded and coherent inference

Reward Assessment

Similarly, the reward function r(st, at)r(s_{t},\ a_{t}) can be specified in different ways depends on the reasoning problem

  1. likelihood of action
    1. incorporate the log probability of the action as a reward r(st, at)=logπ(atst)r(s_{t},\ a_{t}) = \log \pi(a_{t} \mid s_{t})
    2. the probability of the specific action reflects the LLM’s preference
  2. confidence of state
    1. draw multiple predicted state st+1s_{t + 1} from the world model st+1p(st, at)s_{t + 1} \sim p(\cdot \mid s_{t},\ a_{t})
    2. use the proportion of the most frequent result (confidence) as the reward
    3. higher confidence indicates that the state prediction is more consistent with the knowledge of LLMs
  3. self-evaluation by the LLM
    1. use the LLM to criticize itself with the question Is this reasoning step correct ?
    2. use the next-word probability of the token Yes as the reward
    3. this evaluates LLM’s own estimation of the correctness of reasoning
  4. task-specific heuristics

MCTS Planning

RAP adopts MCTS to strategically explores the reasoning space and balance exploration and exploitation

Each internal node of the search tree maintains statistics like state-value function Q(s, a)Q(s,\ a), visit count N(s)N(s)

Selection Expansion Simulation Backup

The reasoning process continues with the following phases untils a specified computational budget

  1. selection
    1. an action is selected at each level of reasoning tree via UCB value until a leaf node is encountered

    a=arg maxaA(s)[Q(s, a)+wlnN(s)N(c(s, a))]a^{\star} = \argmax_{a \in A(s)} \left[ Q(s,\ a) + w \sqrt{\frac{\ln N(s)}{N(c(s,\ a))}} \right]

    1. the exploration weight ww controls the balance between exploration and exploitation
  2. expansion
    1. sample dd possible actions a(1:d)a^{(1:d)} from LLM policy π(as, c)\pi(a \mid s,\ c) rather than enumerate all actions
    2. use LLM world model p(ss, a)p(s' \mid s,\ a) to predict respective next states for sampled actions
  3. simulation
    1. use light-weight rollout policy and reward assessment to perform quick simulation
    2. the reasoning tree is recursively expanded at each level until a terminal state
  4. backup
    1. a reasoning path {s0:T, a0:T1}\{ s_{0:\mathrm{T}},\ a_{0:\mathrm{T} - 1} \} from the root node to terminal node is obtained from previous phases
    2. the state-value function of each node Q(st, at)Q(s_{t},\ a_{t}) on the reasoning path is updated

The final reasoning trace is selected from the constructed tree, which can be implemented as

  1. choose the action with the highest QQ value iteratively until reaching a terminal
  2. select the path from the iterations that yielded the highest reward
  3. choose the leaf node and the respective root-to-leaf path that has been visited the most

RAP
http://example.com/2024/09/11/RAP/
Author
木辛
Posted on
September 11, 2024
Licensed under