【深度强化学习 DRL 快速实践】深度确定性策略梯度算法 (DDPG)
DDPG(2016,DeepMind) 核心改进点
深度确定性策略梯度算法 (DDPG): 通过融合 DQN+PG 优势,解决连续动作空间下的确定性策略问题
- model-free, off policy, actor-critic, deterministic 策略
核心改进点 | 说明 |
---|---|
策略梯度优化 (继承 PG) | 通过Actor网络直接优化策略,适应连续动作问题 |
延迟目标网络 (继承 DQN) | 避免Q值的估计震荡,提高算法的训练稳定性: θ Q ′ ← τ θ Q + ( 1 − τ ) θ Q ′ , θ μ ′ ← τ θ μ + ( 1 − τ ) θ μ ′ \theta^{Q'} \leftarrow \tau \theta^Q + (1 - \tau) \theta^{Q'}, \theta^{\mu'} \leftarrow \tau \theta^\mu + (1 - \tau) \theta^{\mu'} θQ′←τθQ+(1−τ)θQ′,θμ′←τθμ+(1−τ)θμ′ |
经验回放机制 (继承 DQN) | 训练时从存储的 (s, a, r, s’) 中随机采样,减少数据相关性和样本浪费 |
DDPG 网络更新
Critic 网络更新: θ Q \theta^Q θQ
Critic 网络的目标是最小化与 Q target Q^\text{target} Qtarget 的差距
L ( θ Q ) = 1 N ∑ [ Q ( s , μ ( s ) ) − Q target ] 2 , where Q target = r + γ Q ′ ( s ′ , μ ′ ( s ′ ) ) L(\theta^Q) = \frac{1}{N} \sum \left [ Q(s, \mu(s)) - Q^\text{target} \right ]^2, \text{where} \ Q^\text{target} = r + \gamma Q'(s', \mu'(s')) L(θQ)=N1∑[Q(s,μ(s))−Qtarget]2,where Qtarget=r+γQ′(s′,μ′(s′))
Actor 网络更新: θ μ \theta^\mu θμ
Actor 网络的目标是最大化 Critic 网络估计的 Q 值:
J ( θ μ ) = 1 N ∑ Q ( s , μ ( s ) ) J(\theta^\mu) = \frac{1}{N} \sum Q(s, \mu(s)) J(θμ)=N1∑Q(s,μ(s))
- 【深入思考】这里其实不是直接优化策略,还是基于值来决定策略的 【本质上 value-based】
基于 stable_baselines3 的快速代码示例
注意:训练时添加动作噪声很重要! 因为 DDPG 是确定性策略(deterministic policy)无法主动探索
import gymnasium as gym
import numpy as np
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise# 创建环境
env = gym.make("Pendulum-v1")
env.reset(seed=0)# 动作噪声: 重要 !!! 因为 DDPG 是确定性策略(deterministic policy)无法主动探索
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))# 初始化模型 -- DDPG --
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)# 训练
model.learn(total_timesteps=100_000)
model.save("ddpg_pendulum_v1")# 测试
obs, _ = env.reset()
total_reward = 0
for _ in range(200):action, _ = model.predict(obs, deterministic=True) ## 测试时才设置为确定性策略:deterministic=True obs, reward, terminated, truncated, _ = env.step(action)total_reward += rewardif terminated or truncated:breakprint("Test total reward:", total_reward)
参考资料:深度确定性策略梯度算法(DDPG)详解