当前位置: 首页 > news >正文

ros进阶——强化学习倒立摆的PG算法实现

请添加图片描述
项目地址:https://github.com/chan-yuu/cartpole_ws

git clone https://github.com/chan-yuu/cartpole_ws

依赖安装:

xterm等
python3.8
torch等

上一节中我们定义了很多ros工具,在这里我们将进行验证。
对于launch_robot_test.py来说,主要包含gazebo部分和强化学习训练部分,定义的功能来启动gazebo和相关node

#!/usr/bin/env python
# -*- coding: utf-8 -*-import rospy
import time
import os
import syscurrent_dir = os.path.dirname(os.path.abspath(__file__))
module_path = os.path.join(current_dir, 'utils', 'ros_utils')
parent_dir = os.path.dirname(module_path)
sys.path.append(parent_dir)from ros_utils import ros_controllers
from ros_utils import ros_gazebo
from ros_utils import ros_launch
from ros_utils import ros_node
from ros_utils import ros_params
from ros_utils import ros_spawn
from ros_utils import ros_urdfdef main():ros_node.ros_kill_all_processes()ros_gazebo.launch_Gazebo(paused=True, gui=True)rospy.init_node('launch_script', anonymous=True)show_rviz = rospy.get_param('~show_rviz', True)debug = rospy.get_param('~debug', True)gui = rospy.get_param('~gui', True)# Robot posex = rospy.get_param('~x', 0)y = rospy.get_param('~y', 0)z = rospy.get_param('~z', 0)roll = rospy.get_param('~roll', 0)pitch = rospy.get_param('~pitch', 0)yaw = rospy.get_param('~yaw', 0)# urdf xml robot description loaded on the Parameter Serverpkg_name = "cartpole_gazebo"model_urdf_file = "cartpole_v1.urdf"model_urdf_folder = "/robots"ns = "/"args_xacro = Noneif ros_urdf.urdf_load_from_pkg(pkg_name, model_urdf_file, "robot_description", folder=model_urdf_folder, ns=ns, args_xacro=args_xacro):rospy.logwarn("URDF 文件加载成功")else:rospy.logwarn("加载 URDF 文件时出错")returntime.sleep(0.1)# push robot_description to factory and spawn robot in gazebogazebo_name = "cartpole"gaz_ref_frame = "world"result_spawn, message = ros_gazebo.gazebo_spawn_urdf_param("robot_description", model_name=gazebo_name, robot_namespace=ns, reference_frame=gaz_ref_frame,pos_x=x, pos_y=y, pos_z=z, ori_w=1.0, ori_x=roll, ori_y=pitch, ori_z=yaw)if result_spawn:rospy.logwarn("模型生成成功")else:rospy.logwarn("生成模型时出错")rospy.logwarn(message)returntime.sleep(0.1)# robot visualization in Rvizif show_rviz:# Load YAML filepkg_name = "cartpole_controller"controllers_file = "joint_position_control.yaml"ns = "/"if ros_params.ros_load_yaml_from_pkg(pkg_name, controllers_file, ns=ns):rospy.logwarn("机器人控制器参数加载成功")else:rospy.logwarn("加载机器人控制器参数时出错")returntime.sleep(0.1)# # 启动 controller_manager 的 spawner 节点node_name = "cartpole_controller_node"pkg_name = "controller_manager"node_type = "spawner"controllers = "joint_state_controller stand_cart_position_controller"ros_node.ros_node_from_pkg(pkg_name, node_type, launch_new_term=False, name=node_name, ns=ns, args=controllers)# # # Spawn controllers# controllers_list = ['cart_pole_position_controller', 'stand_cart_position_controller', 'joint_state_controller']# rospy.loginfo("尝试生成控制器: %s", controllers_list)# if ros_controllers.spawn_controllers_srv(controllers_list, ns=ns):#     rospy.logwarn("控制器生成成功")# else:#     rospy.logwarn("生成控制器时出错")#     # 打印更多的错误信息#     # rospy.logerr("请检查控制器管理器服务是否正常运行")#     # return# Start robot_state_publisherros_node.ros_node_from_pkg("robot_state_publisher", "robot_state_publisher", launch_new_term=False, name="robot_state_publisher", ns=ns)# Start joint_state_publishernode_name = "joint_state_publisher"pkg_name = "joint_state_publisher"node_type = "joint_state_publisher"ros_node.ros_node_from_pkg(pkg_name, node_type, launch_new_term=False, name=node_name, ns=ns)rospy.set_param(ns + node_name + "/use_gui", True)

然后是强化学习训练部分

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import os
import rospy
import random
import time
from std_srvs.srv import Empty
from gazebo_msgs.srv import SetModelConfiguration
from gazebo_msgs.srv import SetLinkState
from control_msgs.msg import JointControllerState
from sensor_msgs.msg import JointState
from gazebo_msgs.msg import LinkStates, LinkState
from std_msgs.msg import Float64, String
from sensor_msgs.msg import Joy
import signal
import sys
# 信号处理函数
def signal_handler(sig, frame):print('You pressed Ctrl+C! Stopping the program...')rospy.signal_shutdown("User interrupted")sys.exit(0)# 注册信号处理函数
signal.signal(signal.SIGINT, signal_handler)# 超参数设置
MAX_EPISODES = 10000
EVAL_EPISODES = 5000
MAX_STEPS = 3000
MODEL_SAVE_PATH = 'models/'# 检查并创建模型保存目录
if not os.path.exists(MODEL_SAVE_PATH):os.makedirs(MODEL_SAVE_PATH)# 初始化ROS相关服务和发布者
def init_ros():pub_cart_position = rospy.Publisher('/stand_cart_position_controller/command', Float64, queue_size=1)rospy.Subscriber("/joint_states", JointState, callback_joint_states)reset_world = rospy.ServiceProxy('/gazebo/reset_world', Empty)reset_joints = rospy.ServiceProxy('/gazebo/set_model_configuration', SetModelConfiguration)unpause = rospy.ServiceProxy('/gazebo/unpause_physics', Empty)pause = rospy.ServiceProxy('/gazebo/pause_physics', Empty)return pub_cart_position, reset_world, reset_joints, unpause, pause# 定义机器人状态类
class RobotState:def __init__(self):# 小车水平位置self.cart_x = 0.0# 小车速度self.cart_x_dot = 0.0# 摆杆相对于垂直方向的角度self.pole_theta = 0.0# 摆杆角度变化率self.pole_theta_dot = 0.0# 机器人状态列表self.robot_state = [self.cart_x, self.cart_x_dot, self.pole_theta, self.pole_theta_dot]self.data = Noneself.latest_reward = 0.0self.fall = 0# 摆杆角度阈值self.theta_threshold = 0.20943951023# 小车位置阈值self.x_threshold = 0.4# 当前小车速度self.current_vel = 0.0# 回合是否结束标志self.done = False# 重置环境
def reset(reset_world, reset_joints, pause):try:# 重置世界reset_world()# 重置关节状态reset_joints("cartpole", "robot_description", ["stand_cart", "cart_pole"], [0.0, 0.0])# 暂停物理模拟pause()time.sleep(0.1)except rospy.ServiceException as e:print(f"Service call failed: {e}")set_robot_state()robot_state.current_vel = 0# 更新机器人状态
def set_robot_state():robot_state.robot_state = [robot_state.cart_x, robot_state.cart_x_dot, robot_state.pole_theta, robot_state.pole_theta_dot]# 全局步数计数器
step_count = 0# 执行动作并返回奖励和是否结束标志
def take_action(action, unpause, pub_cart_position):global step_counttry:# 恢复物理模拟unpause()except rospy.ServiceException as e:print(f"Service call failed: {e}")# 根据动作更新小车速度robot_state.current_vel += 0.02 if action == 1 else -0.02# 发布小车速度指令pub_cart_position.publish(robot_state.current_vel)# 获取关节状态数据while robot_state.data is None:try:robot_state.data = rospy.wait_for_message('/joint_states', JointState, timeout=5)except:print("Error getting /joint_states data.")# 更新机器人状态set_robot_state()reward = 0.5 + abs(robot_state.pole_theta) # 0.5-step reward# 计算奖励reward -= abs(robot_state.cart_x)# 判断是否超出边界if (robot_state.cart_x < -robot_state.x_threshold or robot_state.cart_x > robot_state.x_threshold orrobot_state.pole_theta > robot_state.theta_threshold or robot_state.pole_theta < -robot_state.theta_threshold):robot_state.done = Truereward -= 10step_count = 0else:robot_state.done = Falsestep_count += 1return reward, robot_state.done# 处理关节状态消息的回调函数
def callback_joint_states(data):robot_state.cart_x_dot = data.velocity[1] if len(data.velocity) > 0 else 0.0robot_state.pole_theta_dot = data.velocity[0] if len(data.velocity) > 0 else 0.0robot_state.cart_x = data.position[1]robot_state.pole_theta = data.position[0]set_robot_state()# # 订阅关节状态消息
# def listener():
#     rospy.Subscriber("/joint_states", JointState, callback_joint_states)# 策略网络
class PolicyNetwork(nn.Module):def __init__(self):super().__init__()# 全连接层,输入维度为4,输出维度为2self.fc = nn.Linear(4, 2)def forward(self, x):# 经过全连接层后进行softmax操作return torch.softmax(self.fc(x), dim=1)# 价值网络
class ValueNetwork(nn.Module):def __init__(self):super().__init__()# 第一个全连接层,输入维度为4,输出维度为10self.fc1 = nn.Linear(4, 10)# ReLU激活函数self.relu = nn.ReLU()# 第二个全连接层,输入维度为10,输出维度为1self.fc2 = nn.Linear(10, 1)def forward(self, x):# 经过第一个全连接层和ReLU激活函数,再经过第二个全连接层return self.fc2(self.relu(self.fc1(x)))# 初始化策略网络和优化器
def policy_gradient():policy_network = PolicyNetwork()optimizer = optim.Adam(policy_network.parameters(), lr=0.01)return policy_network, optimizer# 初始化价值网络、优化器和损失函数
def value_gradient():value_network = ValueNetwork()optimizer = optim.Adam(value_network.parameters(), lr=0.1)criterion = nn.MSELoss()return value_network, optimizer, criterion# 计算未来奖励
def calculate_future_reward(transitions, index):future_reward = 0discount = 1for index2 in range(len(transitions) - index):future_reward += transitions[index + index2][2] * discountdiscount *= 0.99return future_reward# 运行一个回合
def run_episode(policy_network, policy_optimizer, value_network, value_optimizer, value_criterion, episode, writer,reset_world, reset_joints, pause, unpause, pub_cart_position):# 重置环境reset(reset_world, reset_joints, pause)observation = robot_state.robot_statetotal_reward = 0states = []actions = []advantages = []transitions = []update_vals = []for step in range(MAX_STEPS):# 将观测状态转换为张量obs_tensor = torch.FloatTensor(observation).unsqueeze(0)# 计算动作概率分布probs = policy_network(obs_tensor).detach().numpy()[0]# 在策略梯度方法中,通常使用 random.uniform(0, 1) < probs[0] 来选择动作,\# 以确保动作的选择符合策略网络的输出分布,并且具有随机性。action = 0 if random.uniform(0, 1) < probs[0] else 1states.append(observation)action_blank = np.zeros(2)action_blank[action] = 1actions.append(action_blank)old_observation = observation# 执行动作并获取奖励和是否结束标志reward, done = take_action(action, unpause, pub_cart_position)observation = robot_state.robot_statetransitions.append((old_observation, action, reward))total_reward += rewardprint(f"Episode {episode}, Step {step}: Random {random.uniform(0, 1)}, Prob={probs[0]}")print(f"Episode {episode}, Step {step}: Action={action}, Reward={reward}")if done:robot_state.done = Falsebreakfor index, (obs, _, reward) in enumerate(transitions):# 计算未来奖励future_reward = calculate_future_reward(transitions, index)obs_tensor = torch.FloatTensor(obs).unsqueeze(0)# 估计当前状态价值current_val = value_network(obs_tensor).item()advantages.append(future_reward - current_val)update_vals.append(future_reward)states_tensor = torch.FloatTensor(np.array(states))actions_tensor = torch.FloatTensor(np.array(actions))advantages_tensor = torch.FloatTensor(np.array(advantages)).unsqueeze(1)update_vals_tensor = torch.FloatTensor(np.array(update_vals)).unsqueeze(1)# 更新价值网络value_optimizer.zero_grad()predicted_vals = value_network(states_tensor)value_loss = value_criterion(predicted_vals, update_vals_tensor)value_loss.backward()value_optimizer.step()# 更新策略网络policy_optimizer.zero_grad()probs = policy_network(states_tensor)good_probs = torch.sum(probs * actions_tensor, dim=1, keepdim=True)eligibility = torch.log(good_probs) * advantages_tensorpolicy_loss = -torch.sum(eligibility)policy_loss.backward()policy_optimizer.step()# 使用 TensorBoard 记录奖励和损失writer.add_scalar('Reward', total_reward, episode)writer.add_scalar('Value Loss', value_loss.item(), episode)writer.add_scalar('Policy Loss', policy_loss.item(), episode)return total_rewarddef PG():# 初始化ROS相关服务和发布者pub_cart_position, reset_world, reset_joints, unpause, pause = init_ros()# 初始化TensorBoard写入器writer = SummaryWriter()# 初始化策略网络和优化器policy_network, policy_optimizer = policy_gradient()# 初始化价值网络、优化器和损失函数value_network, value_optimizer, value_criterion = value_gradient()list_rewards = []for i in range(MAX_EPISODES):# 运行一个回合reward = run_episode(policy_network, policy_optimizer, value_network, value_optimizer, value_criterion, i, writer,reset_world, reset_joints, pause, unpause, pub_cart_position)list_rewards.append(reward)if i % EVAL_EPISODES == 0:# 保存策略网络模型torch.save(policy_network.state_dict(), f'{MODEL_SAVE_PATH}policy_network_episode_{i}.pth')# 保存价值网络模型torch.save(value_network.state_dict(), f'{MODEL_SAVE_PATH}value_network_episode_{i}.pth')time.sleep(0.05)# 保存最终的策略网络模型torch.save(policy_network.state_dict(), f'{MODEL_SAVE_PATH}policy_network.pth')# 保存最终的价值网络模型torch.save(value_network.state_dict(), f'{MODEL_SAVE_PATH}value_network.pth')writer.close()

启动运行逻辑

if __name__=='__main__':try:robot_state = RobotState()main()PG()except rospy.ROSInterruptException:pass

算法原理上来说

class PolicyNetwork(nn.Module):def __init__(self):super(PolicyNetwork, self).__init__()self.fc = nn.Linear(4, 2)def forward(self, x):x = self.fc(x)return torch.softmax(x, dim=1)class ValueNetwork(nn.Module):def __init__(self):super(ValueNetwork, self).__init__()self.fc1 = nn.Linear(4, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 1)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x

策略网络(Policy Network):输入是状态 ,输出是在该状态下每个动作的概率分布 。这里使用一个全连接层和 softmax 函数来实现。
价值网络(Value Network):输入是状态 ,输出是该状态的价值估计 。这里使用两个全连接层和 ReLU 激活函数来实现。

1采样

首先,重置环境,获取初始状态。
在每个时间步,使用策略网络计算当前状态下的动作概率分布,并根据该分布随机选择一个动作。
执行动作,获取奖励和下一个状态,并将状态转移(旧状态、动作、奖励)存储到 transitions 列表中。
累加奖励,直到回合结束(done 为 True)。

状态:observation 是当前状态。
动作:根据策略网络输出的概率分布采样动作。
奖励:执行动作后获得的奖励。

for i in range(MAX_EPISODES):# 总共MAX_EPISODES个episode# print("Episode ", i)reward = run_episode(policy_network, policy_optimizer, value_network, value_optimizer, value_criterion, i, writer)# print("reward", reward)list_rewards.append(reward)if i % EVAL_EPISODES == 0:torch.save(policy_network.state_dict(), f'{MODEL_SAVE_PATH}policy_network_episode_{i}.pth')torch.save(value_network.state_dict(), f'{MODEL_SAVE_PATH}value_network_episode_{i}.pth')time.sleep(0.05)
for step in range(20000):# 每个回合又要运行20000步obs_tensor = torch.FloatTensor(observation).unsqueeze(0)probs = policy_network(obs_tensor).detach().numpy()[0]action = 0 if random.uniform(0, 1) < probs[0] else 1

2优势函数计算

在这里插入图片描述

未来奖励:从当前状态开始到 episode 结束的所有奖励的折扣和。
当前状态价值:由价值网络估计的 V(s)。

advantages.append(future_reward - currentval)

3价值网络更新

value_loss = value_criterion(predicted_vals, update_vals_tensor)
value_loss.backward()
value_optimizer.step()

在这里插入图片描述

4策略网络的更新

policy_loss = -torch.sum(eligibility)
policy_loss.backward()
policy_optimizer.step()

在这里插入图片描述

综上所述,策略梯度算法通过直接优化策略参数,使得策略能够产生获得更高累积奖励的动作序列。在每个回合中,智能体根据当前策略选择动作,收集状态转移数据,计算累计回报和优势函数,然后分别更新价值网络和策略网络的参数。通过不断迭代,策略逐渐收敛到最优策略。

上面我没太明白的一点是:
random.uniform(0, 1) < probs[0]:这种选择方式确保了动作的选择符合策略网络输出的概率分布,并且具有随机性,有助于探索状态空间。
如果用确定的,比如说0.5 < probs[0]:这种选择方式使用了固定的阈值,忽略了策略网络输出的概率分布,缺乏随机性,不适合强化学习中的动作选择。
因此,在策略梯度方法中,通常使用 random.uniform(0, 1) < probs[0] 来选择动作,以确保动作的选择符合策略网络的输出分布,并且具有随机性。


http://www.mrgr.cn/news/92352.html

相关文章:

  • 架构思维:分布式缓存_提升系统性能的关键手段(上)
  • Kafka面试题汇总
  • 【算法系列】快速排序详解
  • Ubuntu从零创建Hadoop集群
  • 【STL】4.<list>
  • 业务随行原理
  • mac下载MAMP6.8.1
  • 探索超声波的奥秘——定时器与PCA
  • 面试题——简述Vue 3的服务器端渲染(SSR)是如何工作的?
  • MongoDB 面试题目
  • (Arrow)时间处理变得更简单
  • 批量将gitlab仓库转移到gitea中
  • 计算机视觉(opencv-python)入门之图像的读取,显示,与保存
  • 微信小程序网络请求与API调用:实现数据交互
  • 系统调用过程
  • 模型蒸馏与量化技术:让AI模型“瘦身”却不“降智”的底层逻辑
  • 可狱可囚的爬虫系列课程 14:10 秒钟编写一个 requests 爬虫
  • Android AOSP系统裁记录
  • 在 HuggingFace 中使用 SSH 进行下载数据集和模型
  • Java入门基础、JDK安装和配置