跳过正文

时序差分算法

·1968 字·4 分钟
RL Hands-on-Rl
Hands-on-RL - 这篇文章属于一个选集。
§ 4: 本文

本系列是学习《动手学强化学习》 过程中做的摘抄。

无模型(model-free)的强化学习,如 Sarsa 和 Q-learning,智能体只能和环境进行交互,通过采样到的数据来学习。不同于动态规划算法,model-free 的强化学习算法不需要事先知道环境的奖励函数和状态转移函数,而是直接使用和环境交互的过程中采样到的数据来学习,这使得它可以被应用到一些简单的实际场景中。

将采样数据的策略称为行为策略(behavior policy),称用这些数据来更新的策略为目标策略(target policy)

在线策略(on-policy)学习表示行为策略和目标策略是同一个策略,要求使用在当前策略下采样得到的样本进行学习,一旦策略被更新,当前的样本就被放弃了。

离线策略(off-policy)学习表示行为策略和目标策略不是同一个策略,使用经验回放池将之前采样得到的样本收集起来再次利用,能更好地利用历史数据,并具有更小的样本复杂度。

4.1 时序差分
#

时序差分(temporal difference, TD)结合了蒙特卡洛和动态规划算法的思想:

  • TD 和蒙特卡洛的相似之处在于可以从样本数据中学习,不需要事先知道环境;
  • TD 和动态规划的相似之处在于可以根据贝尔曼方程的思想,利用后续状态的价值估计来更新当前状态的价值估计。

4.2 Sarsa
#

Sarsa 的更新公式必须使用当前策略采样得到的五元组 \((s,a,r,s',a')\),因此它是 on-policy 算法。它直接使用 TD 算法来估计动作价值函数 \(Q(s,a)\):

$$ Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha \left[ r_t + \gamma Q(s_{t+1},a_{t+1}) - Q(s_t,a_t) \right] $$

然后用 \(\epsilon-greedy\) 算法根据动作价值选取动作来和环境交互,再根据得到的数据用 TD 算法更新动作价值估计。

$$ \pi(a|s) = \begin{cases} \epsilon/|A|+1-\epsilon & \text{if } a = \arg \max_a Q(s,a) \\ \epsilon/|A| & \text{otherwise} \end{cases} $$
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np


class Sarsa:
    def __init__(self, ncol, nrow, epsilon, alpha, gamma, n_action=4):
        self.Q_table = np.zeros([nrow * ncol, n_action])  # 初始化Q(s,a)表格
        self.n_action = n_action  # 动作个数
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # epsilon-贪婪策略中的参数

    def take_action(self, state) -> int:
        """选取下一步的操作,具体实现为 epsilon-贪婪"""
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.n_action)
        else:
            action = np.argmax(self.Q_table[state])
        return action

    def best_action(self, state) -> list:
        """打印状态 s 对应的策略"""
        Q_max = np.max(self.Q_table[state])
        a = [0 for _ in range(self.n_action)]
        for i in range(self.n_action):  # 若两个动作的价值一样,都会记录下来
            if self.Q_table[state, i] == Q_max:
                a[i] = 1
        return a

    def update(self, s0, a0, r, s1, a1):
        td_error = r + self.gamma * self.Q_table[s1, a1] - self.Q_table[s0, a0]
        self.Q_table[s0, a0] += self.alpha * td_error

4.3 多步Sarsa
#

蒙特卡洛方法无偏的,但是具有比较大的方法;TD 算法具有非常小的方差,但它是有偏的。多步时序差分结合二者的优势,使用 \(n\) 步的奖励,然后使用之后状态的价值估计,用公式表示:将 \(G_t=r_t+\gamma Q(s_{t+1},a_{t+1})\) 替换成 \(G_t = r_t + \gamma r_{t+1} + \cdots + \gamma^{n}Q(s_{t+n},a_{t+n})\)。

于是,相应存在一种多步 Sarsa 算法,它的动作价值函数更新公式变为:

$$ Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha \left[ r_t + \gamma r_{t+1} + \cdots + \gamma^{n} Q(s_{t+n},a_{t+n}) - Q(s_t,a_t) \right] $$
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np


class nstep_Sarsa:
    def __init__(self, n, ncol, nrow, epsilon, alpha, gamma, n_action=4):
        self.Q_table = np.zeros([nrow * ncol, n_action])
        self.n_action = n_action
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.n = n  # 采用 n 步 Sarsa 算法
        self.state_list = []  # 保存之前的状态
        self.action_list = []  # 保存之前的动作
        self.reward_list = []  # 保存之前的奖励

    def take_action(self, state) -> int:
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.n_action)
        else:
            action = np.argmax(self.Q_table[state])
        return action

    def best_action(self, state) -> list:
        """打印状态 s 对应的策略"""
        Q_max = np.max(self.Q_table[state])
        a = [0 for _ in range(self.n_action)]
        for i in range(self.n_action):
            if self.Q_table[state, i] == Q_max:
                a[i] = 1
        return a

    def update(self, s0, a0, r, s1, a1, done):
        # 保存之前的状态、动作和奖励
        self.state_list.append(s0)
        self.action_list.append(a0)
        self.reward_list.append(r)

        if len(self.state_list) == self.n:  # 若保存的数据可以进行 n 步更新
            G = self.Q_table[s1, a1]  # 对应 Q(s_{t+n}, a_{t+n})
            for i in reversed(range(self.n)):
                G = self.gamma * G + self.reward_list[i]  # 不断向前计算每一步的回报
                if done and i > 0:
                    # 如果到达终止状态,最后几步虽然长度不够 n 步,也将其进行更新
                    s = self.state_list[i]
                    a = self.action_list[i]
                    self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])

            # 将需要更新的状态动作从列表中删除,下次不必更新
            s = self.state_list.pop(0)
            a = self.action_list.pop(0)
            self.reward_list.pop(0)

            # n 步 Sarsa 的主要更新步骤
            self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])

        if done:  # 如果到达终止状态,即将开始下一条序列,则将列表全清空
            self.state_list = []
            self.action_list = []
            self.reward_list = []

4.4 Q-learning
#

Q-learning 和 Sarsa 的最大区别在于 Q-learning 的时序差分更新方式为

$$ Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \gamma \left[ r_t + \gamma \max_{a} Q(s_{t+1},a) - Q(s_t,a_t) \right] $$

Q-learning 的更新公式使用四元组 \((s,a,r,s')\) 来更新当前状态动作对的动作价值 \(Q(s,a)\),数据中的 \(s\) 和 \(a\) 是给定的条件,\(r\) 和 \(s'\) 皆由环境采样得到,该四元组并不需要一定是当前策略采样得到的数据,也可以来自行为策略,因此它是 off-policy 算法。

需要强调的是,Q-learning 的更新并非必须使用当前贪婪策略 \(\arg \max_a Q(s,a)\) 采样得到的数据,因为给定任意 \((s,a,r,s')\) 都可以直接根据更新公式来更新 \(Q\)。为了探索,通常使用一个 \( \epsilon-greedy \) 策略来与环境交互;而 Sarsa 必须使用当前 \( \epsilon-greedy \) 策略采样得到数据。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np


class QLearning:
    def __init__(self, ncol, nrow, epsilon, alpha, gamma, n_action=4):
        self.Q_table = np.zeros([nrow * ncol, n_action])  # 初始化 Q(s,a) 表格
        self.n_action = n_action  # 动作个数
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # epsilon-贪婪策略中的参数

    def take_action(self, state):
        """选取下一步的操作"""
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.n_action)
        else:
            action = np.argmax(self.Q_table[state])
        return action

    def best_action(self, state):
        Q_max = np.max(self.Q_table[state])
        a = [0 for _ in range(self.n_action)]
        for i in range(self.n_action):
            if self.Q_table[state, i] == Q_max:
                a[i] = 1
        return a

    def update(self, s0, a0, r, s1):
        td_error = r + self.gamma * self.Q_table[s1].max() - self.Q_table[s0, a0]
        self.Q_table[s0, a0] += self.alpha * td_error
Hands-on-RL - 这篇文章属于一个选集。
§ 4: 本文