From f49280e89bb4e4de81b98fdf6034a7d1a86012df Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sun, 3 Aug 2025 23:21:13 +0200 Subject: [PATCH] RL works at this commit - fixed actor.py and bugs in gym_manipulator --- src/lerobot/scripts/rl/actor.py | 30 +++++++------- src/lerobot/scripts/rl/gym_manipulator.py | 49 +++++++++++++---------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index d2331166e..df4792a0c 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -243,7 +243,7 @@ def act_with_policy( logging.info("make_env online") online_env, teleop_device = make_robot_env(cfg=cfg.env) - env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device) + env_processor, action_processor = make_processors(online_env, cfg.env) set_seed(cfg.seed) device = get_safe_torch_device(cfg.policy.device, log=True) @@ -288,18 +288,15 @@ def act_with_policy( logging.info("[ACTOR] Shutting down act_with_policy") return - if interaction_step >= cfg.policy.online_step_before_learning: - # Time policy inference and check if it meets FPS requirement - with policy_timer: - # Extract observation from transition for policy - batch_obs = transition[TransitionKey.OBSERVATION] - action = policy.select_action(batch=batch_obs) - policy_fps = policy_timer.fps_last + observation = transition[TransitionKey.OBSERVATION] - log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + # Time policy inference and check if it meets FPS requirement + with policy_timer: + # Extract observation from transition for policy + action = policy.select_action(batch=observation) + policy_fps = policy_timer.fps_last - else: - action = online_env.action_space.sample() + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) # Use the new step function new_transition, terminate_episode = step_env_and_process_transition( @@ -312,10 +309,11 @@ def act_with_policy( ) # Extract values from processed transition + next_observation = new_transition[TransitionKey.OBSERVATION] + executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] reward = new_transition[TransitionKey.REWARD] done = new_transition.get(TransitionKey.DONE, False) truncated = new_transition.get(TransitionKey.TRUNCATED, False) - processed_action = new_transition[TransitionKey.ACTION] sum_reward_episode += float(reward) episode_total_steps += 1 @@ -329,13 +327,13 @@ def act_with_policy( # Create transition for learner (convert to old format) list_transition_to_send_to_learner.append( Transition( - state=transition[TransitionKey.OBSERVATION], - action=processed_action, + state=observation, + action=executed_action, reward=reward, - next_state=new_transition[TransitionKey.OBSERVATION], + next_state=next_observation, done=done, truncated=truncated, - complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA], + complementary_info={}, # new_transition[TransitionKey.COMPLEMENTARY_DATA], ) ) diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 29401a9ed..cb3a97701 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -800,11 +800,15 @@ def make_processors(env, cfg): crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size ), TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)), - GripperPenaltyProcessor( - penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos - ), - DeviceProcessor(device=cfg.device), ] + if cfg.processor.use_gripper: + env_pipeline_steps.append( + GripperPenaltyProcessor( + penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos + ) + ) + env_pipeline_steps.append(DeviceProcessor(device=cfg.device)) + env_processor = RobotProcessor(steps=env_pipeline_steps) action_pipeline_steps = [ @@ -920,12 +924,13 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo "action": action_features, "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, "next.done": {"dtype": "bool", "shape": (1,), "names": None}, - "complementary_info.discrete_penalty": { + } + if cfg.processor.use_gripper: + features["complementary_info.discrete_penalty"] = { "dtype": "float32", "shape": (1,), "names": ["discrete_penalty"], - }, - } + } for key, value in transition[TransitionKey.OBSERVATION].items(): if key == "observation.state": @@ -977,16 +982,17 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo truncated = transition.get(TransitionKey.TRUNCATED, False) if cfg.mode == "record": - observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()} + observations = {k: v.squeeze(0).cpu() for k, v in transition[TransitionKey.OBSERVATION].items()} frame = { **observations, - "action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"], + "action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"].cpu(), "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), "next.done": np.array([terminated or truncated], dtype=bool), - "complementary_info.discrete_penalty": np.array( - [transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32 - ), } + if cfg.processor.use_gripper: + frame["complementary_info.discrete_penalty"] = np.array( + [transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32 + ) dataset.add_frame(frame, task=cfg.task) episode_step += 1 @@ -997,16 +1003,6 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo logging.info( f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" ) - - # Reset for new episode - obs, info = env.reset() - complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")} - env_processor.reset() - action_processor.reset() - - transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) - transition = env_processor(transition) - episode_step = 0 episode_idx += 1 @@ -1019,6 +1015,15 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo logging.info(f"Saving episode {episode_idx}") dataset.save_episode() + # Reset for new episode + obs, info = env.reset() + complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")} + env_processor.reset() + action_processor.reset() + + transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) + transition = env_processor(transition) + # Maintain fps timing busy_wait(dt - (time.perf_counter() - step_start_time))