当前位置: 首页 > news >正文

【动手学强化学习】part7-Actor-Critic算法

阐述、总结【动手学强化学习】章节内容的学习情况,复现并理解代码。

文章目录

  • 一、算法背景
    • 1.1 算法目标
    • 1.2 存在问题
    • 1.3 解决方法
  • 二、XX算法
    • 2.1 必要说明
      • · 优势函数
    • 2.2 伪代码
      • · 算法流程简述
    • 2.3 算法代码
    • 2.4 运行结果
      • · 结果分析
    • 2.5 算法流程说明
      • · 初始化参数
      • · 初始化环境
      • · 初始化网络
      • · 采样episode
      • · 价值更新
      • · 策略更新
  • 三、疑问
  • 四、总结


一、算法背景

1.1 算法目标

1.2 存在问题

1.3 解决方法

二、XX算法

  • 🌟算法类型
    环境依赖:❌model-based ✅model-free
    价值估计:✅non-incremental ❌incremental(采样一个完整的episode后才进行网络更新,而不是像TD采样一个(s,a,r,s’)就更新Q值)
    价值表征:❌tabular representation ✅function representation(value_net(critic)网络用于学习价值函数)
    学习方式:✅on-policy ❌off-policy(采样episode和优化的policy都是policy_net(actor))
    策略表征:❌value-based ✅policy-based(Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习)

2.1 必要说明

· 优势函数

REINFORCE 通过蒙特卡洛采样的方法对策略梯度的估计是无偏的,但是方差非常大,在policy_net网络更新过程中,损失函数(交叉熵)的设定采用 q π θ ( s t , a t ) q_{\piθ}(s_t,a_t) qπθ(st,at) 作为指导,该值是通过episode采样后蒙特卡洛法估计的。
在这个基础上,还可以使用其它函数作为指导,例如:
A π θ ( s t , a t ) = q π θ ( s t , a t ) − v π θ ( s t ) A_{\piθ}(s_t,a_t)=q_{\piθ}(s_t,a_t)-v_{\piθ}(s_t) Aπθ(st,at)=qπθ(st,at)vπθ(st)
有贝尔曼公式中定义:
v π ( s ) a = ∑ a π ( a ∣ s ) a q π ( s , a ) {v_\pi(s)}_{a}=\sum_a{\pi(a|s)}_{a}{q_\pi(s,a)} vπ(s)a=aπ(as)aqπ(s,a)
v π θ ( s t ) v_{\piθ}(s_t) vπθ(st)可看作为 q π θ ( s t , a t ) q_{\piθ}(s_t,a_t) qπθ(st,at) 的加权平均,若某个q(s,a)优于v(s),则代表这个q(s,a)肯定是在“平均线”以上的。我们将上式称之为“优势函数”
又根据q(s,a)的定义,可将优势函数转换为:
q π ( s , a ) = ∑ r p ( r ∣ s , a ) r + γ ∑ s ′ p ( s ′ ∣ s , a ) v π ( s ′ ) = r + γ v ( s ′ ) q_\pi(s,a)=\sum_rp(r|s,a)r+\gamma\sum_{s'}p(s'|s,a)v_\pi(s')=r+\gamma v(s') qπ(s,a)=rp(rs,a)r+γsp(ss,a)vπ(s)=r+γv(s)
A π θ ( s t , a t ) = r t + γ v π θ ( s t + 1 ) − v π θ ( s t ) A_{\piθ}(s_t,a_t)=r_t+\gamma v_{\pi_\theta}(s_{t+1})-v_{\pi_\theta}(s_t) Aπθ(st,at)=rt+γvπθ(st+1)vπθ(st)
而在actor-critic算法中,**通过训练value_net(critic)网络去估计 v π θ ( s t ) v_{\piθ}(s_t) vπθ(st)值。**从而采用优势函数去指导policy_net网络更新,即交叉熵的真值设定为优势函数

2.2 伪代码

在这里插入图片描述

· 算法流程简述

①初始化网络:初始化value_net和policy_net网络模型。
②采样episode:设定周期数num_episodes,循环迭代采样episode,对于单个episode而言,根据policy_net获取当前state的action,再与环境交互env.step,获得(s,a,r,s’,done)样本,并直至terminal state(done为True),将样本存储在字典中。
③价值更新:计算TD_target,计算损失函数(均方误差),对value_net进行梯度清零+反向传播+参数更新。
④策略更新:计算TD_delta(优势函数),计算损失函数(交叉熵),对policy_net进行梯度清零+反向传播+参数更新。
⑤终止判断:根据已产生episode个数是否达到num_episodes,判断算法是否终止,并输出policy。

2.3 算法代码

import gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class PolicyNet(torch.nn.Module):'''策略网络是4-128-2的,输入为state,输出为action的归一化概率'''def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):'''价值网络是4-128-1的,输入为state,输出为state value:v(s)的估计值'''def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class ActorCritic:def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,gamma, device):# 策略网络self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)  # 价值网络# 策略网络优化器self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)  # 价值网络优化器self.gamma = gammaself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 时序差分目标td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)  # 时序差分残差log_probs = torch.log(self.actor(states).gather(1, actions))actor_loss = torch.mean(-log_probs * td_delta.detach())# 均方误差损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()  # 计算策略网络的梯度critic_loss.backward()  # 计算价值网络的梯度self.actor_optimizer.step()  # 更新策略网络的参数self.critic_optimizer.step()  # 更新价值网络的参数actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,gamma, device)return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()

2.4 运行结果

Iteration 0: 100%|██████████| 100/100 [00:01<00:00, 62.50it/s, episode=100, return=19.000]
Iteration 1: 100%|██████████| 100/100 [00:03<00:00, 32.51it/s, episode=200, return=58.600]
Iteration 2: 100%|██████████| 100/100 [00:04<00:00, 20.90it/s, episode=300, return=94.000]
Iteration 3: 100%|██████████| 100/100 [00:08<00:00, 11.72it/s, episode=400, return=187.100]
Iteration 4: 100%|██████████| 100/100 [00:10<00:00,  9.66it/s, episode=500, return=155.900]
Iteration 5: 100%|██████████| 100/100 [00:11<00:00,  8.53it/s, episode=600, return=196.000]
Iteration 6: 100%|██████████| 100/100 [00:12<00:00,  8.01it/s, episode=700, return=200.000]
Iteration 7: 100%|██████████| 100/100 [00:12<00:00,  7.82it/s, episode=800, return=200.000]
Iteration 8: 100%|██████████| 100/100 [00:14<00:00,  7.09it/s, episode=900, return=195.500]
Iteration 9: 100%|██████████| 100/100 [00:12<00:00,  8.16it/s, episode=1000, return=200.000]

在这里插入图片描述
在这里插入图片描述

· 结果分析

抖动情况相比 REINFORCE 算法有了一定程度的改进,这说明优势函数的引入减小了方差。

2.5 算法流程说明

· 初始化参数

actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

· 初始化环境

env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)

· 初始化网络

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,gamma, device)
...def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,gamma, device):# 策略网络self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)  # 价值网络# 策略网络优化器self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)  # 价值网络优化器self.gamma = gammaself.device = device
...
class PolicyNet(torch.nn.Module):'''策略网络是4-128-2的,输入为state,输出为action的归一化概率'''def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):'''价值网络是4-128-1的,输入为state,输出为state value:v(s)的估计值'''def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)

policy_net(actor)设定为4-128-2的全连接层网络,输入为state=4维,输出为action概率=2维,其定义与REINFORCE算法中相同:
y = s o f t m a x ( f c 2 ( r e l u ( f c 1 ( x ) ) ) ) , x = s t a t e s y=softmax\left(fc_2\left(relu\left(fc_1\left(x\right)\right)\right)\right),x=states y=softmax(fc2(relu(fc1(x)))),x=states
value_net(critic)设定为4-128-1的全连接层网络,输入维state=4维,输出为state value=1维,其定义与DQN算法中类似:
y = f c 2 ( r e l u ( f c 1 ( x ) ) ) , x = s t a t e s y=fc_2(relu(fc_1(x))),x=states y=fc2(relu(fc1(x))),x=states
两个网络均采用Adam作为优化器。

· 采样episode

return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)
...
for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)
...def take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()

①获取初始state:state = env.reset()
②根据policy_net的输出采取action:agent.take_action(state)
③与环境交互,得到(s,a,r,s’,done):env.step(action)
④将样本添加至episode:transition_dict
⑤统计episode即时奖励累加值:episode_return += reward

· 价值更新

agent.update(transition_dict)
...states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 时序差分目标td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)# 均方误差损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.critic_optimizer.zero_grad()...critic_loss.backward()  # 计算价值网络的梯度...	self.critic_optimizer.step()  # 更新价值网络的参数

①将采样完的单个episode转换为tensor张量
②计算TD_target:td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
③计算损失函数(均方误差):critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
④value_net网络训练更新

· 策略更新

        # 时序差分目标td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)  # 时序差分残差log_probs = torch.log(self.actor(states).gather(1, actions))#交叉熵损失函数actor_loss = torch.mean(-log_probs * td_delta.detach())self.actor_optimizer.zero_grad()...actor_loss.backward()  # 计算策略网络的梯度...self.actor_optimizer.step()  # 更新策略网络的参数

①将采样完的单个episode转换为tensor张量
②计算优势函数:td_delta = td_target - self.critic(states)
③计算损失函数(交叉熵):actor_loss = torch.mean(-log_probs * td_delta.detach())
④policy_net网络训练更新

三、疑问

暂无

四、总结

  • Actor-Critic算法算是DQN与REINFORCE算法的结合,集成了值函数近似和策略梯度下降方法
  • ActorCritic 是囊括一系列算法的整体架构,目前很多高效的前沿算法都属于 Actor-Critic 算法

http://www.mrgr.cn/news/61237.html

相关文章:

  • 音频DSP的发展历史
  • LabVIEW数据库管理系统
  • 多线程与多进程性能分析与最佳实践
  • 【Git版本控制器--1】Git的基本操作--本地仓库
  • 服务器证书、数字证书和加解密算法
  • Github 2025-01-11 Rust开源项目日报 Top10
  • 自适应神经网络架构:原理解析与代码示例
  • linux系统安装软件的三种方式
  • JavaEE初阶------网络编程续+传输层UDP协议介绍
  • WebGL 3D基础
  • 当变频器报警过热故障时,如何处理
  • 基于SSM的智能台球厅系统
  • CAN物理层(ISO 11898-2 2024)
  • 《掌握 Java:从基础到高级概念的综合指南》(11/15)
  • 解决goravel/gorm自动迁移AutoMigrate 时会将关联关系也一并迁移问题
  • 结合无监督表示学习与伪标签监督的自蒸馏方法,用于稀有疾病影像表型分类的分散感知失衡校正|文献速递-基于生成模型的数据增强与疾病监测应用
  • stm32入门教程--DMA 超详细!!!
  • 免费的一键抠图软件有哪些?5个软件帮助你快速进行一键抠图
  • JavaScript Date对象 、日期求差
  • 工商业光储充新能源电站用 安科瑞ACCU-100微电网协调控制器
  • redis修改配置文件配置密码开启远程访问后台运行
  • 解决minio跨域问题
  • springboot民大校园美食推荐系统-计算机毕业设计源码10508
  • 把你的产品宣传册制作成这种3D宣传册,瞬间提升档次
  • mysql定时清空某个表数据
  • Android Activity 属性 TaskAffiity、allowTaskReparenting