diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 4ee934f09..1d60380f8 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -655,9 +655,7 @@ class InverseKinematicsProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: action = transition.get(TransitionKey.ACTION) - observation = transition.get(TransitionKey.OBSERVATION) - - if action is None or observation is None: + if action is None: return transition action_np = action.detach().cpu().numpy().squeeze() @@ -945,6 +943,23 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo dataset.push_to_hub() +def replay_trajectory(env, action_processor, cfg): + dataset = LeRobotDataset( + cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode], download_videos=False + ) + dataset_actions = dataset.hf_dataset.select_columns(["action"]) + _, info = env.reset() + + for _, action in enumerate(dataset_actions): + start_time = time.perf_counter() + transition = create_transition( + action=action["action"], complementary_data={"raw_joint_positions": info["raw_joint_positions"]} + ) + transition = action_processor(transition) + env.step(transition[TransitionKey.ACTION]) + busy_wait(1 / cfg.fps - (time.perf_counter() - start_time)) + + @parser.wrap() def main(cfg: EnvConfig): env, teleop_device = make_robot_env(cfg) @@ -986,6 +1001,10 @@ def main(cfg: EnvConfig): print("Environment processor:", env_processor) print("Action processor:", action_processor) + if cfg.mode == "replay": + replay_trajectory(env, action_processor, cfg) + exit() + control_loop(env, env_processor, action_processor, teleop_device, cfg)