mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Refactored actor.py to use the pipeline
This commit is contained in:
@@ -83,6 +83,13 @@ from lerobot.utils.transition import (
|
|||||||
move_state_dict_to_device,
|
move_state_dict_to_device,
|
||||||
move_transition_to_device,
|
move_transition_to_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||||
|
from lerobot.scripts.rl.gym_manipulator import (
|
||||||
|
create_transition,
|
||||||
|
make_processors,
|
||||||
|
step_env_and_process_transition,
|
||||||
|
)
|
||||||
from lerobot.utils.utils import (
|
from lerobot.utils.utils import (
|
||||||
TimerManager,
|
TimerManager,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
@@ -236,7 +243,8 @@ def act_with_policy(
|
|||||||
|
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
|
|
||||||
online_env = make_robot_env(cfg=cfg.env)
|
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||||
|
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
|
||||||
|
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
@@ -257,6 +265,13 @@ def act_with_policy(
|
|||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||||
|
env_processor.reset()
|
||||||
|
action_processor.reset()
|
||||||
|
|
||||||
|
# Process initial observation
|
||||||
|
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||||
|
transition = env_processor(transition)
|
||||||
|
|
||||||
# NOTE: For the moment we will solely handle the case of a single environment
|
# NOTE: For the moment we will solely handle the case of a single environment
|
||||||
sum_reward_episode = 0
|
sum_reward_episode = 0
|
||||||
@@ -277,7 +292,9 @@ def act_with_policy(
|
|||||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||||
# Time policy inference and check if it meets FPS requirement
|
# Time policy inference and check if it meets FPS requirement
|
||||||
with policy_timer:
|
with policy_timer:
|
||||||
action = policy.select_action(batch=obs)
|
# Extract observation from transition for policy
|
||||||
|
batch_obs = transition[TransitionKey.OBSERVATION]
|
||||||
|
action = policy.select_action(batch=batch_obs)
|
||||||
policy_fps = policy_timer.fps_last
|
policy_fps = policy_timer.fps_last
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
@@ -285,34 +302,46 @@ def act_with_policy(
|
|||||||
else:
|
else:
|
||||||
action = online_env.action_space.sample()
|
action = online_env.action_space.sample()
|
||||||
|
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
# Use the new step function
|
||||||
|
new_transition, terminate_episode = step_env_and_process_transition(
|
||||||
|
env=online_env,
|
||||||
|
transition=transition,
|
||||||
|
action=action,
|
||||||
|
teleop_device=teleop_device,
|
||||||
|
env_processor=env_processor,
|
||||||
|
action_processor=action_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract values from processed transition
|
||||||
|
reward = new_transition[TransitionKey.REWARD]
|
||||||
|
done = new_transition.get(TransitionKey.DONE, False)
|
||||||
|
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||||
|
processed_action = new_transition[TransitionKey.ACTION]
|
||||||
|
|
||||||
sum_reward_episode += float(reward)
|
sum_reward_episode += float(reward)
|
||||||
# Increment total steps counter for intervention rate
|
|
||||||
episode_total_steps += 1
|
episode_total_steps += 1
|
||||||
|
|
||||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
# Check for intervention from transition info
|
||||||
if "is_intervention" in info and info["is_intervention"]:
|
intervention_info = new_transition[TransitionKey.INFO]
|
||||||
# NOTE: The action space for demonstration before hand is with the full action space
|
if intervention_info.get("is_intervention", False):
|
||||||
# but sometimes for example we want to deactivate the gripper
|
|
||||||
action = info["action_intervention"]
|
|
||||||
episode_intervention = True
|
episode_intervention = True
|
||||||
# Increment intervention steps counter
|
|
||||||
episode_intervention_steps += 1
|
episode_intervention_steps += 1
|
||||||
|
|
||||||
|
# Create transition for learner (convert to old format)
|
||||||
list_transition_to_send_to_learner.append(
|
list_transition_to_send_to_learner.append(
|
||||||
Transition(
|
Transition(
|
||||||
state=obs,
|
state=transition[TransitionKey.OBSERVATION],
|
||||||
action=action,
|
action=processed_action,
|
||||||
reward=reward,
|
reward=reward,
|
||||||
next_state=next_obs,
|
next_state=new_transition[TransitionKey.OBSERVATION],
|
||||||
done=done,
|
done=done,
|
||||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
truncated=truncated,
|
||||||
complementary_info=info,
|
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# assign obs to the next obs and continue the rollout
|
|
||||||
obs = next_obs
|
# Update transition for next iteration
|
||||||
|
transition = new_transition
|
||||||
|
|
||||||
if done or truncated:
|
if done or truncated:
|
||||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||||
@@ -347,12 +376,21 @@ def act_with_policy(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset intervention counters
|
# Reset intervention counters and environment
|
||||||
sum_reward_episode = 0.0
|
sum_reward_episode = 0.0
|
||||||
episode_intervention = False
|
episode_intervention = False
|
||||||
episode_intervention_steps = 0
|
episode_intervention_steps = 0
|
||||||
episode_total_steps = 0
|
episode_total_steps = 0
|
||||||
|
|
||||||
|
# Reset environment and processors
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||||
|
env_processor.reset()
|
||||||
|
action_processor.reset()
|
||||||
|
|
||||||
|
# Process initial observation
|
||||||
|
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||||
|
transition = env_processor(transition)
|
||||||
|
|
||||||
if cfg.env.fps is not None:
|
if cfg.env.fps is not None:
|
||||||
dt_time = time.perf_counter() - start_time
|
dt_time = time.perf_counter() - start_time
|
||||||
|
|||||||
@@ -404,6 +404,9 @@ class ImageCropResizeProcessor:
|
|||||||
if observation is None:
|
if observation is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
if self.resize_size is None and not self.crop_params_dict:
|
||||||
|
return transition
|
||||||
|
|
||||||
new_observation = dict(observation)
|
new_observation = dict(observation)
|
||||||
|
|
||||||
# Process all image keys in the observation
|
# Process all image keys in the observation
|
||||||
@@ -777,6 +780,123 @@ def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
|
|||||||
return env, teleop_device
|
return env, teleop_device
|
||||||
|
|
||||||
|
|
||||||
|
def make_processors(env, cfg):
|
||||||
|
"""
|
||||||
|
Factory function to create environment and action processors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The robot environment
|
||||||
|
cfg: Configuration object containing processor parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (env_processor, action_processor)
|
||||||
|
"""
|
||||||
|
env_pipeline_steps = [
|
||||||
|
ImageProcessor(),
|
||||||
|
StateProcessor(),
|
||||||
|
JointVelocityProcessor(dt=1.0 / cfg.fps),
|
||||||
|
MotorCurrentProcessor(env=env),
|
||||||
|
ImageCropResizeProcessor(
|
||||||
|
crop_params_dict=cfg.processor.crop_params_dict,
|
||||||
|
resize_size=cfg.processor.resize_size
|
||||||
|
),
|
||||||
|
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
|
||||||
|
GripperPenaltyProcessor(
|
||||||
|
penalty=cfg.processor.gripper_penalty,
|
||||||
|
max_gripper_pos=cfg.processor.max_gripper_pos
|
||||||
|
),
|
||||||
|
DeviceProcessor(device=cfg.device),
|
||||||
|
]
|
||||||
|
env_processor = RobotProcessor(steps=env_pipeline_steps)
|
||||||
|
|
||||||
|
action_pipeline_steps = [
|
||||||
|
InterventionActionProcessor(
|
||||||
|
use_gripper=cfg.processor.use_gripper,
|
||||||
|
),
|
||||||
|
InverseKinematicsProcessor(
|
||||||
|
urdf_path=cfg.processor.urdf_path,
|
||||||
|
target_frame_name=cfg.processor.target_frame_name,
|
||||||
|
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
|
||||||
|
end_effector_bounds=cfg.processor.end_effector_bounds,
|
||||||
|
max_gripper_pos=cfg.processor.max_gripper_pos,
|
||||||
|
env=env,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
action_processor = RobotProcessor(steps=action_pipeline_steps)
|
||||||
|
|
||||||
|
return env_processor, action_processor
|
||||||
|
|
||||||
|
|
||||||
|
def step_env_and_process_transition(
|
||||||
|
env,
|
||||||
|
transition,
|
||||||
|
action,
|
||||||
|
teleop_device,
|
||||||
|
env_processor,
|
||||||
|
action_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Execute one step with processors handling intervention and observation processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The robot environment
|
||||||
|
transition: Current transition state
|
||||||
|
action: Action to execute (will be replaced by neutral action in gym_manipulator mode)
|
||||||
|
teleop_device: Teleoperator device for getting intervention signals
|
||||||
|
env_processor: Environment processor for observations
|
||||||
|
action_processor: Action processor for handling interventions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (new_transition, terminate_episode)
|
||||||
|
"""
|
||||||
|
# Get teleoperation action and events
|
||||||
|
teleop_action = teleop_device.get_action()
|
||||||
|
teleop_events = teleop_device.get_teleop_events()
|
||||||
|
|
||||||
|
# Create action transition
|
||||||
|
action_transition = dict(transition)
|
||||||
|
action_transition[TransitionKey.ACTION] = action
|
||||||
|
|
||||||
|
# Add teleoperation data to complementary data
|
||||||
|
action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy()
|
||||||
|
action_complementary_data["teleop_action"] = teleop_action
|
||||||
|
action_complementary_data.update(teleop_events)
|
||||||
|
action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data
|
||||||
|
|
||||||
|
# Process action through action pipeline (handles intervention)
|
||||||
|
processed_action_transition = action_processor(action_transition)
|
||||||
|
|
||||||
|
# Extract processed action and metadata
|
||||||
|
processed_action = processed_action_transition[TransitionKey.ACTION]
|
||||||
|
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
|
||||||
|
|
||||||
|
# Step environment with processed action
|
||||||
|
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||||
|
|
||||||
|
# Combine rewards from environment and action processor
|
||||||
|
reward = reward + processed_action_transition[TransitionKey.REWARD]
|
||||||
|
|
||||||
|
# Process new observation
|
||||||
|
complementary_data = {
|
||||||
|
"raw_joint_positions": info.pop("raw_joint_positions"),
|
||||||
|
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||||
|
}
|
||||||
|
info.update(processed_action_transition[TransitionKey.INFO])
|
||||||
|
|
||||||
|
new_transition = create_transition(
|
||||||
|
observation=obs,
|
||||||
|
action=processed_action,
|
||||||
|
reward=reward,
|
||||||
|
done=terminated or terminate_episode,
|
||||||
|
truncated=truncated,
|
||||||
|
info=info,
|
||||||
|
complementary_data=complementary_data,
|
||||||
|
)
|
||||||
|
new_transition = env_processor(new_transition)
|
||||||
|
|
||||||
|
return new_transition, terminate_episode
|
||||||
|
|
||||||
|
|
||||||
def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvConfig):
|
def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvConfig):
|
||||||
dt = 1.0 / cfg.fps
|
dt = 1.0 / cfg.fps
|
||||||
|
|
||||||
@@ -841,54 +961,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
while episode_idx < cfg.num_episodes:
|
while episode_idx < cfg.num_episodes:
|
||||||
step_start_time = time.perf_counter()
|
step_start_time = time.perf_counter()
|
||||||
|
|
||||||
# Get teleoperation action and extra signals
|
|
||||||
teleop_action = teleop_device.get_action()
|
|
||||||
teleop_events = teleop_device.get_teleop_events()
|
|
||||||
|
|
||||||
# Create a neutral action (no movement)
|
# Create a neutral action (no movement)
|
||||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||||
if hasattr(env, "use_gripper") and env.use_gripper:
|
if hasattr(env, "use_gripper") and env.use_gripper:
|
||||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||||
|
|
||||||
# Create action transition
|
# Use the new step function
|
||||||
action_transition = dict(transition)
|
transition, terminate_episode = step_env_and_process_transition(
|
||||||
action_transition[TransitionKey.ACTION] = neutral_action
|
env=env,
|
||||||
|
transition=transition,
|
||||||
# Add teleoperation data to complementary data
|
action=neutral_action,
|
||||||
action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy()
|
teleop_device=teleop_device,
|
||||||
action_complementary_data["teleop_action"] = teleop_action
|
env_processor=env_processor,
|
||||||
action_complementary_data.update(teleop_events)
|
action_processor=action_processor,
|
||||||
action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data
|
|
||||||
|
|
||||||
# Process action through action pipeline (handles intervention)
|
|
||||||
processed_action_transition = action_processor(action_transition)
|
|
||||||
|
|
||||||
# Extract processed action and metadata
|
|
||||||
processed_action = processed_action_transition[TransitionKey.ACTION]
|
|
||||||
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
|
|
||||||
|
|
||||||
# Step environment with processed action
|
|
||||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
|
||||||
|
|
||||||
reward = reward + processed_action_transition[TransitionKey.REWARD]
|
|
||||||
|
|
||||||
# Process new observation
|
|
||||||
complementary_data = {
|
|
||||||
"raw_joint_positions": info.pop("raw_joint_positions"),
|
|
||||||
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
|
|
||||||
}
|
|
||||||
info.update(processed_action_transition[TransitionKey.INFO])
|
|
||||||
|
|
||||||
transition = create_transition(
|
|
||||||
observation=obs,
|
|
||||||
action=processed_action,
|
|
||||||
reward=reward,
|
|
||||||
done=terminated or terminate_episode,
|
|
||||||
truncated=truncated,
|
|
||||||
info=info,
|
|
||||||
complementary_data=complementary_data,
|
|
||||||
)
|
)
|
||||||
transition = env_processor(transition)
|
|
||||||
terminated = transition.get(TransitionKey.DONE, False)
|
terminated = transition.get(TransitionKey.DONE, False)
|
||||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||||
|
|
||||||
@@ -963,38 +1049,7 @@ def replay_trajectory(env, action_processor, cfg):
|
|||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def main(cfg: EnvConfig):
|
def main(cfg: EnvConfig):
|
||||||
env, teleop_device = make_robot_env(cfg)
|
env, teleop_device = make_robot_env(cfg)
|
||||||
env_pipeline_steps = [
|
env_processor, action_processor = make_processors(env, cfg)
|
||||||
ImageProcessor(),
|
|
||||||
StateProcessor(),
|
|
||||||
JointVelocityProcessor(dt=1.0 / cfg.fps),
|
|
||||||
MotorCurrentProcessor(env=env),
|
|
||||||
ImageCropResizeProcessor(
|
|
||||||
crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size
|
|
||||||
),
|
|
||||||
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
|
|
||||||
GripperPenaltyProcessor(
|
|
||||||
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
|
|
||||||
),
|
|
||||||
DeviceProcessor(device=cfg.device),
|
|
||||||
]
|
|
||||||
|
|
||||||
env_processor = RobotProcessor(steps=env_pipeline_steps)
|
|
||||||
|
|
||||||
action_pipeline_steps = [
|
|
||||||
InterventionActionProcessor(
|
|
||||||
use_gripper=cfg.processor.use_gripper,
|
|
||||||
),
|
|
||||||
InverseKinematicsProcessor(
|
|
||||||
urdf_path=cfg.processor.urdf_path,
|
|
||||||
target_frame_name=cfg.processor.target_frame_name,
|
|
||||||
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
|
|
||||||
end_effector_bounds=cfg.processor.end_effector_bounds,
|
|
||||||
max_gripper_pos=cfg.processor.max_gripper_pos,
|
|
||||||
env=env,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
action_processor = RobotProcessor(steps=action_pipeline_steps)
|
|
||||||
|
|
||||||
print("Environment observation space:", env.observation_space)
|
print("Environment observation space:", env.observation_space)
|
||||||
print("Environment action space:", env.action_space)
|
print("Environment action space:", env.action_space)
|
||||||
|
|||||||
Reference in New Issue
Block a user