系列文章目录
目录
三、MuJoCo 游乐场: 机器人运动和操纵环境 + 模拟到现实!
前言
本笔记本提供 MuJoCo XLA (MJX) 的入门教程,MJX 是 MuJoCo 基于 JAX 的实现,适用于 RL 训练工作负载。
需要使用具有 GPU 加速功能的 Colab 运行时。如果使用的是纯 CPU 运行时,可以通过菜单 “运行时 > 更改运行时类型 ”进行切换。
一、安装 MuJoCo、MJX 和 Brax
!pip install mujoco
!pip install mujoco_mjx
!pip install brax
#@title Check if MuJoCo installation was successful
from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
raise RuntimeError(
'Cannot communicate with GPU. '
'Make sure you are using a GPU Colab runtime. '
'Go to the Runtime menu and select Choose runtime type.')
# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
f.write("""{
"file_format_version" : "1.0.0",
"ICD" : {
"library_path" : "libEGL_nvidia.so.0"
}
}
""")
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
try:
print('Checking that the installation succeeded:')
import mujoco
mujoco.MjModel.from_xml_string('')
except Exception as e:
raise e from RuntimeError(
'Something went wrong during installation. Check the shell output above '
'for more information.\n'
'If using a hosted Colab runtime, make sure you enable GPU acceleration '
'by going to the Runtime menu and selecting "Choose runtime type".')
print('Installation successful.')
# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List
# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt
# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict
import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp
import mujoco
from mujoco import mjx
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
二、MJX 简介
MJX 是用 JAX 编写的 MuJoCo 实现,可以在 GPU/TPU 上进行大批量训练。在本笔记本中,我们将演示如何使用 MJX 训练 RL 策略。
在开始大量 RL 工作负载之前,我们先从一个简单的示例开始!进入 MJX 的入口是 MuJoCo,因此我们首先加载一个 MuJoCo 模型:
xml = """
"""
# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)
接下来,我们使用 MJX 将 MuJoCo 模型和数据放到 GPU 设备上。
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
下面,我们将打印来自 MuJoCo 和 MJX 的 qpos。请注意,mjData 的 qpos 是运行在 CPU 上的 numpy 数组,而 mjx.Data 的 qpos 是运行在 GPU 设备上的 JAX 数组。
print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())
让我们在 MuJoCo 中运行模拟并渲染轨迹。本示例摘自 MuJoCo 教程。
# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True
duration = 3.8 # (seconds)
framerate = 60 # (Hz)
frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
mujoco.mj_step(mj_model, mj_data)
if len(frames) < mj_data.time * framerate:
renderer.update_scene(mj_data, scene_option=scene_option)
pixels = renderer.render()
frames.append(pixels)
# Simulate and display video.
media.show_video(frames, fps=framerate)
现在,让我们使用 MJX 在 GPU 设备上运行完全相同的模拟!
在下面的示例中,我们使用 mjx.step 而不是 mujoco.mj_step,而且我们还对 mjx.step 进行了 jax.jit,以便在 GPU 上高效运行。对于每一帧,我们都会将 mjx.Data 转换回 mjData,以便使用 MuJoCo 渲染器。
jit_step = jax.jit(mjx.step)
frames = []
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
while mjx_data.time < duration:
mjx_data = jit_step(mjx_model, mjx_data)
if len(frames) < mjx_data.time * framerate:
mj_data = mjx.get_data(mj_model, mjx_data)
renderer.update_scene(mj_data, scene_option=scene_option)
pixels = renderer.render()
frames.append(pixels)
media.show_video(frames, fps=framerate)
在 GPU 上运行单线程物理仿真效率不高。MJX 的优势在于,我们可以在硬件加速设备上并行运行环境。让我们试试看!
在下面的示例中,我们创建了 4096 份 mjx.Data 副本,并在批量数据上运行 mjx.step。由于 MJX 是用 JAX 实现的,因此我们利用 jax.vmap 在所有 mjx.Data 上并行运行 mjx.step。
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 4096)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)
jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
batch = jit_step(mjx_model, batch)
print(batch.qpos)
我们可以像之前一样,将批处理的 mjx.Data 复制回 MuJoCo:
batched_mj_data = mjx.get_data(mj_model, batch)
print([d.qpos for d in batched_mj_data])
三、使用 MJX 训练策略
运行大批量物理仿真对于训练 RL 策略非常有用。在这里,我们将演示如何使用 MJX 和 Brax 的 RL 库来训练 RL 策略。
下面,我们将使用 MJX 和 Brax 实现经典的仿人环境。我们继承了 Brax 中的 MjxEnv 实现,因此在使用 Brax RL 实现进行训练时,我们可以使用 MJX 步进物理仿真。
#@title Humanoid Env
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
class Humanoid(PipelineEnv):
def __init__(
self,
forward_reward_weight=1.25,
ctrl_cost_weight=0.1,
healthy_reward=5.0,
terminate_when_unhealthy=True,
healthy_z_range=(1.0, 2.0),
reset_noise_scale=1e-2,
exclude_current_positions_from_observation=True,
**kwargs,
):
#
mj_model = mujoco.MjModel.from_xml_path(
(HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
sys = mjcf.load_model(mj_model)
physics_steps_per_control_step = 5
kwargs['n_frames'] = kwargs.get(
'n_frames', physics_steps_per_control_step)
kwargs['backend'] = 'mjx'
super().__init__(sys, **kwargs)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.qpos0 + jax.random.uniform(
rng1, (self.sys.nq,), minval=low, maxval=hi
)
qvel = jax.random.uniform(
rng2, (self.sys.nv,), minval=low, maxval=hi
)
data = self.pipeline_init(qpos, qvel)
obs = self._get_obs(data, jp.zeros(self.sys.nu))
reward, done, zero = jp.zeros(3)
metrics = {
'forward_reward': zero,
'reward_linvel': zero,
'reward_quadctrl': zero,
'reward_alive': zero,
'x_position': zero,
'y_position': zero,
'distance_from_origin': zero,
'x_velocity': zero,
'y_velocity': zero,
}
return State(data, obs, reward, done, metrics)
def step(self, state: State, action: jp.ndarray) -> State:
"""Runs one timestep of the environment's dynamics."""
data0 = state.pipeline_state
data = self.pipeline_step(data0, action)
com_before = data0.subtree_com[1]
com_after = data.subtree_com[1]
velocity = (com_after - com_before) / self.dt
forward_reward = self._forward_reward_weight * velocity[0]
min_z, max_z = self._healthy_z_range
is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
healthy_reward = self._healthy_reward * is_healthy
ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
obs = self._get_obs(data, action)
reward = forward_reward + healthy_reward - ctrl_cost
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
state.metrics.update(
forward_reward=forward_reward,
reward_linvel=forward_reward,
reward_quadctrl=-ctrl_cost,
reward_alive=healthy_reward,
x_position=com_after[0],
y_position=com_after[1],
distance_from_origin=jp.linalg.norm(com_after),
x_velocity=velocity[0],
y_velocity=velocity[1],
)
return state.replace(
pipeline_state=data, obs=obs, reward=reward, done=done
)
def _get_obs(
self, data: mjx.Data, action: jp.ndarray
) -> jp.ndarray:
"""Observes humanoid body position, velocities, and angles."""
position = data.qpos
if self._exclude_current_positions_from_observation:
position = position[2:]
# external_contact_forces are excluded
return jp.concatenate([
position,
data.qvel,
data.cinert[1:].ravel(),
data.cvel[1:].ravel(),
data.qfrc_actuator,
])
envs.register_environment('humanoid', Humanoid)
2.1 可视化滚动
让我们将环境实例化,并可视化一次短暂的滚动。
注意:由于如果躯干低于健康的 Z 值范围,剧集就会提前结束,因此本任务中唯一相关的接触点是脚和平面之间的接触点。我们将关闭其他接触。
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)
# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
# grab a trajectory
for i in range(10):
ctrl = -0.1 * jp.ones(env.sys.nu)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)
media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)
2.2 训练人形机器人策略
现在让我们用 PPO 训练一个策略,使人形机器人向前运行。在 Tesla A100 GPU 上训练大约需要 6 分钟。
train_fn = functools.partial(
ppo.train, num_timesteps=20_000_000, num_evals=5, reward_scaling=0.1,
episode_length=1000, normalize_observations=True, action_repeat=1,
unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
batch_size=512, seed=0)
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 13000, 0
def progress(num_steps, metrics):
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics['eval/episode_reward'])
ydataerr.append(metrics['eval/episode_reward_std'])
plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
plt.ylim([min_y, max_y])
plt.xlabel('# environment steps')
plt.ylabel('reward per episode')
plt.title(f'y={y_data[-1]:.3f}')
plt.errorbar(
x_data, y_data, yerr=ydataerr)
plt.show()
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
我们可以使用 brax 模型 API 保存和加载策略。
#@title Save Model
model_path = '/tmp/mjx_brax_policy'
model.save_params(model_path, params)
#@title Load Model and Define Inference Function
params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)
2.3 可视化策略
最后,我们可以将策略可视化。
eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]
# grab a trajectory
n_steps = 500
render_every = 2
for i in range(n_steps):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)
if state.done:
break
media.show_video(env.render(rollout[::render_every], camera='side'), fps=1.0 / env.dt / render_every)
2.4 MuJoCo 中的 MJX 策略
我们还可以使用原始的 MuJoCo python 绑定执行物理步骤,以证明在 MJX 中训练的策略在 MuJoCo 中也能正常工作。
mj_model = eval_env.sys.mj_model
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)
ctrl = jp.zeros(mj_model.nu)
images = []
for i in range(n_steps):
act_rng, rng = jax.random.split(rng)
obs = eval_env._get_obs(mjx.put_data(mj_model, mj_data), ctrl)
ctrl, _ = jit_inference_fn(obs, act_rng)
mj_data.ctrl = ctrl
for _ in range(eval_env._n_frames):
mujoco.mj_step(mj_model, mj_data) # Physics step using MuJoCo mj_step.
if i % render_every == 0:
renderer.update_scene(mj_data, camera='side')
images.append(renderer.render())
media.show_video(images, fps=1.0 / eval_env.dt / render_every)
2.5 使用领域随机化训练策略
在训练策略时,我们可能还想对某些 mjModel 参数进行随机化。在 MJX 中,我们可以轻松创建一批在 mjx.Model 中填充了随机值的环境。下面,我们将展示一个对摩擦力和致动器增益/偏置进行随机化的函数。
def domain_randomize(sys, rng):
"""Randomizes the mjx.Model."""
@jax.vmap
def rand(rng):
_, key = jax.random.split(rng, 2)
# friction
friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
friction = sys.geom_friction.at[:, 0].set(friction)
# actuator
_, key = jax.random.split(key, 2)
gain_range = (-5, 5)
param = jax.random.uniform(
key, (1,), minval=gain_range[0], maxval=gain_range[1]
) + sys.actuator_gainprm[:, 0]
gain = sys.actuator_gainprm.at[:, 0].set(param)
bias = sys.actuator_biasprm.at[:, 1].set(-param)
return friction, gain, bias
friction, gain, bias = rand(rng)
in_axes = jax.tree_util.tree_map(lambda x: None, sys)
in_axes = in_axes.tree_replace({
'geom_friction': 0,
'actuator_gainprm': 0,
'actuator_biasprm': 0,
})
sys = sys.tree_replace({
'geom_friction': friction,
'actuator_gainprm': gain,
'actuator_biasprm': bias,
})
return sys, in_axes
如果我们想要 10 个具有随机摩擦力和致动器参数的环境,我们可以调用 domain_randomize,它会返回一个批处理的 mjx.Model 以及一个指定批处理轴的字典。
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 10)
batched_sys, _ = domain_randomize(env.sys, rng)
print('Single env friction shape: ', env.sys.geom_friction.shape)
print('Batched env friction shape: ', batched_sys.geom_friction.shape)
print('Friction on geom 0: ', env.sys.geom_friction[0, 0])
print('Random frictions on geom 0: ', batched_sys.geom_friction[:, 0, 0])
2.6 四足动物环境
让我们利用域随机化功能定义一个四足动物环境。在这里,我们使用来自 MuJoCo Menagerie 的 Barkour vb Quadruped。我们使用 Brax 实现了一个训练操纵杆策略的环境。
注:如欲了解全套机器人环境,其中许多已转移到机器人上,请查看 MuJoCo Playground!
!git clone https://github.com/google-deepmind/mujoco_menagerie
#@title Barkour vb Quadruped Env
BARKOUR_ROOT_PATH = epath.Path('mujoco_menagerie/google_barkour_vb')
def get_config():
"""Returns reward config for barkour quadruped environment."""
def get_default_rewards_config():
default_config = config_dict.ConfigDict(
dict(
# The coefficients for all reward terms used for training. All
# physical quantities are in SI units, if no otherwise specified,
# i.e. joint positions are in rad, positions are measured in meters,
# torques in Nm, and time in seconds, and forces in Newtons.
scales=config_dict.ConfigDict(
dict(
# Tracking rewards are computed using exp(-delta^2/sigma)
# sigma can be a hyperparameters to tune.
# Track the base x-y velocity (no z-velocity tracking.)
tracking_lin_vel=1.5,
# Track the angular velocity along z-axis, i.e. yaw rate.
tracking_ang_vel=0.8,
# Below are regularization terms, we roughly divide the
# terms to base state regularizations, joint
# regularizations, and other behavior regularizations.
# Penalize the base velocity in z direction, L2 penalty.
lin_vel_z=-2.0,
# Penalize the base roll and pitch rate. L2 penalty.
ang_vel_xy=-0.05,
# Penalize non-zero roll and pitch angles. L2 penalty.
orientation=-5.0,
# L2 regularization of joint torques, |tau|^2.
torques=-0.0002,
# Penalize the change in the action and encourage smooth
# actions. L2 regularization |action - last_action|^2
action_rate=-0.01,
# Encourage long swing steps. However, it does not
# encourage high clearances.
feet_air_time=0.2,
# Encourage no motion at zero command, L2 regularization
# |q - q_default|^2.
stand_still=-0.5,
# Early termination penalty.
termination=-1.0,
# Penalizing foot slipping on the ground.
foot_slip=-0.1,
)
),
# Tracking reward = exp(-error^2/sigma).
tracking_sigma=0.25,
)
)
return default_config
default_config = config_dict.ConfigDict(
dict(
rewards=get_default_rewards_config(),
)
)
return default_config
class BarkourEnv(PipelineEnv):
"""Environment for training the barkour quadruped joystick policy in MJX."""
def __init__(
self,
obs_noise: float = 0.05,
action_scale: float = 0.3,
kick_vel: float = 0.05,
scene_file: str = 'scene_mjx.xml',
**kwargs,
):
path = BARKOUR_ROOT_PATH / scene_file
sys = mjcf.load(path.as_posix())
self._dt = 0.02 # this environment is 50 fps
sys = sys.tree_replace({'opt.timestep': 0.004})
# override menagerie params for smoother policy
sys = sys.replace(
dof_damping=sys.dof_damping.at[6:].set(0.5239),
actuator_gainprm=sys.actuator_gainprm.at[:, 0].set(35.0),
actuator_biasprm=sys.actuator_biasprm.at[:, 1].set(-35.0),
)
n_frames = kwargs.pop('n_frames', int(self._dt / sys.opt.timestep))
super().__init__(sys, backend='mjx', n_frames=n_frames)
self.reward_config = get_config()
# set custom from kwargs
for k, v in kwargs.items():
if k.endswith('_scale'):
self.reward_config.rewards.scales[k[:-6]] = v
self._torso_idx = mujoco.mj_name2id(
sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'
)
self._action_scale = action_scale
self._obs_noise = obs_noise
self._kick_vel = kick_vel
self._init_q = jp.array(sys.mj_model.keyframe('home').qpos)
self._default_pose = sys.mj_model.keyframe('home').qpos[7:]
self.lowers = jp.array([-0.7, -1.0, 0.05] * 4)
self.uppers = jp.array([0.52, 2.1, 2.1] * 4)
feet_site = [
'foot_front_left',
'foot_hind_left',
'foot_front_right',
'foot_hind_right',
]
feet_site_id = [
mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)
for f in feet_site
]
assert not any(id_ == -1 for id_ in feet_site_id), 'Site not found.'
self._feet_site_id = np.array(feet_site_id)
lower_leg_body = [
'lower_leg_front_left',
'lower_leg_hind_left',
'lower_leg_front_right',
'lower_leg_hind_right',
]
lower_leg_body_id = [
mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, l)
for l in lower_leg_body
]
assert not any(id_ == -1 for id_ in lower_leg_body_id), 'Body not found.'
self._lower_leg_body_id = np.array(lower_leg_body_id)
self._foot_radius = 0.0175
self._nv = sys.nv
def sample_command(self, rng: jax.Array) -> jax.Array:
lin_vel_x = [-0.6, 1.5] # min max [m/s]
lin_vel_y = [-0.8, 0.8] # min max [m/s]
ang_vel_yaw = [-0.7, 0.7] # min max [rad/s]
_, key1, key2, key3 = jax.random.split(rng, 4)
lin_vel_x = jax.random.uniform(
key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1]
)
lin_vel_y = jax.random.uniform(
key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1]
)
ang_vel_yaw = jax.random.uniform(
key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1]
)
new_cmd = jp.array([lin_vel_x[0], lin_vel_y[0], ang_vel_yaw[0]])
return new_cmd
def reset(self, rng: jax.Array) -> State: # pytype: disable=signature-mismatch
rng, key = jax.random.split(rng)
pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))
state_info = {
'rng': rng,
'last_act': jp.zeros(12),
'last_vel': jp.zeros(12),
'command': self.sample_command(key),
'last_contact': jp.zeros(4, dtype=bool),
'feet_air_time': jp.zeros(4),
'rewards': {k: 0.0 for k in self.reward_config.rewards.scales.keys()},
'kick': jp.array([0.0, 0.0]),
'step': 0,
}
obs_history = jp.zeros(15 * 31) # store 15 steps of history
obs = self._get_obs(pipeline_state, state_info, obs_history)
reward, done = jp.zeros(2)
metrics = {'total_dist': 0.0}
for k in state_info['rewards']:
metrics[k] = state_info['rewards'][k]
state = State(pipeline_state, obs, reward, done, metrics, state_info) # pytype: disable=wrong-arg-types
return state
def step(self, state: State, action: jax.Array) -> State: # pytype: disable=signature-mismatch
rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'], 3)
# kick
push_interval = 10
kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)
kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
kick *= jp.mod(state.info['step'], push_interval) == 0
qvel = state.pipeline_state.qvel # pytype: disable=attribute-error
qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
state = state.tree_replace({'pipeline_state.qvel': qvel})
# physics step
motor_targets = self._default_pose + action * self._action_scale
motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
x, xd = pipeline_state.x, pipeline_state.xd
# observation data
obs = self._get_obs(pipeline_state, state.info, state.obs)
joint_angles = pipeline_state.q[7:]
joint_vel = pipeline_state.qd[6:]
# foot contact data based on z-position
foot_pos = pipeline_state.site_xpos[self._feet_site_id] # pytype: disable=attribute-error
foot_contact_z = foot_pos[:, 2] - self._foot_radius
contact = foot_contact_z < 1e-3 # a mm or less off the floor
contact_filt_mm = contact | state.info['last_contact']
contact_filt_cm = (foot_contact_z < 3e-2) | state.info['last_contact']
first_contact = (state.info['feet_air_time'] > 0) * contact_filt_mm
state.info['feet_air_time'] += self.dt
# done if joint limits are reached or robot is falling
up = jp.array([0.0, 0.0, 1.0])
done = jp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0
done |= jp.any(joint_angles < self.lowers)
done |= jp.any(joint_angles > self.uppers)
done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18
# reward
rewards = {
'tracking_lin_vel': (
self._reward_tracking_lin_vel(state.info['command'], x, xd)
),
'tracking_ang_vel': (
self._reward_tracking_ang_vel(state.info['command'], x, xd)
),
'lin_vel_z': self._reward_lin_vel_z(xd),
'ang_vel_xy': self._reward_ang_vel_xy(xd),
'orientation': self._reward_orientation(x),
'torques': self._reward_torques(pipeline_state.qfrc_actuator), # pytype: disable=attribute-error
'action_rate': self._reward_action_rate(action, state.info['last_act']),
'stand_still': self._reward_stand_still(
state.info['command'], joint_angles,
),
'feet_air_time': self._reward_feet_air_time(
state.info['feet_air_time'],
first_contact,
state.info['command'],
),
'foot_slip': self._reward_foot_slip(pipeline_state, contact_filt_cm),
'termination': self._reward_termination(done, state.info['step']),
}
rewards = {
k: v * self.reward_config.rewards.scales[k] for k, v in rewards.items()
}
reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0)
# state management
state.info['kick'] = kick
state.info['last_act'] = action
state.info['last_vel'] = joint_vel
state.info['feet_air_time'] *= ~contact_filt_mm
state.info['last_contact'] = contact
state.info['rewards'] = rewards
state.info['step'] += 1
state.info['rng'] = rng
# sample new command if more than 500 timesteps achieved
state.info['command'] = jp.where(
state.info['step'] > 500,
self.sample_command(cmd_rng),
state.info['command'],
)
# reset the step counter when done
state.info['step'] = jp.where(
done | (state.info['step'] > 500), 0, state.info['step']
)
# log total displacement as a proxy metric
state.metrics['total_dist'] = math.normalize(x.pos[self._torso_idx - 1])[1]
state.metrics.update(state.info['rewards'])
done = jp.float32(done)
state = state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
)
return state
def _get_obs(
self,
pipeline_state: base.State,
state_info: dict[str, Any],
obs_history: jax.Array,
) -> jax.Array:
inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)
obs = jp.concatenate([
jp.array([local_rpyrate[2]]) * 0.25, # yaw rate
math.rotate(jp.array([0, 0, -1]), inv_torso_rot), # projected gravity
state_info['command'] * jp.array([2.0, 2.0, 0.25]), # command
pipeline_state.q[7:] - self._default_pose, # motor angles
state_info['last_act'], # last action
])
# clip, noise
obs = jp.clip(obs, -100.0, 100.0) + self._obs_noise * jax.random.uniform(
state_info['rng'], obs.shape, minval=-1, maxval=1
)
# stack observations through time
obs = jp.roll(obs_history, obs.size).at[:obs.size].set(obs)
return obs
# ------------ reward functions----------------
def _reward_lin_vel_z(self, xd: Motion) -> jax.Array:
# Penalize z axis base linear velocity
return jp.square(xd.vel[0, 2])
def _reward_ang_vel_xy(self, xd: Motion) -> jax.Array:
# Penalize xy axes base angular velocity
return jp.sum(jp.square(xd.ang[0, :2]))
def _reward_orientation(self, x: Transform) -> jax.Array:
# Penalize non flat base orientation
up = jp.array([0.0, 0.0, 1.0])
rot_up = math.rotate(up, x.rot[0])
return jp.sum(jp.square(rot_up[:2]))
def _reward_torques(self, torques: jax.Array) -> jax.Array:
# Penalize torques
return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques))
def _reward_action_rate(
self, act: jax.Array, last_act: jax.Array
) -> jax.Array:
# Penalize changes in actions
return jp.sum(jp.square(act - last_act))
def _reward_tracking_lin_vel(
self, commands: jax.Array, x: Transform, xd: Motion
) -> jax.Array:
# Tracking of linear velocity commands (xy axes)
local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))
lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))
lin_vel_reward = jp.exp(
-lin_vel_error / self.reward_config.rewards.tracking_sigma
)
return lin_vel_reward
def _reward_tracking_ang_vel(
self, commands: jax.Array, x: Transform, xd: Motion
) -> jax.Array:
# Tracking of angular velocity commands (yaw)
base_ang_vel = math.rotate(xd.ang[0], math.quat_inv(x.rot[0]))
ang_vel_error = jp.square(commands[2] - base_ang_vel[2])
return jp.exp(-ang_vel_error / self.reward_config.rewards.tracking_sigma)
def _reward_feet_air_time(
self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array
) -> jax.Array:
# Reward air time.
rew_air_time = jp.sum((air_time - 0.1) * first_contact)
rew_air_time *= (
math.normalize(commands[:2])[1] > 0.05
) # no reward for zero command
return rew_air_time
def _reward_stand_still(
self,
commands: jax.Array,
joint_angles: jax.Array,
) -> jax.Array:
# Penalize motion at zero commands
return jp.sum(jp.abs(joint_angles - self._default_pose)) * (
math.normalize(commands[:2])[1] < 0.1
)
def _reward_foot_slip(
self, pipeline_state: base.State, contact_filt: jax.Array
) -> jax.Array:
# get velocities at feet which are offset from lower legs
# pytype: disable=attribute-error
pos = pipeline_state.site_xpos[self._feet_site_id] # feet position
feet_offset = pos - pipeline_state.xpos[self._lower_leg_body_id]
# pytype: enable=attribute-error
offset = base.Transform.create(pos=feet_offset)
foot_indices = self._lower_leg_body_id - 1 # we got rid of the world body
foot_vel = offset.vmap().do(pipeline_state.xd.take(foot_indices)).vel
# Penalize large feet velocity for feet that are in contact with the ground.
return jp.sum(jp.square(foot_vel[:, :2]) * contact_filt.reshape((-1, 1)))
def _reward_termination(self, done: jax.Array, step: jax.Array) -> jax.Array:
return done & (step < 500)
def render(
self, trajectory: List[base.State], camera: str | None = None,
width: int = 240, height: int = 320,
) -> Sequence[np.ndarray]:
camera = camera or 'track'
return super().render(trajectory, camera=camera, width=width, height=height)
envs.register_environment('barkour', BarkourEnv)
env_name = 'barkour'
env = envs.get_environment(env_name)
2.7 训练策略
为了训练带有域随机化的策略,我们将域随机化函数传入 brax train 函数;brax 将在推出剧集时调用域随机化函数。在 Tesla A100 GPU 上训练四足动物需要 6 分钟。
ckpt_path = epath.Path('/tmp/quadrupred_joystick/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)
def policy_params_fn(current_step, make_policy, params):
# save checkpoints
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
path = ckpt_path / f'{current_step}'
orbax_checkpointer.save(path, params, force=True, save_args=save_args)
make_networks_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(128, 128, 128, 128))
train_fn = functools.partial(
ppo.train, num_timesteps=100_000_000, num_evals=10,
reward_scaling=1, episode_length=1000, normalize_observations=True,
action_repeat=1, unroll_length=20, num_minibatches=32,
num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,
entropy_cost=1e-2, num_envs=8192, batch_size=256,
network_factory=make_networks_factory,
randomization_fn=domain_randomize,
policy_params_fn=policy_params_fn,
seed=0)
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 40, 0
# Reset environments since internals may be overwritten by tracers from the
# domain randomization function.
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)
make_inference_fn, params, _= train_fn(environment=env,
progress_fn=progress,
eval_env=eval_env)
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
# Save and reload params.
model_path = '/tmp/mjx_brax_quadruped_policy'
model.save_params(model_path, params)
params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)
2.8 可视化策略
x_vel 和 y_vel 定义了相对于四足动物躯干的线性前进速度和侧向速度。
eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
# @markdown Commands **only used for Barkour Env**:
x_vel = 1.0 #@param {type: "number"}
y_vel = 0.0 #@param {type: "number"}
ang_vel = -0.5 #@param {type: "number"}
the_command = jp.array([x_vel, y_vel, ang_vel])
# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command'] = the_command
rollout = [state.pipeline_state]
# grab a trajectory
n_steps = 500
render_every = 2
for i in range(n_steps):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)
media.show_video(
eval_env.render(rollout[::render_every], camera='track'),
fps=1.0 / eval_env.dt / render_every)
我们还可以使用 Brax 渲染器渲染滚动效果。
HTML(html.render(eval_env.sys.tree_replace({'opt.timestep': eval_env.dt}), rollout))
2.9 用高度场训练策略
我们可能还希望四足动物学会在崎岖地形上行走。让我们利用上述操纵杆策略中的最新检查点,在高度场地形上对其进行微调。
# use the height field scene
scene_file = 'scene_hfield_mjx.xml'
env = envs.get_environment(env_name, scene_file=scene_file)
jit_reset = jax.jit(env.reset)
state = jit_reset(jax.random.PRNGKey(0))
plt.imshow(env.render([state.pipeline_state], camera='track')[0])
# grab the latest checkpoint from the flat terrain joystick policy
latest_ckpts = list(ckpt_path.glob('*'))
latest_ckpts.sort(key=lambda x: int(x.as_posix().split('/')[-1]))
latest_ckpt = latest_ckpts[-1]
train_fn = functools.partial(
ppo.train, num_timesteps=40_000_000, num_evals=5,
reward_scaling=1, episode_length=1000, normalize_observations=True,
action_repeat=1, unroll_length=20, num_minibatches=32,
num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,
entropy_cost=1e-2, num_envs=8192, batch_size=256,
network_factory=make_networks_factory,
randomization_fn=domain_randomize, seed=0,
restore_checkpoint_path=latest_ckpt)
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 40, 0
# Reset environments since internals may be overwritten by tracers from the
# domain randomization function.
env = envs.get_environment(env_name, scene_file=scene_file)
eval_env = envs.get_environment(env_name, scene_file=scene_file)
make_inference_fn, params, _= train_fn(environment=env,
progress_fn=progress,
eval_env=eval_env)
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
2.10 利用高度域可视化策略
eval_env = envs.get_environment(env_name, scene_file=scene_file)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)
# @markdown Commands **only used for Barkour Env**:
x_vel = 1.0 #@param {type: "number"}
y_vel = 0.0 #@param {type: "number"}
ang_vel = -0.5 #@param {type: "number"}
the_command = jp.array([x_vel, y_vel, ang_vel])
# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command'] = the_command
rollout = [state.pipeline_state]
# grab a trajectory
n_steps = 500
render_every = 2
for i in range(n_steps):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)
media.show_video(
eval_env.render(rollout[::render_every], camera='track'),
fps=1.0 / eval_env.dt / render_every)
三、MuJoCo 游乐场: 机器人运动和操纵环境 + 模拟到现实!
到目前为止,我们已经展示了如何使用 MJX 来训练经典控制和机器人运动的策略。如需全套机器人运动和操纵环境,我们建议您访问 MuJoCo Playground。如网站和技术报告所述,许多机器人环境已被移植到机器人上。