diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 2eae35820..4ee934f09 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -27,6 +27,7 @@ import torchvision.transforms.functional as F # noqa: N812 from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.types import PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.envs.configs import EnvConfig from lerobot.model.kinematics import RobotKinematics from lerobot.processor import ( @@ -279,7 +280,7 @@ class RobotEnv(gym.Env): @dataclass -@ProcessorStepRegistry.register("joint_velocity_processor_") +@ProcessorStepRegistry.register("joint_velocity_processor") class JointVelocityProcessor: """Add joint velocity information to observations. @@ -589,18 +590,18 @@ class InterventionActionProcessor: new_transition[TransitionKey.ACTION] = teleop_action_tensor # Handle episode termination - if terminate_episode: - new_transition[TransitionKey.DONE] = True - if success: - new_transition[TransitionKey.REWARD] = 1.0 + new_transition[TransitionKey.DONE] = bool(terminate_episode) + new_transition[TransitionKey.REWARD] = float(success) # Update info with intervention metadata info = new_transition.get(TransitionKey.INFO, {}) info["is_intervention"] = is_intervention - info["action_intervention"] = new_transition[TransitionKey.ACTION] info["rerecord_episode"] = rerecord_episode info["next.success"] = success if terminate_episode else info.get("next.success", False) new_transition[TransitionKey.INFO] = info + new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[ + TransitionKey.ACTION + ] return new_transition @@ -797,10 +798,49 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) transition = env_processor(transition) + if cfg.mode == "record": + action_features = teleop_device.action_features + features = { + "action": action_features, + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + "complementary_info.discrete_penalty": { + "dtype": "float32", + "shape": (1,), + "names": ["discrete_penalty"], + }, + } + + for key, value in transition[TransitionKey.OBSERVATION].items(): + if key == "observation.state": + features[key] = { + "dtype": "float32", + "shape": value.squeeze(0).shape, + "names": None, + } + if "image" in key: + features[key] = { + "dtype": "video", + "shape": value.squeeze(0).shape, + "names": ["channels", "height", "width"], + } + + # Create dataset + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.dataset_root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) + + episode_idx = 0 episode_step = 0 episode_start_time = time.perf_counter() - while True: + while episode_idx < cfg.num_episodes: step_start_time = time.perf_counter() # Get teleoperation action and extra signals @@ -827,14 +867,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo # Extract processed action and metadata processed_action = processed_action_transition[TransitionKey.ACTION] - action_info = processed_action_transition.get(TransitionKey.INFO, {}) terminate_episode = processed_action_transition.get(TransitionKey.DONE, False) # Step environment with processed action obs, reward, terminated, truncated, info = env.step(processed_action) + reward = reward + processed_action_transition[TransitionKey.REWARD] + # Process new observation - complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")} + complementary_data = { + "raw_joint_positions": info.pop("raw_joint_positions"), + **processed_action_transition[TransitionKey.COMPLEMENTARY_DATA], + } + info.update(processed_action_transition[TransitionKey.INFO]) + transition = create_transition( observation=obs, action=processed_action, @@ -848,14 +894,27 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo terminated = transition.get(TransitionKey.DONE, False) truncated = transition.get(TransitionKey.TRUNCATED, False) + if cfg.mode == "record": + observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()} + frame = { + **observations, + "action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"], + "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), + "next.done": np.array([terminated or truncated], dtype=bool), + "complementary_info.discrete_penalty": np.array( + [transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32 + ), + } + dataset.add_frame(frame, task=cfg.task) + episode_step += 1 # Handle episode termination if terminated or truncated or terminate_episode: - episode_end_reason = "success" if action_info.get("next.success", False) else "terminated" episode_time = time.perf_counter() - episode_start_time - print(f"Episode ended ({episode_end_reason}) after {episode_step} steps in {episode_time:.1f}s") - print(f"Rerecord episode: {action_info.get('rerecord_episode', False)}") + logging.info( + f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" + ) # Reset for new episode obs, info = env.reset() @@ -867,11 +926,24 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo transition = env_processor(transition) episode_step = 0 - episode_start_time = time.perf_counter() + episode_idx += 1 + + if cfg.mode == "record": + if transition[TransitionKey.INFO].get("rerecord_episode", False): + logging.info(f"Re-recording episode {episode_idx}") + dataset.clear_episode_buffer() + episode_idx -= 1 + else: + logging.info(f"Saving episode {episode_idx}") + dataset.save_episode() # Maintain fps timing busy_wait(dt - (time.perf_counter() - step_start_time)) + if cfg.mode == "record" and cfg.push_to_hub: + logging.info("Pushing dataset to hub") + dataset.push_to_hub() + @parser.wrap() def main(cfg: EnvConfig): @@ -914,7 +986,6 @@ def main(cfg: EnvConfig): print("Environment processor:", env_processor) print("Action processor:", action_processor) - # Run the control loop control_loop(env, env_processor, action_processor, teleop_device, cfg)