From cfa672129e1d4c606fbadccb2985b5f4847c3be4 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sat, 2 Aug 2025 19:06:56 +0200 Subject: [PATCH] Refactored `actor.py` to use the pipeline --- src/lerobot/scripts/rl/actor.py | 76 ++++++-- src/lerobot/scripts/rl/gym_manipulator.py | 203 ++++++++++++++-------- 2 files changed, 186 insertions(+), 93 deletions(-) diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index 1c8f9286b..fae3be753 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -83,6 +83,13 @@ from lerobot.utils.transition import ( move_state_dict_to_device, move_transition_to_device, ) + +from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.scripts.rl.gym_manipulator import ( + create_transition, + make_processors, + step_env_and_process_transition, +) from lerobot.utils.utils import ( TimerManager, get_safe_torch_device, @@ -236,7 +243,8 @@ def act_with_policy( logging.info("make_env online") - online_env = make_robot_env(cfg=cfg.env) + online_env, teleop_device = make_robot_env(cfg=cfg.env) + env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device) set_seed(cfg.seed) device = get_safe_torch_device(cfg.policy.device, log=True) @@ -257,6 +265,13 @@ def act_with_policy( assert isinstance(policy, nn.Module) obs, info = online_env.reset() + complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")} + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) + transition = env_processor(transition) # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 @@ -277,7 +292,9 @@ def act_with_policy( if interaction_step >= cfg.policy.online_step_before_learning: # Time policy inference and check if it meets FPS requirement with policy_timer: - action = policy.select_action(batch=obs) + # Extract observation from transition for policy + batch_obs = transition[TransitionKey.OBSERVATION] + action = policy.select_action(batch=batch_obs) policy_fps = policy_timer.fps_last log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) @@ -285,34 +302,46 @@ def act_with_policy( else: action = online_env.action_space.sample() - next_obs, reward, done, truncated, info = online_env.step(action) - + # Use the new step function + new_transition, terminate_episode = step_env_and_process_transition( + env=online_env, + transition=transition, + action=action, + teleop_device=teleop_device, + env_processor=env_processor, + action_processor=action_processor, + ) + + # Extract values from processed transition + reward = new_transition[TransitionKey.REWARD] + done = new_transition.get(TransitionKey.DONE, False) + truncated = new_transition.get(TransitionKey.TRUNCATED, False) + processed_action = new_transition[TransitionKey.ACTION] + sum_reward_episode += float(reward) - # Increment total steps counter for intervention rate episode_total_steps += 1 - # NOTE: We override the action if the intervention is True, because the action applied is the intervention action - if "is_intervention" in info and info["is_intervention"]: - # NOTE: The action space for demonstration before hand is with the full action space - # but sometimes for example we want to deactivate the gripper - action = info["action_intervention"] + # Check for intervention from transition info + intervention_info = new_transition[TransitionKey.INFO] + if intervention_info.get("is_intervention", False): episode_intervention = True - # Increment intervention steps counter episode_intervention_steps += 1 + # Create transition for learner (convert to old format) list_transition_to_send_to_learner.append( Transition( - state=obs, - action=action, + state=transition[TransitionKey.OBSERVATION], + action=processed_action, reward=reward, - next_state=next_obs, + next_state=new_transition[TransitionKey.OBSERVATION], done=done, - truncated=truncated, # TODO: (azouitine) Handle truncation properly - complementary_info=info, + truncated=truncated, + complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA], ) ) - # assign obs to the next obs and continue the rollout - obs = next_obs + + # Update transition for next iteration + transition = new_transition if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") @@ -347,12 +376,21 @@ def act_with_policy( ) ) - # Reset intervention counters + # Reset intervention counters and environment sum_reward_episode = 0.0 episode_intervention = False episode_intervention_steps = 0 episode_total_steps = 0 + + # Reset environment and processors obs, info = online_env.reset() + complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")} + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) + transition = env_processor(transition) if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 1d60380f8..a102e3ff2 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -404,6 +404,9 @@ class ImageCropResizeProcessor: if observation is None: return transition + if self.resize_size is None and not self.crop_params_dict: + return transition + new_observation = dict(observation) # Process all image keys in the observation @@ -777,6 +780,123 @@ def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]: return env, teleop_device +def make_processors(env, cfg): + """ + Factory function to create environment and action processors. + + Args: + env: The robot environment + cfg: Configuration object containing processor parameters + + Returns: + tuple: (env_processor, action_processor) + """ + env_pipeline_steps = [ + ImageProcessor(), + StateProcessor(), + JointVelocityProcessor(dt=1.0 / cfg.fps), + MotorCurrentProcessor(env=env), + ImageCropResizeProcessor( + crop_params_dict=cfg.processor.crop_params_dict, + resize_size=cfg.processor.resize_size + ), + TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)), + GripperPenaltyProcessor( + penalty=cfg.processor.gripper_penalty, + max_gripper_pos=cfg.processor.max_gripper_pos + ), + DeviceProcessor(device=cfg.device), + ] + env_processor = RobotProcessor(steps=env_pipeline_steps) + + action_pipeline_steps = [ + InterventionActionProcessor( + use_gripper=cfg.processor.use_gripper, + ), + InverseKinematicsProcessor( + urdf_path=cfg.processor.urdf_path, + target_frame_name=cfg.processor.target_frame_name, + end_effector_step_sizes=cfg.processor.end_effector_step_sizes, + end_effector_bounds=cfg.processor.end_effector_bounds, + max_gripper_pos=cfg.processor.max_gripper_pos, + env=env, + ), + ] + action_processor = RobotProcessor(steps=action_pipeline_steps) + + return env_processor, action_processor + + +def step_env_and_process_transition( + env, + transition, + action, + teleop_device, + env_processor, + action_processor, +): + """ + Execute one step with processors handling intervention and observation processing. + + Args: + env: The robot environment + transition: Current transition state + action: Action to execute (will be replaced by neutral action in gym_manipulator mode) + teleop_device: Teleoperator device for getting intervention signals + env_processor: Environment processor for observations + action_processor: Action processor for handling interventions + + Returns: + tuple: (new_transition, terminate_episode) + """ + # Get teleoperation action and events + teleop_action = teleop_device.get_action() + teleop_events = teleop_device.get_teleop_events() + + # Create action transition + action_transition = dict(transition) + action_transition[TransitionKey.ACTION] = action + + # Add teleoperation data to complementary data + action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy() + action_complementary_data["teleop_action"] = teleop_action + action_complementary_data.update(teleop_events) + action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data + + # Process action through action pipeline (handles intervention) + processed_action_transition = action_processor(action_transition) + + # Extract processed action and metadata + processed_action = processed_action_transition[TransitionKey.ACTION] + terminate_episode = processed_action_transition.get(TransitionKey.DONE, False) + + # Step environment with processed action + obs, reward, terminated, truncated, info = env.step(processed_action) + + # Combine rewards from environment and action processor + reward = reward + processed_action_transition[TransitionKey.REWARD] + + # Process new observation + complementary_data = { + "raw_joint_positions": info.pop("raw_joint_positions"), + **processed_action_transition[TransitionKey.COMPLEMENTARY_DATA], + } + info.update(processed_action_transition[TransitionKey.INFO]) + + new_transition = create_transition( + observation=obs, + action=processed_action, + reward=reward, + done=terminated or terminate_episode, + truncated=truncated, + info=info, + complementary_data=complementary_data, + ) + new_transition = env_processor(new_transition) + + return new_transition, terminate_episode + + def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvConfig): dt = 1.0 / cfg.fps @@ -841,54 +961,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo while episode_idx < cfg.num_episodes: step_start_time = time.perf_counter() - # Get teleoperation action and extra signals - teleop_action = teleop_device.get_action() - teleop_events = teleop_device.get_teleop_events() - # Create a neutral action (no movement) neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) if hasattr(env, "use_gripper") and env.use_gripper: neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay - # Create action transition - action_transition = dict(transition) - action_transition[TransitionKey.ACTION] = neutral_action - - # Add teleoperation data to complementary data - action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy() - action_complementary_data["teleop_action"] = teleop_action - action_complementary_data.update(teleop_events) - action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data - - # Process action through action pipeline (handles intervention) - processed_action_transition = action_processor(action_transition) - - # Extract processed action and metadata - processed_action = processed_action_transition[TransitionKey.ACTION] - terminate_episode = processed_action_transition.get(TransitionKey.DONE, False) - - # Step environment with processed action - obs, reward, terminated, truncated, info = env.step(processed_action) - - reward = reward + processed_action_transition[TransitionKey.REWARD] - - # Process new observation - complementary_data = { - "raw_joint_positions": info.pop("raw_joint_positions"), - **processed_action_transition[TransitionKey.COMPLEMENTARY_DATA], - } - info.update(processed_action_transition[TransitionKey.INFO]) - - transition = create_transition( - observation=obs, - action=processed_action, - reward=reward, - done=terminated or terminate_episode, - truncated=truncated, - info=info, - complementary_data=complementary_data, + # Use the new step function + transition, terminate_episode = step_env_and_process_transition( + env=env, + transition=transition, + action=neutral_action, + teleop_device=teleop_device, + env_processor=env_processor, + action_processor=action_processor, ) - transition = env_processor(transition) terminated = transition.get(TransitionKey.DONE, False) truncated = transition.get(TransitionKey.TRUNCATED, False) @@ -963,38 +1049,7 @@ def replay_trajectory(env, action_processor, cfg): @parser.wrap() def main(cfg: EnvConfig): env, teleop_device = make_robot_env(cfg) - env_pipeline_steps = [ - ImageProcessor(), - StateProcessor(), - JointVelocityProcessor(dt=1.0 / cfg.fps), - MotorCurrentProcessor(env=env), - ImageCropResizeProcessor( - crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size - ), - TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)), - GripperPenaltyProcessor( - penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos - ), - DeviceProcessor(device=cfg.device), - ] - - env_processor = RobotProcessor(steps=env_pipeline_steps) - - action_pipeline_steps = [ - InterventionActionProcessor( - use_gripper=cfg.processor.use_gripper, - ), - InverseKinematicsProcessor( - urdf_path=cfg.processor.urdf_path, - target_frame_name=cfg.processor.target_frame_name, - end_effector_step_sizes=cfg.processor.end_effector_step_sizes, - end_effector_bounds=cfg.processor.end_effector_bounds, - max_gripper_pos=cfg.processor.max_gripper_pos, - env=env, - ), - ] - - action_processor = RobotProcessor(steps=action_pipeline_steps) + env_processor, action_processor = make_processors(env, cfg) print("Environment observation space:", env.observation_space) print("Environment action space:", env.action_space)