diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index 4b2e398a3..dfdf431ac 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -127,11 +127,11 @@ class RewardClassifierConfig: # Dataset configuration class DatasetConfig: repo_id: str # LeRobot dataset repository ID - dataset_root: str # Local dataset root directory task: str # Task identifier - num_episodes: int # Number of episodes for recording - episode: int # Episode index for replay - push_to_hub: bool # Whether to push datasets to Hub + root: str | None = None # Local dataset root directory + num_episodes_to_record: int = 5 # Number of episodes for recording + replay_episode: int | None = None # Episode index for replay + push_to_hub: bool = False # Whether to push datasets to Hub ``` @@ -351,7 +351,7 @@ Create a configuration file for recording demonstrations (or edit an existing on 1. Set `mode` to `"record"` at the root level 2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name") -3. Set `num_episodes` in the `dataset` section to the number of demonstrations you want to collect +3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect 4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later) 5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section @@ -390,10 +390,10 @@ Example configuration section: }, "dataset": { "repo_id": "username/pick_lift_cube", - "dataset_root": null, + "root": null, "task": "pick_and_lift", - "num_episodes": 15, - "episode": 0, + "num_episodes_to_record": 15, + "replay_episode": 0, "push_to_hub": true }, "mode": "record", @@ -626,7 +626,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r - **mode**: set it to `"record"` to collect a dataset (at root level) - **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub -- **dataset.num_episodes**: Number of episodes to record +- **dataset.num_episodes_to_record**: Number of episodes to record - **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`) - **env.fps**: Number of frames per second to record - **dataset.push_to_hub**: Whether to push the dataset to the hub @@ -664,8 +664,8 @@ Example configuration section for data collection: "repo_id": "hf_username/dataset_name", "dataset_root": "data/your_dataset", "task": "reward_classifier_task", - "num_episodes": 20, - "episode": 0, + "num_episodes_to_record": 20, + "replay_episode": null, "push_to_hub": true }, "mode": "record", diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx index bbb0cc6f9..656e650a0 100644 --- a/docs/source/hilserl_sim.mdx +++ b/docs/source/hilserl_sim.mdx @@ -107,10 +107,10 @@ To collect a dataset, set the mode to `record` whilst defining the repo_id and n }, "dataset": { "repo_id": "username/sim_dataset", - "dataset_root": null, + "root": null, "task": "pick_cube", - "num_episodes": 10, - "episode": 0, + "num_episodes_to_record": 10, + "replay_episode": null, "push_to_hub": true }, "mode": "record" diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 17d5c46c8..7f93580e5 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -36,10 +36,10 @@ To teleoperate and collect a dataset, we need to modify this config file. Here's }, "dataset": { "repo_id": "your_username/il_gym", - "dataset_root": null, + "root": null, "task": "pick_cube", - "num_episodes": 30, - "episode": 0, + "num_episodes_to_record": 30, + "replay_episode": null, "push_to_hub": true }, "mode": "record", @@ -50,7 +50,7 @@ To teleoperate and collect a dataset, we need to modify this config file. Here's Key configuration points: - Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"` -- Set `num_episodes: 30` to collect 30 demonstration episodes +- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes - Ensure `mode` is set to `"record"` - If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"` - To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"` diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 564648329..c81130d4d 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -1,7 +1,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy -from lerobot.policies.factory import make_processor +from lerobot.policies.factory import make_pre_post_processors from lerobot.record import record_loop from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.utils.control_utils import init_keyboard_listener @@ -46,7 +46,7 @@ listener, events = init_keyboard_listener() if not robot.is_connected: raise ValueError("Robot is not connected!") -preprocessor, postprocessor = make_processor( +preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy, pretrained_path=HF_MODEL_ID, dataset_stats=dataset.meta.stats, diff --git a/examples/phone_so100_eval.py b/examples/phone_to_so100/evaluate.py similarity index 97% rename from examples/phone_so100_eval.py rename to examples/phone_to_so100/evaluate.py index e3a577de5..d1190b363 100644 --- a/examples/phone_so100_eval.py +++ b/examples/phone_to_so100/evaluate.py @@ -20,7 +20,7 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur from lerobot.datasets.utils import merge_features from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy -from lerobot.policies.factory import make_processor +from lerobot.policies.factory import make_pre_post_processors from lerobot.processor.converters import ( to_output_robot_action, to_transition_robot_observation, @@ -127,7 +127,7 @@ robot.connect() episode_idx = 0 policy = ACTPolicy.from_pretrained(HF_MODEL_ID) -preprocessor, postprocessor = make_processor( +preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy, pretrained_path=HF_MODEL_ID, dataset_stats=dataset.meta.stats, diff --git a/examples/phone_so100_record.py b/examples/phone_to_so100/record.py similarity index 99% rename from examples/phone_so100_record.py rename to examples/phone_to_so100/record.py index 4ec3948ea..e9d22ef80 100644 --- a/examples/phone_so100_record.py +++ b/examples/phone_to_so100/record.py @@ -38,8 +38,8 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( ) from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS -from lerobot.teleoperators.phone.phone import Phone from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.teleoperators.phone.teleop_phone import Phone from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun diff --git a/examples/phone_so100_replay.py b/examples/phone_to_so100/replay.py similarity index 79% rename from examples/phone_so100_replay.py rename to examples/phone_to_so100/replay.py index f44207789..e39f482c2 100644 --- a/examples/phone_so100_replay.py +++ b/examples/phone_to_so100/replay.py @@ -19,7 +19,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics -from lerobot.processor.converters import to_output_robot_action +from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action from lerobot.processor.pipeline import RobotProcessor from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( @@ -49,31 +49,6 @@ kinematics_solver = RobotKinematics( joint_names=list(robot.bus.motors.keys()), ) - -# This method converts the action from the dataset to a transition for pipeline -def action_to_transition(action: dict): - act = {} - - # EE pose - for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"): - if k in action: - act[f"action.{k}"] = float(action[k]) - - # Gripper: your dataset has absolute position - if "gripper.pos" in action: - act["action.gripper.pos"] = float(action["gripper.pos"]) - - return { - "observation": None, - "action": act, - "reward": None, - "done": False, - "truncated": False, - "info": {}, - "complementary_data": {}, - } - - # Build pipeline to convert ee pose action to joint action robot_ee_to_joints = RobotProcessor( steps=[ @@ -84,7 +59,7 @@ robot_ee_to_joints = RobotProcessor( initial_guess_current_joints=False, # Because replay is open loop ), ], - to_transition=action_to_transition, + to_transition=to_transition_teleop_action, to_output=to_output_robot_action, ) diff --git a/examples/phone_so100_teleop.py b/examples/phone_to_so100/teleoperate.py similarity index 83% rename from examples/phone_so100_teleop.py rename to examples/phone_to_so100/teleoperate.py index 82515c98f..1eef0f8ae 100644 --- a/examples/phone_so100_teleop.py +++ b/examples/phone_to_so100/teleoperate.py @@ -28,8 +28,8 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( ) from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS -from lerobot.teleoperators.phone.phone import Phone from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.teleoperators.phone.teleop_phone import Phone # Initialize the robot and teleoperator robot_config = SO100FollowerConfig( @@ -48,8 +48,8 @@ kinematics_solver = RobotKinematics( joint_names=list(robot.bus.motors.keys()), ) -# Build pipeline to convert phone action to ee pose action -phone_to_robot_ee_pose = RobotProcessor( +# Build pipeline to convert phone action to ee pose action to joint action +phone_to_robot_joints = RobotProcessor( steps=[ MapPhoneActionToRobotAction(platform=teleop_config.phone_os), AddRobotObservationAsComplimentaryData(robot=robot), @@ -63,14 +63,6 @@ phone_to_robot_ee_pose = RobotProcessor( max_ee_step_m=0.10, max_ee_twist_step_rad=0.50, ), - ], - to_transition=to_transition_teleop_action, - to_output=lambda tr: tr, -) - -# Build pipeline to convert ee pose action to joint action -robot_ee_to_joints = RobotProcessor( - steps=[ InverseKinematicsEEToJoints( kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()), @@ -80,7 +72,7 @@ robot_ee_to_joints = RobotProcessor( speed_factor=20.0, ), ], - to_transition=lambda tr: tr, + to_transition=to_transition_teleop_action, to_output=to_output_robot_action, ) @@ -89,19 +81,11 @@ teleop_device.connect() print("Starting teleop loop. Move your phone to teleoperate the robot.") while True: - phone_obs = teleop_device.get_action() - if not phone_obs: - time.sleep(0.01) - continue - # Get teleop observation phone_obs = teleop_device.get_action() - # Phone to EE pose transition - ee_transition = phone_to_robot_ee_pose(phone_obs) - - # EE pose to Joints transition - joint_action = robot_ee_to_joints(ee_transition) + # Phone -> EE pose -> Joints transition + joint_action = phone_to_robot_joints(phone_obs) if joint_action: robot.send_action(joint_action) diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index 6a4ad82ce..e0f3462cc 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -287,7 +287,7 @@ class ACT(nn.Module): └───────────────────────┘ """ - def __init__(self, config: ACTConfig, dataset_stats=None): + def __init__(self, config: ACTConfig): # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). super().__init__() diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index de004f676..db6d4d85a 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -27,7 +27,7 @@ from lerobot.processor import ( ) -def make_act_processor( +def make_act_pre_post_processors( config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index 89491d022..2ecf5d700 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -28,7 +28,7 @@ from lerobot.processor import ( ) -def make_diffusion_processor( +def make_diffusion_pre_post_processors( config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 6990e8e5e..67dd928ec 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -17,10 +17,9 @@ from __future__ import annotations import logging -from typing import Any, TypedDict, cast +from typing import Any, TypedDict import torch -from torch import nn from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig @@ -117,7 +116,7 @@ class ProcessorConfigKwargs(TypedDict, total=False): dataset_stats: dict[str, dict[str, torch.Tensor]] | None -def make_processor( +def make_pre_post_processors( policy_cfg: PreTrainedConfig, pretrained_path: str | None = None, **kwargs: Unpack[ProcessorConfigKwargs], @@ -154,68 +153,65 @@ def make_processor( ) # Create a new processor based on policy type - if policy_cfg.type == "tdmpc": - from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig - from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor + if isinstance(policy_cfg, TDMPCConfig): + from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors - processors = make_tdmpc_processor( - config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_tdmpc_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif policy_cfg.type == "diffusion": - from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor + elif isinstance(policy_cfg, DiffusionConfig): + from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors - processors = make_diffusion_processor( - cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_diffusion_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif policy_cfg.type == "act": - from lerobot.policies.act.processor_act import make_act_processor + elif isinstance(policy_cfg, ACTConfig): + from lerobot.policies.act.processor_act import make_act_pre_post_processors - processors = make_act_processor( - config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_act_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif policy_cfg.type == "vqbet": - from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor + elif isinstance(policy_cfg, VQBeTConfig): + from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors - processors = make_vqbet_processor( - config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_vqbet_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif policy_cfg.type == "pi0": - from lerobot.policies.pi0.processor_pi0 import make_pi0_processor + elif isinstance(policy_cfg, PI0Config): + from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors - processors = make_pi0_processor( - config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_pi0_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif policy_cfg.type == "pi0fast": - from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor + elif isinstance(policy_cfg, PI0Config): + from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors - processors = make_pi0fast_processor( - cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_pi0fast_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif policy_cfg.type == "sac": - from lerobot.policies.sac.processor_sac import make_sac_processor + elif isinstance(policy_cfg, SACConfig): + from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors - processors = make_sac_processor( - cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_sac_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif policy_cfg.type == "reward_classifier": + elif isinstance(policy_cfg, RewardClassifierConfig): from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor - processors = make_classifier_processor( - cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") - ) + processors = make_classifier_processor(config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")) - elif policy_cfg.type == "smolvla": - from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor + elif isinstance(policy_cfg, SmolVLAConfig): + from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors - processors = make_smolvla_processor( - cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + processors = make_smolvla_pre_post_processors( + config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) else: @@ -295,7 +291,7 @@ def make_policy( policy = policy_cls(**kwargs) policy.to(cfg.device) - assert isinstance(policy, nn.Module) + assert isinstance(policy, torch.nn.Module) # policy = torch.compile(policy, mode="reduce-overhead") diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 3629f1071..8b1fc8301 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -14,11 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any import torch -from lerobot.configs.types import PolicyFeature from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( @@ -30,64 +28,43 @@ from lerobot.processor import ( UnnormalizerProcessor, ) from lerobot.processor.pipeline import ( - EnvTransition, + ComplementaryDataProcessor, ProcessorStep, ProcessorStepRegistry, - TransitionKey, ) from lerobot.processor.rename_processor import RenameProcessor @ProcessorStepRegistry.register(name="pi0_new_line_processor") -class Pi0NewLineProcessor(ProcessorStep): +class Pi0NewLineProcessor(ComplementaryDataProcessor): """Add a new line to the end of the task if it doesn't have one. This is required for the PaliGemma tokenizer. """ - def __call__(self, transition: EnvTransition) -> EnvTransition: - # Check if complementary_data exists - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data is None or "task" not in complementary_data: - return transition + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data task = complementary_data["task"] if task is None: - return transition + return complementary_data + + new_complementary_data = dict(complementary_data) # Handle both string and list of strings if isinstance(task, str): # Single string: add newline if not present if not task.endswith("\n"): - complementary_data["task"] = f"{task}\n" + new_complementary_data["task"] = f"{task}\n" elif isinstance(task, list) and all(isinstance(t, str) for t in task): # List of strings: add newline to each if not present - complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] # If task is neither string nor list of strings, leave unchanged - return transition - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Add tokenized task features to the features.""" - return features - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary (empty for this processor).""" - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary (no-op for this processor).""" - pass - - def reset(self) -> None: - """Reset processor state (no-op for this processor).""" - pass - - def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return {} + return new_complementary_data -def make_pi0_processor( +def make_pi0_pre_post_processors( config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: # Add remaining processors diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index eccfeb44f..9cfd24c76 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -28,7 +28,7 @@ from lerobot.processor import ( ) -def make_pi0fast_processor( +def make_pi0fast_pre_post_processors( config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 14a976bbe..d96d1e003 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -29,7 +29,7 @@ from lerobot.processor import ( ) -def make_sac_processor( +def make_sac_pre_post_processors( config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 231f2969e..f9d6a4594 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -13,11 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any import torch -from lerobot.configs.types import PolicyFeature from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.processor import ( @@ -29,10 +27,13 @@ from lerobot.processor import ( TokenizerProcessor, UnnormalizerProcessor, ) -from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey +from lerobot.processor.pipeline import ( + ComplementaryDataProcessor, + ProcessorStepRegistry, +) -def make_smolvla_processor( +def make_smolvla_pre_post_processors( config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ @@ -64,47 +65,27 @@ def make_smolvla_processor( @ProcessorStepRegistry.register(name="smolvla_new_line_processor") -class SmolVLANewLineProcessor(ProcessorStep): +class SmolVLANewLineProcessor(ComplementaryDataProcessor): """Add a new line to the end of the task if it doesn't have one.""" - def __call__(self, transition: EnvTransition) -> EnvTransition: - # Check if complementary_data exists - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data is None or "task" not in complementary_data: - return transition + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data task = complementary_data["task"] if task is None: - return transition + return complementary_data + + new_complementary_data = dict(complementary_data) # Handle both string and list of strings if isinstance(task, str): # Single string: add newline if not present if not task.endswith("\n"): - complementary_data["task"] = f"{task}\n" + new_complementary_data["task"] = f"{task}\n" elif isinstance(task, list) and all(isinstance(t, str) for t in task): # List of strings: add newline to each if not present - complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] # If task is neither string nor list of strings, leave unchanged - return transition - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Adds nothing to the features.""" - return features - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary (empty for this processor).""" - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary (no-op for this processor).""" - pass - - def reset(self) -> None: - """Reset processor state (no-op for this processor).""" - pass - - def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return {} + return new_complementary_data diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index 553aa1d04..9d13ef0e1 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -28,7 +28,7 @@ from lerobot.processor import ( ) -def make_tdmpc_processor( +def make_tdmpc_pre_post_processors( config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index cfcd31ddd..242e6c3b0 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -29,7 +29,7 @@ from lerobot.processor import ( ) -def make_vqbet_processor( +def make_vqbet_pre_post_processors( config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 979f7ebc4..813756490 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -15,18 +15,17 @@ # limitations under the License. from .batch_processor import ToBatchProcessor -from .delta_action_processor import MapDeltaActionToRobotAction +from .delta_action_processor import MapDeltaActionToRobotAction, MapTensorToDeltaActionDict from .device_processor import DeviceProcessor +from .gym_action_processor import Numpy2TorchActionProcessor, Torch2NumpyActionProcessor from .hil_processor import ( AddTeleopActionAsComplimentaryData, AddTeleopEventsAsInfo, GripperPenaltyProcessor, ImageCropResizeProcessor, InterventionActionProcessor, - Numpy2TorchActionProcessor, RewardClassifierProcessor, TimeLimitProcessor, - Torch2NumpyActionProcessor, ) from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats @@ -55,6 +54,7 @@ __all__ = [ "DeviceProcessor", "DoneProcessor", "MapDeltaActionToRobotAction", + "MapTensorToDeltaActionDict", "EnvTransition", "GripperPenaltyProcessor", "IdentityProcessor", diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 8a74afd3e..aab575ef7 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -11,20 +11,88 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any +from dataclasses import dataclass, field -import torch from torch import Tensor -from lerobot.configs.types import PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.processor.pipeline import ( + ActionProcessor, + ComplementaryDataProcessor, + EnvTransition, + ObservationProcessor, + ProcessorStep, + ProcessorStepRegistry, +) + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_action") +class ToBatchProcessorAction(ActionProcessor): + """Process action component in-place, adding batch dimension if needed.""" + + def action(self, action): + if not isinstance(action, Tensor) or action.dim() != 1: + return action + + return action.unsqueeze(0) + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_observation") +class ToBatchProcessorObservation(ObservationProcessor): + """Process observation component in-place, adding batch dimensions where needed.""" + + def observation(self, observation): + # Process state observations - add batch dim if 1D + for state_key in [OBS_STATE, OBS_ENV_STATE]: + if state_key in observation: + state_value = observation[state_key] + if isinstance(state_value, Tensor) and state_value.dim() == 1: + observation[state_key] = state_value.unsqueeze(0) + + # Process single image observation - add batch dim if 3D + if OBS_IMAGE in observation: + image_value = observation[OBS_IMAGE] + if isinstance(image_value, Tensor) and image_value.dim() == 3: + observation[OBS_IMAGE] = image_value.unsqueeze(0) + + # Process multiple image observations - add batch dim if 3D + for key, value in observation.items(): + if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: + observation[key] = value.unsqueeze(0) + return observation + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data") +class ToBatchProcessorComplementaryData(ComplementaryDataProcessor): + """Process complementary data in-place, handling task field batching.""" + + def complementary_data(self, complementary_data): + # Process task field - wrap string in list to add batch dimension + if "task" in complementary_data: + task_value = complementary_data["task"] + if isinstance(task_value, str): + complementary_data["task"] = [task_value] + + # Process index field - add batch dim if 0D + if "index" in complementary_data: + index_value = complementary_data["index"] + if isinstance(index_value, Tensor) and index_value.dim() == 0: + complementary_data["index"] = index_value.unsqueeze(0) + + # Process task_index field - add batch dim if 0D + if "task_index" in complementary_data: + task_index_value = complementary_data["task_index"] + if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0: + complementary_data["task_index"] = task_index_value.unsqueeze(0) + return complementary_data @dataclass @ProcessorStepRegistry.register(name="to_batch_processor") -class ToBatchProcessor: +class ToBatchProcessor(ProcessorStep): """Processor that adds batch dimensions to observations and actions when needed. This processor ensures that observations and actions have proper batch dimensions for model processing: @@ -59,81 +127,16 @@ class ToBatchProcessor: ``` """ + to_batch_action_processor: ToBatchProcessorAction = field(default_factory=ToBatchProcessorAction) + to_batch_observation_processor: ToBatchProcessorObservation = field( + default_factory=ToBatchProcessorObservation + ) + to_batch_complementary_data_processor: ToBatchProcessorComplementaryData = field( + default_factory=ToBatchProcessorComplementaryData + ) + def __call__(self, transition: EnvTransition) -> EnvTransition: - self._process_observation(transition) - self._process_action(transition) - self._process_complementary_data(transition) + transition = self.to_batch_action_processor(transition) + transition = self.to_batch_observation_processor(transition) + transition = self.to_batch_complementary_data_processor(transition) return transition - - def _process_observation(self, transition: EnvTransition) -> None: - """Process observation component in-place, adding batch dimensions where needed.""" - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - return - - # Process state observations - add batch dim if 1D - for state_key in [OBS_STATE, OBS_ENV_STATE]: - if state_key in observation: - state_value = observation[state_key] - if isinstance(state_value, Tensor) and state_value.dim() == 1: - observation[state_key] = state_value.unsqueeze(0) - - # Process single image observation - add batch dim if 3D - if OBS_IMAGE in observation: - image_value = observation[OBS_IMAGE] - if isinstance(image_value, Tensor) and image_value.dim() == 3: - observation[OBS_IMAGE] = image_value.unsqueeze(0) - - # Process multiple image observations - add batch dim if 3D - for key, value in observation.items(): - if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: - observation[key] = value.unsqueeze(0) - - def _process_action(self, transition: EnvTransition) -> None: - """Process action component in-place, adding batch dimension if needed.""" - action = transition.get(TransitionKey.ACTION) - if action is not None and isinstance(action, Tensor) and action.dim() == 1: - transition[TransitionKey.ACTION] = action.unsqueeze(0) - - def _process_complementary_data(self, transition: EnvTransition) -> None: - """Process complementary data in-place, handling task field batching.""" - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data is None: - return - - # Process task field - wrap string in list to add batch dimension - if "task" in complementary_data: - task_value = complementary_data["task"] - if isinstance(task_value, str): - complementary_data["task"] = [task_value] - - # Process index field - add batch dim if 0D - if "index" in complementary_data: - index_value = complementary_data["index"] - if isinstance(index_value, Tensor) and index_value.dim() == 0: - complementary_data["index"] = index_value.unsqueeze(0) - - # Process task_index field - add batch dim if 0D - if "task_index" in complementary_data: - task_index_value = complementary_data["task_index"] - if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0: - complementary_data["task_index"] = task_index_value.unsqueeze(0) - - def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return {} - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary (empty for this processor).""" - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary (no-op for this processor).""" - pass - - def reset(self) -> None: - """Reset processor state (no-op for this processor).""" - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index f0e081577..3a8f8b109 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -53,7 +53,7 @@ def _is_image(arr: Any) -> bool: def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: state, images = {}, {} for k, v in obs.items(): - if _is_image(v): + if "image" in k.lower() or _is_image(v): images[k] = v else: state[k] = v @@ -116,6 +116,9 @@ def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]: out: dict[str, Any] = {} action_dict = transition.get(TransitionKey.ACTION) or {} + if action_dict is None: + return out + for k, v in action_dict.items(): if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")): out_key = k[len("action.") :] # Strip the 'action.' prefix. diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index a575bb07a..63eff9aad 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import dataclass from torch import Tensor @@ -22,6 +22,30 @@ from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry +@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict") +@dataclass +class MapTensorToDeltaActionDict(ActionProcessor): + """ + Map a tensor to a delta action dictionary. + """ + + def action(self, action: Tensor) -> dict: + if isinstance(action, dict): + return action + if action.dim() > 1: + action = action.squeeze(0) + + # TODO (maractingi): add rotation + delta_action = { + "action.delta_x": action[0], + "action.delta_y": action[1], + "action.delta_z": action[2], + } + if action.shape[0] > 3: + delta_action["action.gripper"] = action[3] + return delta_action + + @ProcessorStepRegistry.register("map_delta_action_to_robot_action") @dataclass class MapDeltaActionToRobotAction(ActionProcessor): @@ -53,35 +77,25 @@ class MapDeltaActionToRobotAction(ActionProcessor): # Scale factors for delta movements position_scale: float = 1.0 rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard - gripper_deadzone: float = 0.1 # Threshold for gripper activation - _prev_enabled: bool = field(default=False, init=False, repr=False) - - def action(self, action: dict | Tensor | None) -> dict: - if action is None: - return {} + noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise + def action(self, action: dict) -> dict: # NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy # TODO (maractingi): changing this target_xyz naming convention from the teleop_devices - if isinstance(action, dict): - delta_x = action.pop("action.delta_x", 0.0) - delta_y = action.pop("action.delta_y", 0.0) - delta_z = action.pop("action.delta_z", 0.0) - gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0) - else: - delta_x = action[0].item() - delta_y = action[1].item() - delta_z = action[2].item() - gripper = action[3].item() + delta_x = action.pop("action.delta_x", 0.0) + delta_y = action.pop("action.delta_y", 0.0) + delta_z = action.pop("action.delta_z", 0.0) + gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0) # Determine if the teleoperator is actively providing input # Consider enabled if any significant movement delta is detected - position_magnitude = abs(delta_x) + abs(delta_y) + abs(delta_z) - enabled = position_magnitude > 1e-6 # Small threshold to avoid noise + position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position + enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise # Scale the deltas appropriately - scaled_delta_x = float(delta_x) * self.position_scale - scaled_delta_y = float(delta_y) * self.position_scale - scaled_delta_z = float(delta_z) * self.position_scale + scaled_delta_x = delta_x * self.position_scale + scaled_delta_y = delta_y * self.position_scale + scaled_delta_z = delta_z * self.position_scale # For gamepad/keyboard, we don't have rotation input, so set to 0 # These could be extended in the future for more sophisticated teleoperators @@ -101,7 +115,6 @@ class MapDeltaActionToRobotAction(ActionProcessor): "action.gripper": float(gripper), } - self._prev_enabled = enabled return action def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: @@ -120,6 +133,3 @@ class MapDeltaActionToRobotAction(ActionProcessor): } ) return features - - def reset(self): - self._prev_enabled = False diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 4188e6208..78a3ad797 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -18,14 +18,13 @@ from typing import Any import torch -from lerobot.configs.types import PolicyFeature -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey from lerobot.utils.utils import get_safe_torch_device @ProcessorStepRegistry.register("device_processor") @dataclass -class DeviceProcessor: +class DeviceProcessor(ProcessorStep): """Processes transitions by moving tensors to the specified device and optionally converting float dtypes. This processor ensures that all tensors in the transition are moved to the @@ -36,32 +35,30 @@ class DeviceProcessor: device: str = "cpu" float_dtype: str | None = None - _device: torch.device | None = None + + DTYPE_MAPPING = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float": torch.float32, + "double": torch.float64, + } def __post_init__(self): - self._device = get_safe_torch_device(self.device) - self.device = self._device.type + self._device: torch.device = get_safe_torch_device(self.device) + self.device = self._device.type # cuda might have changed to cuda:1 self.non_blocking = "cuda" in str(self.device) # Validate and convert float_dtype string to torch dtype if self.float_dtype is not None: - dtype_mapping = { - "float16": torch.float16, - "float32": torch.float32, - "float64": torch.float64, - "bfloat16": torch.bfloat16, - "half": torch.float16, - "float": torch.float32, - "double": torch.float64, - } - - if self.float_dtype not in dtype_mapping: - available_dtypes = list(dtype_mapping.keys()) + if self.float_dtype not in self.DTYPE_MAPPING: raise ValueError( - f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}" + f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}" ) - self._target_float_dtype = dtype_mapping[self.float_dtype] + self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype] else: self._target_float_dtype = None @@ -94,69 +91,38 @@ class DeviceProcessor: return tensor def __call__(self, transition: EnvTransition) -> EnvTransition: - # Create a copy of the transition new_transition = transition.copy() - # Process observation tensors - observation = transition.get(TransitionKey.OBSERVATION) - if observation is not None: - new_observation = { - k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v - for k, v in observation.items() - } - new_transition[TransitionKey.OBSERVATION] = new_observation + simple_tensor_keys = [ + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + ] - # Process action tensor - action = transition.get(TransitionKey.ACTION) - if action is not None and isinstance(action, torch.Tensor): - new_transition[TransitionKey.ACTION] = self._process_tensor(action) + dict_tensor_keys = [ + TransitionKey.OBSERVATION, + TransitionKey.COMPLEMENTARY_DATA, + ] - # Process reward tensor - reward = transition.get(TransitionKey.REWARD) - if reward is not None and isinstance(reward, torch.Tensor): - new_transition[TransitionKey.REWARD] = self._process_tensor(reward) + # Process simple tensors + for key in simple_tensor_keys: + value = transition.get(key) + if isinstance(value, torch.Tensor): + new_transition[key] = self._process_tensor(value) - # Process done tensor - done = transition.get(TransitionKey.DONE) - if done is not None and isinstance(done, torch.Tensor): - new_transition[TransitionKey.DONE] = self._process_tensor(done) - - # Process truncated tensor - truncated = transition.get(TransitionKey.TRUNCATED) - if truncated is not None and isinstance(truncated, torch.Tensor): - new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated) - - # Process complementary data tensors - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data is not None: - new_complementary_data = {} - - # Process all items in complementary_data - for key, value in complementary_data.items(): - if isinstance(value, torch.Tensor): - new_complementary_data[key] = self._process_tensor(value) - else: - new_complementary_data[key] = value - - new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data + # Process dictionary-like tensors + for key in dict_tensor_keys: + data_dict = transition.get(key) + if data_dict is not None: + new_data_dict = { + k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v + for k, v in data_dict.items() + } + new_transition[key] = new_data_dict return new_transition def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return {"device": self.device, "float_dtype": self.float_dtype} - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary (empty for this processor).""" - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary (no-op for this processor).""" - pass - - def reset(self) -> None: - """Reset processor state (no-op for this processor).""" - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py new file mode 100644 index 000000000..70c54cbde --- /dev/null +++ b/src/lerobot/processor/gym_action_processor.py @@ -0,0 +1,63 @@ +#! /usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, + +from dataclasses import dataclass + +import numpy as np +import torch + +from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry + + +@ProcessorStepRegistry.register("torch2numpy_action_processor") +@dataclass +class Torch2NumpyActionProcessor(ActionProcessor): + """Convert PyTorch tensor actions to NumPy arrays.""" + + squeeze_batch_dim: bool = True + + def action(self, action: torch.Tensor) -> np.ndarray: + if not isinstance(action, torch.Tensor): + raise TypeError( + f"Expected torch.Tensor or None, got {type(action).__name__}. " + "Use appropriate processor for non-tensor actions." + ) + + numpy_action = action.detach().cpu().numpy() + + # Remove batch dimensions but preserve action dimensions + # Only squeeze if there's a batch dimension (first dim == 1) + if ( + self.squeeze_batch_dim + and numpy_action.shape + and len(numpy_action.shape) > 1 + and numpy_action.shape[0] == 1 + ): + numpy_action = numpy_action.squeeze(0) + + return numpy_action + + +@ProcessorStepRegistry.register("numpy2torch_action_processor") +@dataclass +class Numpy2TorchActionProcessor(ActionProcessor): + """Convert NumPy array action to PyTorch tensor.""" + + def action(self, action: np.ndarray) -> torch.Tensor: + if not isinstance(action, np.ndarray): + raise TypeError( + f"Expected np.ndarray or None, got {type(action).__name__}. " + "Use appropriate processor for non-tensor actions." + ) + torch_action = torch.from_numpy(action) + return torch_action diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 9e31548b2..c1f4569ed 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -1,3 +1,4 @@ +import math import time from dataclasses import dataclass from typing import Any @@ -8,13 +9,14 @@ import torchvision.transforms.functional as F # noqa: N812 from lerobot.configs.types import PolicyFeature from lerobot.processor.pipeline import ( - ActionProcessor, ComplementaryDataProcessor, EnvTransition, InfoProcessor, ObservationProcessor, + ProcessorStep, ProcessorStepRegistry, TransitionKey, + TruncatedProcessor, ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents @@ -29,10 +31,10 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor): teleop_device: Teleoperator - def complementary_data(self, complementary_data: dict | None) -> dict: - complementary_data = {} if complementary_data is None else dict(complementary_data) - complementary_data["teleop_action"] = self.teleop_device.get_action() - return complementary_data + def complementary_data(self, complementary_data: dict) -> dict: + new_complementary_data = dict(complementary_data) + new_complementary_data["teleop_action"] = self.teleop_device.get_action() + return new_complementary_data @ProcessorStepRegistry.register("add_teleop_action_as_info") @@ -42,60 +44,11 @@ class AddTeleopEventsAsInfo(InfoProcessor): teleop_device: Teleoperator - def info(self, info: dict | None) -> dict: - info = {} if info is None else dict(info) + def info(self, info: dict) -> dict: + new_info = dict(info) teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})() - info.update(teleop_events) - return info - - -@ProcessorStepRegistry.register("torch2numpy_action_processor") -@dataclass -class Torch2NumpyActionProcessor(ActionProcessor): - """Convert PyTorch tensor actions to NumPy arrays.""" - - squeeze_batch_dim: bool = True - - def action(self, action: torch.Tensor | None) -> np.ndarray | None: - if action is None: - return None - - if not isinstance(action, torch.Tensor): - raise TypeError( - f"Expected torch.Tensor or None, got {type(action).__name__}. " - "Use appropriate processor for non-tensor actions." - ) - - numpy_action = action.detach().cpu().numpy() - - # Remove batch dimensions but preserve action dimensions - # Only squeeze if there's a batch dimension (first dim == 1) - if ( - self.squeeze_batch_dim - and numpy_action.shape - and len(numpy_action.shape) > 1 - and numpy_action.shape[0] == 1 - ): - numpy_action = numpy_action.squeeze(0) - - return numpy_action - - -@ProcessorStepRegistry.register("numpy2torch_action_processor") -@dataclass -class Numpy2TorchActionProcessor(ActionProcessor): - """Convert NumPy array action to PyTorch tensor.""" - - def action(self, action: np.ndarray | None) -> torch.Tensor | None: - if action is None: - return None - if not isinstance(action, np.ndarray): - raise TypeError( - f"Expected np.ndarray or None, got {type(action).__name__}. " - "Use appropriate processor for non-tensor actions." - ) - torch_action = torch.from_numpy(action) - return torch_action + new_info.update(teleop_events) + return new_info @ProcessorStepRegistry.register("image_crop_resize_processor") @@ -106,10 +59,7 @@ class ImageCropResizeProcessor(ObservationProcessor): crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None resize_size: tuple[int, int] | None = None - def observation(self, observation: dict | None) -> dict | None: - if observation is None: - return None - + def observation(self, observation: dict) -> dict: if self.resize_size is None and not self.crop_params_dict: return observation @@ -153,61 +103,43 @@ class ImageCropResizeProcessor(ObservationProcessor): @dataclass @ProcessorStepRegistry.register("time_limit_processor") -class TimeLimitProcessor: +class TimeLimitProcessor(TruncatedProcessor): """Track episode steps and enforce time limits.""" max_episode_steps: int current_step: int = 0 - def __call__(self, transition: EnvTransition) -> EnvTransition: - truncated = transition.get(TransitionKey.TRUNCATED) - if truncated is None: - return transition - + def truncated(self, truncated): self.current_step += 1 if self.current_step >= self.max_episode_steps: truncated = True - new_transition = transition.copy() - new_transition[TransitionKey.TRUNCATED] = truncated - return new_transition + # TODO (steven): missing an else truncated = False? + return truncated def get_config(self) -> dict[str, Any]: return { "max_episode_steps": self.max_episode_steps, } - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - def reset(self) -> None: self.current_step = 0 - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") -class GripperPenaltyProcessor: +class GripperPenaltyProcessor(ComplementaryDataProcessor): """Apply penalty for inappropriate gripper usage.""" penalty: float = -0.01 max_gripper_pos: float = 30.0 - def __call__(self, transition: EnvTransition) -> EnvTransition: + def complementary_data(self, complementary_data): """Calculate gripper penalty and add to complementary data.""" - action = transition.get(TransitionKey.ACTION) - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - - if complementary_data is None or action is None: - return transition + action = self.transition.get(TransitionKey.ACTION) current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None) if current_gripper_pos is None: - return transition + return complementary_data gripper_action = action[f"action.{GRIPPER_KEY}.pos"] gripper_action_normalized = gripper_action / self.max_gripper_pos @@ -222,19 +154,11 @@ class GripperPenaltyProcessor: gripper_penalty = self.penalty * int(gripper_penalty_bool) - # Add penalty information to complementary data - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - # Create new complementary data with penalty info new_complementary_data = dict(complementary_data) new_complementary_data["discrete_penalty"] = gripper_penalty - # Create new transition with updated complementary data - new_transition = transition.copy() - existing_comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - existing_comp_data.update(new_complementary_data) - new_transition[TransitionKey.COMPLEMENTARY_DATA] = existing_comp_data # type: ignore[misc] - return new_transition + return new_complementary_data def get_config(self) -> dict[str, Any]: return { @@ -242,23 +166,14 @@ class GripperPenaltyProcessor: "max_gripper_pos": self.max_gripper_pos, } - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - def reset(self) -> None: """Reset the processor state.""" self.last_gripper_state = None - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - @dataclass @ProcessorStepRegistry.register("intervention_action_processor") -class InterventionActionProcessor: +class InterventionActionProcessor(ProcessorStep): """Handle human intervention actions and episode termination.""" use_gripper: bool = False @@ -271,7 +186,8 @@ class InterventionActionProcessor: # Get intervention signals from complementary data info = transition.get(TransitionKey.INFO, {}) - teleop_action = info.get("teleop_action", {}) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + teleop_action = complementary_data.get("teleop_action", {}) is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False) terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False) success = info.get(TeleopEvents.SUCCESS, False) @@ -321,24 +237,13 @@ class InterventionActionProcessor: def get_config(self) -> dict[str, Any]: return { "use_gripper": self.use_gripper, + "terminate_on_success": self.terminate_on_success, } - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - @dataclass @ProcessorStepRegistry.register("reward_classifier_processor") -class RewardClassifierProcessor: +class RewardClassifierProcessor(ProcessorStep): """Apply reward classification to image observations.""" pretrained_path: str | None = None @@ -380,7 +285,7 @@ class RewardClassifierProcessor: reward = transition.get(TransitionKey.REWARD, 0.0) terminated = transition.get(TransitionKey.DONE, False) - if success == 1.0: + if math.isclose(success, 1, abs_tol=1e-2): reward = self.success_reward if self.terminate_on_success: terminated = True @@ -404,15 +309,3 @@ class RewardClassifierProcessor: "success_reward": self.success_reward, "terminate_on_success": self.terminate_on_success, } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/processor/joint_observations_processor.py index 185db4048..40ec4ea1e 100644 --- a/src/lerobot/processor/joint_observations_processor.py +++ b/src/lerobot/processor/joint_observations_processor.py @@ -13,30 +13,28 @@ from lerobot.robots import Robot @dataclass @ProcessorStepRegistry.register("joint_velocity_processor") -class JointVelocityProcessor: +class JointVelocityProcessor(ObservationProcessor): """Add joint velocity information to observations.""" - joint_velocity_limits: float = 100.0 - dt: float = 1.0 / 10 - num_dof: int | None = None + dt: float = 0.1 last_joint_positions: torch.Tensor | None = None - def observation(self, observation: dict | None) -> dict | None: - if observation is None: - return None - + def observation(self, observation: dict) -> dict: # Get current joint positions (assuming they're in observation.state) current_positions = observation.get("observation.state") if current_positions is None: + # TODO(steven): if we get here, then the transform_features method will not hold return observation # Initialize last joint positions if not already set if self.last_joint_positions is None: self.last_joint_positions = current_positions.clone() + joint_velocities = torch.zeros_like(current_positions) + else: + # Compute velocities + joint_velocities = (current_positions - self.last_joint_positions) / self.dt - # Compute velocities - joint_velocities = (current_positions - self.last_joint_positions) / self.dt self.last_joint_positions = current_positions.clone() # Extend observation with velocities @@ -50,7 +48,6 @@ class JointVelocityProcessor: def get_config(self) -> dict[str, Any]: return { - "joint_velocity_limits": self.joint_velocity_limits, "dt": self.dt, } @@ -58,12 +55,11 @@ class JointVelocityProcessor: self.last_joint_positions = None def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - if "observation.state" in features and self.num_dof is not None: - from lerobot.configs.types import PolicyFeature - + if "observation.state" in features: original_feature = features["observation.state"] # Double the shape to account for positions + velocities - new_shape = (original_feature.shape[0] + self.num_dof,) + original_feature.shape[1:] + new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:] + features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape) return features @@ -75,10 +71,7 @@ class MotorCurrentProcessor(ObservationProcessor): robot: Robot | None = None - def observation(self, observation: dict | None) -> dict | None: - if observation is None: - return None - + def observation(self, observation: dict) -> dict: # Get current values from robot state if self.robot is None: return observation diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 92e654472..fa635414c 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Mapping from copy import deepcopy from dataclasses import dataclass, field from typing import Any @@ -11,222 +10,110 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor, TransitionKey +from lerobot.processor.pipeline import ( + EnvTransition, + ProcessorStep, + ProcessorStepRegistry, + RobotProcessor, + TransitionKey, +) -def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: - """Convert numpy arrays and other types to torch tensors.""" +def _to_tensor(value: Any, device: torch.device | None = None) -> Tensor: + """Convert common python/numpy/torch types to a torch.float32 tensor. + + Always returns float32; preserves device if provided. + """ + if isinstance(value, torch.Tensor): + return value.to(dtype=torch.float32, device=device) + if isinstance(value, np.ndarray): + # ensure contiguous, cast to float32 then convert + return torch.from_numpy(np.ascontiguousarray(value.astype(np.float32))).to(device=device) + if isinstance(value, (int, float)): + return torch.tensor(value, dtype=torch.float32, device=device) + if isinstance(value, (list, tuple)): + return torch.tensor(value, dtype=torch.float32, device=device) + raise TypeError(f"Unsupported type for stats value: {type(value)}") + + +def _convert_stats_to_tensors( + stats: dict[str, dict[str, Any]], device: torch.device | None = None +) -> dict[str, dict[str, Tensor]]: + """Convert numeric stats values to torch tensors, preserving keys.""" tensor_stats: dict[str, dict[str, Tensor]] = {} - for key, sub in stats.items(): + for key, sub in (stats or {}).items(): + if sub is None: + continue tensor_stats[key] = {} for stat_name, value in sub.items(): - if isinstance(value, np.ndarray): - tensor_val = torch.from_numpy(value.astype(np.float32)) - elif isinstance(value, torch.Tensor): - tensor_val = value.to(dtype=torch.float32) - elif isinstance(value, (int, float, list, tuple)): - tensor_val = torch.tensor(value, dtype=torch.float32) - else: - raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}") - tensor_stats[key][stat_name] = tensor_val + tensor_stats[key][stat_name] = _to_tensor(value, device=device) return tensor_stats @dataclass -@ProcessorStepRegistry.register(name="normalizer_processor") -class NormalizerProcessor: - """Normalizes observations and actions in a single processor step. +class _NormalizationMixin: + """ + A mixin class providing core functionality for normalization and unnormalization. - This processor handles normalization of both observation and action tensors - using either mean/std normalization or min/max scaling to a [-1, 1] range. - - For each tensor key in the stats dictionary, the processor will: - - Use mean/std normalization if those statistics are provided: (x - mean) / std - - Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1 - - The processor can be configured to normalize only specific keys by setting - the normalize_keys parameter. + This class manages normalization statistics, their conversion to tensors, device placement, + and the application of normalization transformations. It is designed to be inherited by + concrete ProcessorStep implementations. """ - # Features and normalisation map are mandatory to match the design of normalize.py features: dict[str, PolicyFeature] norm_map: dict[FeatureType, NormalizationMode] - - # Pre-computed statistics coming from dataset.meta.stats for instance. stats: dict[str, dict[str, Any]] | None = None - - # Explicit subset of keys to normalise. If ``None`` every key (except - # "action") found in ``stats`` will be normalised. Using a ``set`` makes - # membership checks O(1). - normalize_keys: set[str] | None = None - + device: torch.device | str | None = None eps: float = 1e-8 + normalize_observation_keys: set[str] | None = None _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) - @classmethod - def from_lerobot_dataset( - cls, - dataset: LeRobotDataset, - features: dict[str, PolicyFeature], - norm_map: dict[FeatureType, NormalizationMode], - *, - normalize_keys: set[str] | None = None, - eps: float = 1e-8, - ) -> NormalizerProcessor: - """Factory helper that pulls statistics from a :class:`LeRobotDataset`. - - The features and norm_map parameters are mandatory to match the design - pattern used in normalize.py. - """ - - return cls( - features=features, - norm_map=norm_map, - stats=dataset.meta.stats, - normalize_keys=normalize_keys, - eps=eps, - ) - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features + # Robust JSON deserialization handling (guard empty maps) + if self.features: + first_val = next(iter(self.features.values())) + if isinstance(first_val, dict): + reconstructed = {} + for key, ft_dict in self.features.items(): + reconstructed[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map + if self.norm_map: + # if keys are strings (JSON), rebuild enum map + if all(isinstance(k, str) for k in self.norm_map.keys()): + reconstructed = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed - # Convert statistics once so we avoid repeated numpy→Tensor conversions - # during runtime. + # Convert stats to tensors and move to the target device once during initialization. self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) + self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device) - # Ensure *normalize_keys* is a set for fast look-ups and compare by - # value later when returning the configuration. - if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): - self.normalize_keys = set(self.normalize_keys) + def to(self, device: torch.device | str) -> _NormalizationMixin: + """Moves the processor's normalization stats to the specified device and returns self.""" + self.device = device + self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device) + return self - def _normalize_obs(self, observation, normalized_info): - if observation is None: - return None + def state_dict(self) -> dict[str, Tensor]: + flat: dict[str, Tensor] = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU + return flat - # Decide which keys should be normalised for this call. - if self.normalize_keys is not None: - keys_to_norm = self.normalize_keys - else: - # Use feature map to skip action keys. - keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} - - processed = dict(observation) - for key in keys_to_norm: - if key not in processed or key not in self.features: - continue - - # Check the normalization mode for this feature type - feature = self.features[key] - norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY) - - # Skip normalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - normalized_info[key] = "IDENTITY" - continue - - # Skip if no stats available for this key - if key not in self._tensor_stats: - continue - - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) + def load_state_dict(self, state: dict[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + # Load to the processor's configured device. + self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( + dtype=torch.float32, device=self.device ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = (tensor - mean) / (std + self.eps) - normalized_info[key] = "MEAN_STD" - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - normalized_info[key] = "MIN_MAX" - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - return processed - - def _normalize_action(self, action, normalized_info): - if action is None: - return action - - # Check the normalization mode for actions - norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY) - - # Skip normalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - normalized_info["action"] = "IDENTITY" - return action - - # Skip if no stats available for actions - if "action" not in self._tensor_stats: - return action - - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - normalized_info["action"] = "MEAN_STD" - return (tensor - mean) / (std + self.eps) - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - normalized_info["action"] = "MIN_MAX" - return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - # If we reach here, the required stats for the normalization mode are not available - raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") - - def __call__(self, transition: EnvTransition) -> EnvTransition: - # Track what was normalized - normalized_info = {} - - observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info) - action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info) - - # Create a new transition with normalized values - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action - - # Add normalization info to complementary data - if normalized_info: - comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - comp_data = {} if comp_data is None else dict(comp_data) - comp_data["normalized_keys"] = normalized_info - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data - - return new_transition def get_config(self) -> dict[str, Any]: config = { @@ -236,45 +123,87 @@ class NormalizerProcessor: }, "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, } - if self.normalize_keys is not None: - # Serialise as a list for YAML / JSON friendliness - config["normalize_keys"] = sorted(self.normalize_keys) + if self.normalize_observation_keys is not None: + config["normalize_observation_keys"] = sorted(self.normalize_observation_keys) return config - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat + def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]: + new_observation = dict(observation) + for key, feature in self.features.items(): + if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys: + continue + if feature.type != FeatureType.ACTION and key in new_observation: + tensor = torch.as_tensor(new_observation[key], dtype=torch.float32) + new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse) + return new_observation - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor + def _normalize_action(self, action: Any, inverse: bool) -> Tensor: + tensor = torch.as_tensor(action, dtype=torch.float32) + processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse) + return processed_action - def reset(self): - pass + def _apply_transform( + self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False + ) -> Tensor: + """Core logic to apply normalization or unnormalization.""" + norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) + if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: + return tensor - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features + if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): + raise ValueError(f"Unsupported normalization mode: {norm_mode}") + + # Ensure input tensor is on the same device as the stats. + if self.device and tensor.device != self.device: + tensor = tensor.to(self.device) + + # For Accelerate compatibility: move stats to match input tensor device + input_device = tensor.device + stats = self._tensor_stats[key] + tensor = tensor.to(dtype=torch.float32) + + # Move stats to input device if needed + stats_device = next(iter(stats.values())).device + if stats_device != input_device: + stats = _convert_stats_to_tensors({key: self._tensor_stats[key]}, device=input_device)[key] + + if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + # Avoid division by zero by adding a small epsilon. + denom = std + self.eps + if inverse: + return tensor * std + mean + return (tensor - mean) / denom + + if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + denom = max_val - min_val + # When min_val == max_val, substitute the denominator with a small epsilon + # to prevent division by zero. This consistently maps an input equal to + # min_val to -1, ensuring a stable transformation. + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom + ) + if inverse: + # Map from [-1, 1] back to [min, max] + return (tensor + 1) / 2 * denom + min_val + # Map from [min, max] to [-1, 1] + return 2 * (tensor - min_val) / denom - 1 + + # If necessary stats are missing, return input unchanged. + return tensor @dataclass -@ProcessorStepRegistry.register(name="unnormalizer_processor") -class UnnormalizerProcessor: - """Inverse normalisation for observations and actions. - - Exactly mirrors :class:`NormalizerProcessor` but applies the inverse - transform. +@ProcessorStepRegistry.register(name="normalizer_processor") +class NormalizerProcessor(_NormalizationMixin, ProcessorStep): """ + A processor that applies normalization to observations and actions in a transition. - features: dict[str, PolicyFeature] - norm_map: dict[FeatureType, NormalizationMode] - stats: dict[str, dict[str, Any]] | None = None - - _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + This class directly implements the normalization logic for both observation and action + components of an `EnvTransition`, using statistics (mean/std or min/max) provided at + initialization. + """ @classmethod def from_lerobot_dataset( @@ -282,194 +211,89 @@ class UnnormalizerProcessor: dataset: LeRobotDataset, features: dict[str, PolicyFeature], norm_map: dict[FeatureType, NormalizationMode], - ) -> UnnormalizerProcessor: - return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) - - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features - - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map - - self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) - - def _unnormalize_obs(self, observation, unnormalized_info): - if observation is None: - return None - keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] - processed = dict(observation) - for key in keys: - if key not in processed or key not in self.features: - continue - - # Check the normalization mode for this feature type - feature = self.features[key] - norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY) - - # Skip unnormalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - unnormalized_info[key] = "IDENTITY" - continue - - # Skip if no stats available for this key - if key not in self._tensor_stats: - continue - - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = tensor * std + mean - unnormalized_info[key] = "MEAN_STD" - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val - unnormalized_info[key] = "MIN_MAX" - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - return processed - - def _unnormalize_action(self, action, unnormalized_info): - if action is None: - return action - - # Check the normalization mode for actions - norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY) - - # Skip unnormalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - unnormalized_info["action"] = "IDENTITY" - return action - - # Skip if no stats available for actions - if "action" not in self._tensor_stats: - return action - - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) + *, + normalize_observation_keys: set[str] | None = None, + eps: float = 1e-8, + device: torch.device | str | None = None, + ) -> NormalizerProcessor: + return cls( + features=features, + norm_map=norm_map, + stats=dataset.meta.stats, + normalize_observation_keys=normalize_observation_keys, + eps=eps, + device=device, ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - unnormalized_info["action"] = "MEAN_STD" - return tensor * std + mean - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - unnormalized_info["action"] = "MIN_MAX" - return (tensor + 1) / 2 * (max_val - min_val) + min_val - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - # If we reach here, the required stats for the normalization mode are not available - raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") def __call__(self, transition: EnvTransition) -> EnvTransition: - # Track what was unnormalized - unnormalized_info = {} - - observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info) - action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info) - - # Create a new transition with unnormalized values new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action - # Add unnormalization info to complementary data - if unnormalized_info: - comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - comp_data = {} if comp_data is None else dict(comp_data) - comp_data["unnormalized_keys"] = unnormalized_info - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + # Handle observation normalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation( + observation, inverse=False + ) + + # Handle action normalization. + action = new_transition.get(TransitionKey.ACTION) + if action is not None: + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False) return new_transition - def get_config(self) -> dict[str, Any]: - return { - "features": { - key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() - }, - "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, - } - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat +@dataclass +@ProcessorStepRegistry.register(name="unnormalizer_processor") +class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep): + """ + A processor that applies unnormalization (the inverse of normalization) to + observations and actions in a transition. - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor + This is typically used to transform actions from a normalized policy output back into + the original scale for execution in an environment. + """ - def reset(self): - pass + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + *, + device: torch.device | str | None = None, + ) -> UnnormalizerProcessor: + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device) - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features + def __call__(self, transition: EnvTransition) -> EnvTransition: + new_transition = transition.copy() + + # Handle observation unnormalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True) + + # Handle action unnormalization. + action = new_transition.get(TransitionKey.ACTION) + if action is not None: + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True) + + return new_transition def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor: - robot_processor = deepcopy(robot_processor) - for step in robot_processor.steps: - if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor): - step: NormalizerProcessor | UnnormalizerProcessor - step.stats = stats - step._tensor_stats = _convert_stats_to_tensors(stats) - return robot_processor - - -def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: - """Rename keys in the stats dictionary according to the provided mapping. - - Args: - stats: The statistics dictionary with structure {feature_key: {stat_name: value}} - rename_map: Dictionary mapping old key names to new key names - - Returns: - A new stats dictionary with renamed keys - - Example: - >>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}} - >>> rename_map = {"observation.state": "observation.robot_state"} - >>> new_stats = rename_stats(stats, rename_map) - >>> # new_stats will have "observation.robot_state" instead of "observation.state" """ - renamed_stats = {} + Replaces normalization statistics in a RobotProcessor pipeline. - for old_key, sub_stats in stats.items(): - # Use the new key if it exists in the rename map, otherwise keep the old key - new_key = rename_map.get(old_key, old_key) - renamed_stats[new_key] = deepcopy(sub_stats) - - return renamed_stats + This function creates a deep copy of the provided `RobotProcessor` and updates the + statistics of any `NormalizerProcessor` or `UnnormalizerProcessor` steps within it. + It's useful for adapting a trained policy to a new environment or dataset with + different data distributions. + """ + rp = deepcopy(robot_processor) + for step in rp.steps: + if isinstance(step, _NormalizationMixin): + step.stats = stats + # Re-initialize tensor_stats on the correct device. + step._tensor_stats = _convert_stats_to_tensors(stats, device=step.device) + return rp diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 19dc668f7..8ffe490d6 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -18,12 +18,13 @@ from __future__ import annotations import importlib import json import os +from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Protocol, TypedDict, runtime_checkable +from typing import Any, TypedDict import torch from huggingface_hub import ModelHubMixin, hf_hub_download @@ -132,8 +133,7 @@ class ProcessorStepRegistry: cls._registry.clear() -@runtime_checkable -class ProcessorStep(Protocol): +class ProcessorStep(ABC): """Structural typing interface for a single processor step. A step is any callable accepting a full `EnvTransition` dict and @@ -166,17 +166,34 @@ class ProcessorStep(Protocol): - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} """ - def __call__(self, transition: EnvTransition) -> EnvTransition: ... + _current_transition: EnvTransition | None = None - def get_config(self) -> dict[str, Any]: ... + @property + def transition(self) -> EnvTransition: + """The current transition being processed by this step.""" + if self._current_transition is None: + raise ValueError("Transition is not set. Make sure to call the step with a transition first.") + return self._current_transition - def state_dict(self) -> dict[str, torch.Tensor]: ... + @abstractmethod + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ... + def get_config(self) -> dict[str, Any]: + return {} - def reset(self) -> None: ... + def state_dict(self) -> dict[str, torch.Tensor]: + return {} - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + return None + + def reset(self) -> None: + return None + + # TODO(Steven): Consider making this abstract so it is more explicit + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 @@ -820,6 +837,7 @@ class RobotProcessor(ModelHubMixin): def __post_init__(self): for i, step in enumerate(self.steps): if not callable(step): + # TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated raise TypeError( f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" ) @@ -837,7 +855,7 @@ class RobotProcessor(ModelHubMixin): return features -class ObservationProcessor: +class ObservationProcessor(ProcessorStep, ABC): """Base class for processors that modify only the observation component of a transition. Subclasses should override the `observation` method to implement custom observation processing. @@ -858,7 +876,8 @@ class ObservationProcessor: manipulation, focusing only on the specific observation processing logic. """ - def observation(self, observation): + @abstractmethod + def observation(self, observation) -> dict[str, Any]: """Process the observation component. Args: @@ -867,36 +886,22 @@ class ObservationProcessor: Returns: The processed observation """ - return observation + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) + self._current_transition = transition.copy() + new_transition = self._current_transition + + observation = new_transition.get(TransitionKey.OBSERVATION) if observation is None: - return transition + return new_transition processed_observation = self.observation(observation) - # Create a new transition dict with the processed observation - new_transition = transition.copy() new_transition[TransitionKey.OBSERVATION] = processed_observation return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class ActionProcessor: +class ActionProcessor(ProcessorStep, ABC): """Base class for processors that modify only the action component of a transition. Subclasses should override the `action` method to implement custom action processing. @@ -918,7 +923,8 @@ class ActionProcessor: manipulation, focusing only on the specific action processing logic. """ - def action(self, action): + @abstractmethod + def action(self, action) -> Any | torch.Tensor: """Process the action component. Args: @@ -927,36 +933,22 @@ class ActionProcessor: Returns: The processed action """ - return action + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition.get(TransitionKey.ACTION) + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) if action is None: - return transition + return new_transition processed_action = self.action(action) - # Create a new transition dict with the processed action - new_transition = transition.copy() new_transition[TransitionKey.ACTION] = processed_action return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class RewardProcessor: +class RewardProcessor(ProcessorStep, ABC): """Base class for processors that modify only the reward component of a transition. Subclasses should override the `reward` method to implement custom reward processing. @@ -977,7 +969,8 @@ class RewardProcessor: manipulation, focusing only on the specific reward processing logic. """ - def reward(self, reward): + @abstractmethod + def reward(self, reward) -> float | torch.Tensor: """Process the reward component. Args: @@ -986,36 +979,22 @@ class RewardProcessor: Returns: The processed reward """ - return reward + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - reward = transition.get(TransitionKey.REWARD) + self._current_transition = transition.copy() + new_transition = self._current_transition + + reward = new_transition.get(TransitionKey.REWARD) if reward is None: - return transition + return new_transition processed_reward = self.reward(reward) - # Create a new transition dict with the processed reward - new_transition = transition.copy() new_transition[TransitionKey.REWARD] = processed_reward return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class DoneProcessor: +class DoneProcessor(ProcessorStep, ABC): """Base class for processors that modify only the done flag of a transition. Subclasses should override the `done` method to implement custom done flag processing. @@ -1041,7 +1020,8 @@ class DoneProcessor: manipulation, focusing only on the specific done flag processing logic. """ - def done(self, done): + @abstractmethod + def done(self, done) -> bool | torch.Tensor: """Process the done flag. Args: @@ -1050,36 +1030,22 @@ class DoneProcessor: Returns: The processed done flag """ - return done + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - done = transition.get(TransitionKey.DONE) + self._current_transition = transition.copy() + new_transition = self._current_transition + + done = new_transition.get(TransitionKey.DONE) if done is None: - return transition + return new_transition processed_done = self.done(done) - # Create a new transition dict with the processed done flag - new_transition = transition.copy() new_transition[TransitionKey.DONE] = processed_done return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class TruncatedProcessor: +class TruncatedProcessor(ProcessorStep, ABC): """Base class for processors that modify only the truncated flag of a transition. Subclasses should override the `truncated` method to implement custom truncated flag processing. @@ -1101,7 +1067,8 @@ class TruncatedProcessor: manipulation, focusing only on the specific truncated flag processing logic. """ - def truncated(self, truncated): + @abstractmethod + def truncated(self, truncated) -> bool | torch.Tensor: """Process the truncated flag. Args: @@ -1110,36 +1077,22 @@ class TruncatedProcessor: Returns: The processed truncated flag """ - return truncated + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - truncated = transition.get(TransitionKey.TRUNCATED) + self._current_transition = transition.copy() + new_transition = self._current_transition + + truncated = new_transition.get(TransitionKey.TRUNCATED) if truncated is None: - return transition + return new_transition processed_truncated = self.truncated(truncated) - # Create a new transition dict with the processed truncated flag - new_transition = transition.copy() new_transition[TransitionKey.TRUNCATED] = processed_truncated return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class InfoProcessor: +class InfoProcessor(ProcessorStep, ABC): """Base class for processors that modify only the info dictionary of a transition. Subclasses should override the `info` method to implement custom info processing. @@ -1166,7 +1119,8 @@ class InfoProcessor: manipulation, focusing only on the specific info dictionary processing logic. """ - def info(self, info): + @abstractmethod + def info(self, info) -> dict[str, Any]: """Process the info dictionary. Args: @@ -1175,36 +1129,22 @@ class InfoProcessor: Returns: The processed info dictionary """ - return info + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - info = transition.get(TransitionKey.INFO) + self._current_transition = transition.copy() + new_transition = self._current_transition + + info = new_transition.get(TransitionKey.INFO) if info is None: - return transition + return new_transition processed_info = self.info(info) - # Create a new transition dict with the processed info - new_transition = transition.copy() new_transition[TransitionKey.INFO] = processed_info return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class ComplementaryDataProcessor: +class ComplementaryDataProcessor(ProcessorStep, ABC): """Base class for processors that modify only the complementary data of a transition. Subclasses should override the `complementary_data` method to implement custom complementary data processing. @@ -1212,7 +1152,8 @@ class ComplementaryDataProcessor: into the transition dict, eliminating the need to implement the `__call__` method in subclasses. """ - def complementary_data(self, complementary_data): + @abstractmethod + def complementary_data(self, complementary_data) -> dict[str, Any]: """Process the complementary data. Args: @@ -1221,52 +1162,23 @@ class ComplementaryDataProcessor: Returns: The processed complementary data """ - return complementary_data + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + self._current_transition = transition.copy() + new_transition = self._current_transition + + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) if complementary_data is None: - return transition + return new_transition processed_complementary_data = self.complementary_data(complementary_data) - # Create a new transition dict with the processed complementary data - new_transition = transition.copy() new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class IdentityProcessor: +class IdentityProcessor(ProcessorStep): """Identity processor that does nothing.""" def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - - def get_config(self) -> dict[str, Any]: - return {} - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index db20424df..ebc867cac 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from dataclasses import dataclass, field from typing import Any @@ -49,3 +50,14 @@ class RenameProcessor(ObservationProcessor): - Keys not in `rename_map` remain unchanged. """ return {self.rename_map.get(k, k): v for k, v in features.items()} + + +def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: + """Rename keys in the stats dictionary according to rename_map (defensive copy).""" + if not stats: + return {} + renamed: dict[str, dict[str, Any]] = {} + for old_key, sub_stats in stats.items(): + new_key = rename_map.get(old_key, old_key) + renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {} + return renamed diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 003b00a4b..d2c04e44c 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -11,7 +11,12 @@ import torch from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import OBS_LANGUAGE -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.processor.pipeline import ( + EnvTransition, + ObservationProcessor, + ProcessorStepRegistry, + TransitionKey, +) from lerobot.utils.import_utils import _transformers_available if TYPE_CHECKING or _transformers_available: @@ -22,7 +27,7 @@ else: @dataclass @ProcessorStepRegistry.register(name="tokenizer_processor") -class TokenizerProcessor: +class TokenizerProcessor(ObservationProcessor): """Tokenizes text tasks in complementary data using a huggingface tokenizer. This processor handles tokenization of task strings found in the complementary_data @@ -118,7 +123,7 @@ class TokenizerProcessor: return None - def __call__(self, transition: EnvTransition) -> EnvTransition: + def observation(self, observation): """Process the transition by tokenizing the task text. Args: @@ -130,15 +135,15 @@ class TokenizerProcessor: Raises: ValueError: If tokenizer initialization failed. """ - task = self.get_task(transition) + task = self.get_task(self.transition) if task is None: - return transition + return observation # Tokenize the task (creates CPU tensors) tokenized_prompt = self._tokenize_text(task) # Detect device from existing tensors in the transition - target_device = self._detect_device(transition) + target_device = self._detect_device(self.transition) # Move tokenized tensors to match the device of other data if target_device is not None: @@ -148,20 +153,15 @@ class TokenizerProcessor: } # Get or create observation dict - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - observation = {} - else: - observation = dict(observation) # Make a copy + new_observation = dict(observation) # Add tokenized data to observation - observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"] - observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to( + new_observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"] + new_observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to( dtype=torch.bool ) - transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc] - return transition + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: """Detect device from existing tensors in the transition. @@ -187,19 +187,6 @@ class TokenizerProcessor: if isinstance(action, torch.Tensor): return action.device - # Check other tensor fields - for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]: - value = transition.get(key) - if isinstance(value, torch.Tensor): - return value.device - - # Check complementary data for tensors - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data: - for value in complementary_data.values(): - if isinstance(value, torch.Tensor): - return value.device - return None # No tensors found, keep on CPU def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]: @@ -235,23 +222,12 @@ class TokenizerProcessor: } # Only include tokenizer_name if it was used (not when tokenizer object was provided) - if self.tokenizer_name is not None: + # TODO(steven): Consider saving the name of the _tokenizer if it was loaded + if self.tokenizer_name is not None and self.tokenizer is None: config["tokenizer_name"] = self.tokenizer_name return config - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary (empty for this processor).""" - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary (no-op for this processor).""" - pass - - def reset(self) -> None: - """Reset processor state (no-op for this processor).""" - pass - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Add tokenized task features to the feature contract. diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 41e63f4a0..0ebe23501 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -74,7 +74,7 @@ from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.video_utils import VideoEncodingManager -from lerobot.policies.factory import make_policy, make_processor +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import RobotProcessor from lerobot.processor.converters import ( @@ -83,8 +83,8 @@ from lerobot.processor.converters import ( to_transition_robot_observation, to_transition_teleop_action, ) -from lerobot.processor.normalize_processor import rename_stats from lerobot.processor.pipeline import IdentityProcessor, TransitionKey +from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -346,13 +346,14 @@ def record_loop( else: logging.info( "No policy or teleoperator provided, skipping action generation. " - "This is likely to happen during environment reset." + "This is likely to happen when resetting the environment without a teleop device." + "The robot won't be at its rest position at the start of the next episode." ) - # Still continue to next loop to respect timing + continue # Applies a pipeline to the action, default is IdentityProcessor # IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action() - if policy_transition is not None: + if policy is not None and policy_transition is not None: robot_action_to_send = robot_action_processor(policy_transition) else: robot_action_to_send = robot_action_processor(teleop_transition) @@ -434,7 +435,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: preprocessor = None postprocessor = None if cfg.policy is not None: - preprocessor, postprocessor = make_processor( + preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), @@ -510,5 +511,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: return dataset -if __name__ == "__main__": +def main(): record() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index 2b62fd67f..f6628ac6b 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -45,9 +45,11 @@ from dataclasses import asdict, dataclass from pathlib import Path from pprint import pformat -import draccus - +from lerobot.configs import parser from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor import RobotProcessor +from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action +from lerobot.processor.pipeline import IdentityProcessor from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -83,13 +85,25 @@ class ReplayConfig: dataset: DatasetReplayConfig # Use vocal synthesis to read events. play_sounds: bool = True + # Optional processor for actions before sending to robot + robot_action_processor: RobotProcessor | None = None -@draccus.wrap() +@parser.wrap() def replay(cfg: ReplayConfig): init_logging() logging.info(pformat(asdict(cfg))) + # Initialize robot action processor with default if not provided + robot_action_processor = cfg.robot_action_processor or RobotProcessor( + steps=[IdentityProcessor()], + to_transition=to_transition_teleop_action, + to_output=to_output_robot_action, # type: ignore[arg-type] + ) + + # Reset processor + robot_action_processor.reset() + robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) actions = dataset.hf_dataset.select_columns("action") @@ -104,7 +118,10 @@ def replay(cfg: ReplayConfig): for i, name in enumerate(dataset.features["action"]["names"]): action[name] = action_array[i] - robot.send_action(action) + # Process action through robot action processor + processed_action = robot_action_processor(action) + + robot.send_action(processed_action) # type: ignore[arg-type] dt_s = time.perf_counter() - start_episode_t busy_wait(1 / dataset.fps - dt_s) diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index 6d85507bb..7c6c73a4d 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -19,13 +19,14 @@ from dataclasses import dataclass, field import numpy as np from scipy.spatial.transform import Rotation -from lerobot.configs.types import PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.model.kinematics import RobotKinematics from lerobot.processor.pipeline import ( ActionProcessor, ComplementaryDataProcessor, EnvTransition, ObservationProcessor, + ProcessorStep, ProcessorStepRegistry, TransitionKey, ) @@ -34,7 +35,7 @@ from lerobot.robots.robot import Robot @ProcessorStepRegistry.register("ee_reference_and_delta") @dataclass -class EEReferenceAndDelta: +class EEReferenceAndDelta(ActionProcessor): """ Compute the desired end-effector pose from the target pose and the current pose. @@ -61,9 +62,9 @@ class EEReferenceAndDelta: _prev_enabled: bool = field(default=False, init=False, repr=False) _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False) - def __call__(self, transition: EnvTransition) -> EnvTransition: - act = transition.get(TransitionKey.ACTION) or {} - comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + def action(self, action): + new_action = action.copy() + comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) # Get joint positions from complimentary data raw = comp.get("raw_joint_positions", None) @@ -80,13 +81,13 @@ class EEReferenceAndDelta: # Current pose from FK on measured joints t_curr = self.kinematics.forward_kinematics(q) - enabled = bool(act.pop("action.enabled", 0)) - tx = float(act.pop("action.target_x", 0.0)) - ty = float(act.pop("action.target_y", 0.0)) - tz = float(act.pop("action.target_z", 0.0)) - wx = float(act.pop("action.target_wx", 0.0)) - wy = float(act.pop("action.target_wy", 0.0)) - wz = float(act.pop("action.target_wz", 0.0)) + enabled = bool(new_action.pop("action.enabled", 0)) + tx = float(new_action.pop("action.target_x", 0.0)) + ty = float(new_action.pop("action.target_y", 0.0)) + tz = float(new_action.pop("action.target_z", 0.0)) + wx = float(new_action.pop("action.target_wx", 0.0)) + wy = float(new_action.pop("action.target_wy", 0.0)) + wz = float(new_action.pop("action.target_wz", 0.0)) desired = None @@ -122,22 +123,36 @@ class EEReferenceAndDelta: # Write action fields pos = desired[:3, 3] tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec() - act.update( - { - "action.ee.x": float(pos[0]), - "action.ee.y": float(pos[1]), - "action.ee.z": float(pos[2]), - "action.ee.wx": float(tw[0]), - "action.ee.wy": float(tw[1]), - "action.ee.wz": float(tw[2]), - } - ) + new_action["action.ee.x"] = float(pos[0]) + new_action["action.ee.y"] = float(pos[1]) + new_action["action.ee.z"] = float(pos[2]) + new_action["action.ee.wx"] = float(tw[0]) + new_action["action.ee.wy"] = float(tw[1]) + new_action["action.ee.wz"] = float(tw[2]) self._prev_enabled = enabled - transition[TransitionKey.ACTION] = act - return transition + return new_action + + def reset(self): + self._prev_enabled = False + self.reference_ee_pose = None + self._command_when_disabled = None def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features.pop("action.enabled", None) + features.pop("action.target_x", None) + features.pop("action.target_y", None) + features.pop("action.target_z", None) + features.pop("action.target_wx", None) + features.pop("action.target_wy", None) + features.pop("action.target_wz", None) + + features["action.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features @@ -162,14 +177,15 @@ class EEBoundsAndSafety(ActionProcessor): max_ee_step_m: float = 0.05 max_ee_twist_step_rad: float = 0.20 _last_pos: np.ndarray | None = field(default=None, init=False, repr=False) + _last_twist: np.ndarray | None = field(default=None, init=False, repr=False) - def action(self, act: dict | None) -> dict: - x = act.pop("action.ee.x", None) - y = act.pop("action.ee.y", None) - z = act.pop("action.ee.z", None) - wx = act.pop("action.ee.wx", None) - wy = act.pop("action.ee.wy", None) - wz = act.pop("action.ee.wz", None) + def action(self, act: dict) -> dict: + x = act.get("action.ee.x", None) + y = act.get("action.ee.y", None) + z = act.get("action.ee.z", None) + wx = act.get("action.ee.wx", None) + wy = act.get("action.ee.wy", None) + wz = act.get("action.ee.wz", None) if None in (x, y, z, wx, wy, wz): return act @@ -191,35 +207,22 @@ class EEBoundsAndSafety(ActionProcessor): self._last_pos = pos self._last_twist = twist - act.update( - { - "action.ee.x": float(pos[0]), - "action.ee.y": float(pos[1]), - "action.ee.z": float(pos[2]), - "action.ee.wx": float(twist[0]), - "action.ee.wy": float(twist[1]), - "action.ee.wz": float(twist[2]), - } - ) + act["action.ee.x"] = float(pos[0]) + act["action.ee.y"] = float(pos[1]) + act["action.ee.z"] = float(pos[2]) + act["action.ee.wx"] = float(twist[0]) + act["action.ee.wy"] = float(twist[1]) + act["action.ee.wz"] = float(twist[2]) return act def reset(self): self._last_pos = None - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # Because this is last step we specify the dataset features of this step that we want to be stored in the dataset - features["action.ee.x"] = float - features["action.ee.y"] = float - features["action.ee.z"] = float - features["action.ee.wx"] = float - features["action.ee.wy"] = float - features["action.ee.wz"] = float - return features + self._last_twist = None @ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints") @dataclass -class InverseKinematicsEEToJoints: +class InverseKinematicsEEToJoints(ProcessorStep): """ Compute the desired joint positions from the desired end-effector pose. @@ -255,18 +258,6 @@ class InverseKinematicsEEToJoints: wz = act.get("action.ee.wz", None) if None in (x, y, z, wx, wy, wz): - # Nothing to do; restore what we popped and return - act.update( - { - "action.ee.x": x, - "action.ee.y": y, - "action.ee.z": z, - "action.ee.wx": wx, - "action.ee.wy": wy, - "action.ee.wz": wz, - } - ) - transition[TransitionKey.ACTION] = act return transition # Get joint positions from complimentary data @@ -303,16 +294,11 @@ class InverseKinematicsEEToJoints: return transition def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We specify the dataset features of this step that we want to be stored in the dataset - features["action.ee.x"] = float - features["action.ee.y"] = float - features["action.ee.z"] = float - features["action.ee.wx"] = float - features["action.ee.wy"] = float - features["action.ee.wz"] = float + features["observation.state.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + for name in self.motor_names: + features[f"action.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) - features["observation.state.gripper.pos"] = float - features["action.gripper.pos"] = float return features def reset(self): @@ -321,7 +307,7 @@ class InverseKinematicsEEToJoints: @ProcessorStepRegistry.register("gripper_velocity_to_joint") @dataclass -class GripperVelocityToJoint: +class GripperVelocityToJoint(ProcessorStep): """ Convert the gripper velocity to a joint velocity. @@ -379,14 +365,13 @@ class GripperVelocityToJoint: new_act.pop("action.gripper", None) transition[TransitionKey.ACTION] = new_act - obs.update({"observation.state.gripper.pos": curr_pos}) + obs["observation.state.gripper.pos"] = curr_pos transition[TransitionKey.OBSERVATION] = obs return transition def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We specify the dataset features of this step that we want to be stored in the dataset - features["observation.state.gripper.pos"] = float - features["action.gripper.pos"] = float + features.pop("action.gripper", None) + features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features @@ -410,7 +395,7 @@ class ForwardKinematicsJointsToEE(ObservationProcessor): kinematics: RobotKinematics motor_names: list[str] - def observation(self, obs: dict | None) -> dict: + def observation(self, obs: dict) -> dict: if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names): return obs @@ -419,22 +404,18 @@ class ForwardKinematicsJointsToEE(ObservationProcessor): pos = t[:3, 3] tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() - obs.update( - { - "observation.state.ee.x": float(pos[0]), - "observation.state.ee.y": float(pos[1]), - "observation.state.ee.z": float(pos[2]), - "observation.state.ee.wx": float(tw[0]), - "observation.state.ee.wy": float(tw[1]), - "observation.state.ee.wz": float(tw[2]), - } - ) + obs["observation.state.ee.x"] = float(pos[0]) + obs["observation.state.ee.y"] = float(pos[1]) + obs["observation.state.ee.z"] = float(pos[2]) + obs["observation.state.ee.wx"] = float(tw[0]) + obs["observation.state.ee.wy"] = float(tw[1]) + obs["observation.state.ee.wz"] = float(tw[2]) return obs def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: # We specify the dataset features of this step that we want to be stored in the dataset for k in ["x", "y", "z", "wx", "wy", "wz"]: - features[f"observation.state.ee.{k}"] = float + features[f"observation.state.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features @@ -451,15 +432,14 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor): robot: Robot def complementary_data(self, comp: dict | None) -> dict: - comp = {} if comp is None else dict(comp) - obs = self.robot.get_observation() + new_comp = dict(comp) + obs = ( + self.robot.get_observation() + ) # todo(steven): why not self.trtansition.get(TransitionKey.OBSERVATION)? - comp["raw_joint_positions"] = { + new_comp["raw_joint_positions"] = { k.removesuffix(".pos"): float(v) for k, v in obs.items() if isinstance(k, str) and k.endswith(".pos") } - return comp - - def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features + return new_comp diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 87e751b26..dca1adc83 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -29,10 +29,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .so100_follower import SO100Follower return SO100Follower(config) - elif config.type == "so100_follower_end_effector": - from .so100_follower import SO100FollowerEndEffector - - return SO100FollowerEndEffector(config) elif config.type == "so101_follower": from .so101_follower import SO101Follower diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index 5f44d3c5f..997bc620b 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -98,7 +98,6 @@ from lerobot.utils.utils import ( ACTOR_SHUTDOWN_TIMEOUT = 30 - ################################################# # Main entry point # ################################################# @@ -288,7 +287,9 @@ def act_with_policy( logging.info("[ACTOR] Shutting down act_with_policy") return - observation = transition[TransitionKey.OBSERVATION] + observation = { + k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features + } # Time policy inference and check if it meets FPS requirement with policy_timer: @@ -308,8 +309,16 @@ def act_with_policy( ) # Extract values from processed transition - next_observation = new_transition[TransitionKey.OBSERVATION] - executed_action = new_transition[TransitionKey.ACTION] + next_observation = { + k: v + for k, v in new_transition[TransitionKey.OBSERVATION].items() + if k in cfg.policy.input_features + } + + # Teleop action is the action that was executed in the environment + # It is either the action from the teleop device or the action from the policy + executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] + reward = new_transition[TransitionKey.REWARD] done = new_transition.get(TransitionKey.DONE, False) truncated = new_transition.get(TransitionKey.TRUNCATED, False) diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 37ff1cc7e..835b85190 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -37,6 +37,7 @@ from lerobot.processor import ( InterventionActionProcessor, JointVelocityProcessor, MapDeltaActionToRobotAction, + MapTensorToDeltaActionDict, MotorCurrentProcessor, Numpy2TorchActionProcessor, RewardClassifierProcessor, @@ -80,11 +81,11 @@ class DatasetConfig: """Configuration for dataset creation and management.""" repo_id: str - dataset_root: str task: str - num_episodes: int - episode: int - push_to_hub: bool + root: str | None = None + num_episodes_to_record: int = 5 + replay_episode: int | None = None + push_to_hub: bool = False @dataclass @@ -401,13 +402,11 @@ def make_processors( joint_names=motor_names, ) - env_pipeline_steps = [ - VanillaObservationProcessor(), - ] + env_pipeline_steps = [VanillaObservationProcessor()] if cfg.processor.observation is not None: if cfg.processor.observation.add_joint_velocity_to_observation: - env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps, num_dof=len(motor_names))) + env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps)) if cfg.processor.observation.add_current_to_observation: env_pipeline_steps.append(MotorCurrentProcessor(robot=env.robot)) @@ -473,6 +472,7 @@ def make_processors( if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None: # Add EE bounds and safety processor inverse_kinematics_steps = [ + MapTensorToDeltaActionDict(), MapDeltaActionToRobotAction(), EEReferenceAndDelta( kinematics=kinematics_solver, @@ -625,7 +625,7 @@ def control_loop( dataset = LeRobotDataset.create( cfg.dataset.repo_id, cfg.env.fps, - root=cfg.dataset.dataset_root, + root=cfg.dataset.root, use_videos=True, image_writer_threads=4, image_writer_processes=0, @@ -636,7 +636,7 @@ def control_loop( episode_step = 0 episode_start_time = time.perf_counter() - while episode_idx < cfg.dataset.num_episodes: + while episode_idx < cfg.dataset.num_episodes_to_record: step_start_time = time.perf_counter() # Create a neutral action (no movement) @@ -711,10 +711,12 @@ def control_loop( def replay_trajectory(env: gym.Env, action_processor: RobotProcessor, cfg: GymManipulatorConfig) -> None: """Replay recorded trajectory on robot environment.""" + assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay" + dataset = LeRobotDataset( cfg.dataset.repo_id, - root=cfg.dataset.dataset_root, - episodes=[cfg.dataset.episode], + root=cfg.dataset.root, + episodes=[cfg.dataset.replay_episode], download_videos=False, ) dataset_actions = dataset.hf_dataset.select_columns(["action"]) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 57eb0db60..68361fe14 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -32,7 +32,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy, make_processor +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.scripts.eval import eval_policy @@ -141,7 +141,7 @@ def train(cfg: TrainPipelineConfig): cfg=cfg.policy, ds_meta=dataset.meta, ) - preprocessor, postprocessor = make_processor( + preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats ) diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 92a26311a..d8101f0b3 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -56,11 +56,18 @@ import time from dataclasses import asdict, dataclass from pprint import pformat -import draccus import rerun as rr from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.processor import RobotProcessor +from lerobot.processor.converters import ( + to_output_robot_action, + to_transition_robot_observation, + to_transition_teleop_action, +) +from lerobot.processor.pipeline import IdentityProcessor from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -97,39 +104,82 @@ class TeleoperateConfig: teleop_time_s: float | None = None # Display all cameras on screen display_data: bool = False + # Optional processors for data transformation + teleop_action_processor: RobotProcessor | None = None # runs after teleop + robot_action_processor: RobotProcessor | None = None # runs before robot + robot_observation_processor: RobotProcessor | None = None # runs after robot def teleop_loop( - teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None + teleop: Teleoperator, + robot: Robot, + fps: int, + display_data: bool = False, + duration: float | None = None, + teleop_action_processor: RobotProcessor | None = None, + robot_action_processor: RobotProcessor | None = None, + robot_observation_processor: RobotProcessor | None = None, ): + # Initialize processors with defaults if not provided + teleop_action_processor = teleop_action_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr + ) + robot_action_processor = robot_action_processor or RobotProcessor( + steps=[IdentityProcessor()], + to_transition=lambda tr: tr, + to_output=to_output_robot_action, # type: ignore[arg-type] + ) + robot_observation_processor = robot_observation_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr + ) + + # Reset processors + teleop_action_processor.reset() + robot_action_processor.reset() + robot_observation_processor.reset() + display_len = max(len(key) for key in robot.action_features) start = time.perf_counter() + while True: loop_start = time.perf_counter() - action = teleop.get_action() - if display_data: - observation = robot.get_observation() - log_rerun_data(observation=observation, action=action) - robot.send_action(action) + # Get teleop action + raw_action = teleop.get_action() + + # Process teleop action through pipeline + teleop_transition = teleop_action_processor(raw_action) + + # Process action for robot through pipeline + robot_action_to_send = robot_action_processor(teleop_transition) + + # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any]) + robot.send_action(robot_action_to_send) # type: ignore[arg-type] + + if display_data: + # Get robot observation + obs = robot.get_observation() + # Process robot observation through pipeline + obs_transition = robot_observation_processor(obs) + log_rerun_data([obs_transition, teleop_transition]) + + print("\n" + "-" * (display_len + 10)) + print(f"{'NAME':<{display_len}} | {'NORM':>7}") + # Display the final robot action that was sent + for motor, value in robot_action_to_send.items(): + print(f"{motor:<{display_len}} | {value:>7.2f}") + move_cursor_up(len(robot_action_to_send) + 5) + dt_s = time.perf_counter() - loop_start busy_wait(1 / fps - dt_s) - loop_s = time.perf_counter() - loop_start - - print("\n" + "-" * (display_len + 10)) - print(f"{'NAME':<{display_len}} | {'NORM':>7}") - for motor, value in action.items(): - print(f"{motor:<{display_len}} | {value:>7.2f}") print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") if duration is not None and time.perf_counter() - start >= duration: return - move_cursor_up(len(action) + 5) - -@draccus.wrap() +@parser.wrap() def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) @@ -143,7 +193,16 @@ def teleoperate(cfg: TeleoperateConfig): robot.connect() try: - teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s) + teleop_loop( + teleop=teleop, + robot=robot, + fps=cfg.fps, + display_data=cfg.display_data, + duration=cfg.teleop_time_s, + teleop_action_processor=cfg.teleop_action_processor, + robot_action_processor=cfg.robot_action_processor, + robot_observation_processor=cfg.robot_observation_processor, + ) except KeyboardInterrupt: pass finally: diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index a8b003ede..c4daaa1e2 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -177,16 +177,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): "names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2}, } - def _on_press(self, key): - if hasattr(key, "char"): - key = key.char - self.event_queue.put((key, True)) - - def _on_release(self, key): - if hasattr(key, "char"): - key = key.char - self.event_queue.put((key, False)) - def get_action(self) -> dict[str, Any]: if not self.is_connected: raise DeviceNotConnectedError( diff --git a/src/lerobot/teleoperators/phone/phone.py b/src/lerobot/teleoperators/phone/phone.py deleted file mode 100644 index 3c6d5fc5d..000000000 --- a/src/lerobot/teleoperators/phone/phone.py +++ /dev/null @@ -1,246 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Docs: -# hebi: https://docs.hebi.us/tools.html#mobile-io -# teleop: https://github.com/SpesRobotics/teleop - -import logging -import threading -import time - -import hebi -import numpy as np -from scipy.spatial.transform import Rotation -from teleop import Teleop - -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError -from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS -from lerobot.teleoperators.teleoperator import Teleoperator - -logger = logging.getLogger(__name__) - - -class Phone(Teleoperator): - """ - Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API). - For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs. - - Press and hold **B1** to enable teleoperation. While enabled, the first B1 press - captures a reference pose and rotation, when disabled and pressed again the position is reapplied. - """ - - config_class = PhoneConfig - name = "phone" - - def __init__(self, config: PhoneConfig): - super().__init__(config) - self.config = config - self._group = None - self._teleop = None - self._teleop_thread = None - self._latest_pose = None - self._latest_message = None - self._enabled: bool = False - self._calib_pos: np.ndarray | None = None - self._calib_rot_inv: Rotation | None = None - - @property - def is_connected(self) -> bool: - return (self.config.phone_os == PhoneOS.IOS and self._group is not None) or ( - self.config.phone_os == PhoneOS.ANDROID and self._teleop is not None - ) - - def connect(self) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - - if self.config.phone_os == PhoneOS.IOS: - logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") - lookup = hebi.Lookup() - time.sleep(2.0) - group = lookup.get_group_from_names(["HEBI"], ["mobileIO"]) - if group is None: - raise RuntimeError("Mobile I/O not found — check name/family settings in the app.") - self._group = group - logger.info(f"{self} connected to HEBI group with {group.size} module(s).") - elif self.config.phone_os == PhoneOS.ANDROID: - logger.info("Starting teleop stream for Android...") - self._teleop = Teleop() - self._teleop.subscribe(self._android_callback) - self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True) - self._teleop_thread.start() - logger.info(f"{self} connected, teleop stream started.") - else: - raise ValueError(f"Invalid config phone_os: {self.config.phone_os}") - - self.calibrate() - - def calibrate(self) -> None: - print( - "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" - ) - if self.config.phone_os == PhoneOS.IOS: - print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n") - else: - print("Touch and move on the WebXR page to capture this pose...\n") - - pos, rot = self._wait_for_capture_trigger() - self._calib_pos = pos.copy() - self._calib_rot_inv = rot.inv() - self._enabled = False - print("Calibration done\n") - - def _reapply_position_calibration(self, pos: np.ndarray) -> None: - self._calib_pos = pos.copy() - - @property - def is_calibrated(self) -> bool: - return (self._calib_pos is not None) and (self._calib_rot_inv is not None) - - @property - def action_features(self) -> dict[str, type]: - return { - "phone.pos": np.ndarray, # shape (3,) - "phone.rot": Rotation, # scipy.spatial.transform.Rotation - "phone.raw_inputs": dict, # analogs/buttons or webXR meta - "phone.enabled": bool, - } - - def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: - """Wait trigger for calibration: iOS: B1. Android: 'move'.""" - while True: - ok, pos, rot, pose = self._read_current_pose() - if not ok: - time.sleep(0.01) - continue - - if self.config.phone_os == PhoneOS.IOS: - io = getattr(pose, "io", None) - b = getattr(io, "b", None) if io is not None else None - b1 = False - if b is not None: - b1 = bool(b.get_int(1)) - if b1: - return pos, rot - else: - msg = self._latest_message or {} - if bool(msg.get("move", False)): - return pos, rot - - time.sleep(0.01) - - def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: - if self.config.phone_os == PhoneOS.IOS: - fbk = self._group.get_next_feedback() - pose = fbk[0] - ar_pos = getattr(pose, "ar_position", None) - ar_quat = getattr(pose, "ar_orientation", None) - if ar_pos is None or ar_quat is None: - return False, None, None, None - quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw - rot = Rotation.from_quat(quat_xyzw) - pos = ar_pos - rot.apply(self.config.camera_offset) - return True, pos, rot, pose - else: - p = self._latest_pose - if p is None: - return False, None, None, None - rot = Rotation.from_matrix(p[:3, :3]) - pos = p[:3, 3] - rot.apply(self.config.camera_offset) - pose = self._latest_pose - return True, pos, rot, pose - - @property - def feedback_features(self) -> dict[str, type]: - # No haptic or other feedback implemented yet - pass - - def configure(self) -> None: - # No additional configuration required for phone teleop - pass - - def _android_callback(self, pose: np.ndarray, message: dict) -> None: - self._latest_pose = pose - self._latest_message = message - time.sleep(0.001) # 1ms delay to avoid race condition - - def get_action(self) -> dict: - ok, raw_pos, raw_rot, pose = self._read_current_pose() - if not ok or not self.is_calibrated: - return {} - - # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) - raw_inputs: dict[str, float | int | bool] = {} - if self.config.phone_os == PhoneOS.IOS: - io = getattr(pose, "io", None) - if io is not None: - bank_a, bank_b = io.a, io.b - if bank_a: - for ch in range(1, 9): - if bank_a.has_float(ch): - raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch)) - if bank_b: - for ch in range(1, 9): - if bank_b.has_int(ch): - raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch)) - elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch): - raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch)) - else: - msg = self._latest_message or {} - raw_inputs["move"] = bool(msg.get("move", False)) - raw_inputs["scale"] = float(msg.get("scale", 1.0)) - raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False)) - raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False)) - - if self.config.phone_os == PhoneOS.IOS: - enable = bool(raw_inputs.get("b1", 0)) - else: - enable = bool(raw_inputs.get("move", False)) - - # Rising edge then re-capture calibration immediately from current raw pose - if enable and not self._enabled: - self._reapply_position_calibration(raw_pos) - - # Apply calibration - pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos) - rot_cal = self._calib_rot_inv * raw_rot - - self._enabled = enable - - return { - "phone.pos": pos_cal, - "phone.rot": rot_cal, - "phone.raw_inputs": raw_inputs, - "phone.enabled": self._enabled, - } - - def send_feedback(self, feedback: dict[str, float]) -> None: - # We could add haptic feedback (vibrations) here, but it's not implemented yet - raise NotImplementedError - - def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - if self.config.phone_os == PhoneOS.IOS: - self._group = None - else: - self._teleop = None - if self._teleop_thread and self._teleop_thread.is_alive(): - self._teleop_thread.join(timeout=1.0) - self._teleop_thread = None - self._latest_pose = None diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py index 436ee8444..36880e0c8 100644 --- a/src/lerobot/teleoperators/phone/phone_processor.py +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field -from lerobot.configs.types import PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry from lerobot.teleoperators.phone.config_phone import PhoneOS @@ -46,9 +46,9 @@ class MapPhoneActionToRobotAction(ActionProcessor): platform: PhoneOS _enabled_prev: bool = field(default=False, init=False, repr=False) - def action(self, act: dict | None) -> dict: + def action(self, act: dict) -> dict: # Pop them from the action - enabled = act.pop("action.phone.enabled", 0) + enabled = bool(act.pop("action.phone.enabled", 0)) pos = act.pop("action.phone.pos", None) rot = act.pop("action.phone.rot", None) inputs = act.pop("action.phone.raw_inputs", {}) @@ -69,19 +69,28 @@ class MapPhoneActionToRobotAction(ActionProcessor): ) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed # For some actions we need to invert the axis - act.update( - { - "action.enabled": enabled, - "action.target_x": -pos[1] if enabled else 0.0, - "action.target_y": pos[0] if enabled else 0.0, - "action.target_z": pos[2] if enabled else 0.0, - "action.target_wx": rotvec[1] if enabled else 0.0, - "action.target_wy": rotvec[0] if enabled else 0.0, - "action.target_wz": -rotvec[2] if enabled else 0.0, - "action.gripper": gripper, # Still send gripper action when disabled - } - ) + act["action.enabled"] = enabled + act["action.target_x"] = -pos[1] if enabled else 0.0 + act["action.target_y"] = pos[0] if enabled else 0.0 + act["action.target_z"] = pos[2] if enabled else 0.0 + act["action.target_wx"] = rotvec[1] if enabled else 0.0 + act["action.target_wy"] = rotvec[0] if enabled else 0.0 + act["action.target_wz"] = -rotvec[2] if enabled else 0.0 + act["action.gripper"] = gripper # Still send gripper action when disabled return act def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features.pop("action.phone.enabled", None) + features.pop("action.phone.pos", None) + features.pop("action.phone.rot", None) + features.pop("action.phone.raw_inputs", None) + + features["action.enabled"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.target_x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.target_y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.target_z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.target_wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.target_wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.target_wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features["action.gripper"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py new file mode 100644 index 000000000..ed985c081 --- /dev/null +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Docs: +# hebi: https://docs.hebi.us/tools.html#mobile-io +# teleop: https://github.com/SpesRobotics/teleop + +import logging +import threading +import time + +import hebi +import numpy as np +from scipy.spatial.transform import Rotation +from teleop import Teleop + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.teleoperator import Teleoperator + +logger = logging.getLogger(__name__) + + +class BasePhone: + _enabled: bool = False + _calib_pos: np.ndarray | None = None + _calib_rot_inv: Rotation | None = None + + def _reapply_position_calibration(self, pos: np.ndarray) -> None: + self._calib_pos = pos.copy() + + @property + def is_calibrated(self) -> bool: + return (self._calib_pos is not None) and (self._calib_rot_inv is not None) + + @property + def action_features(self) -> dict[str, type]: + return { + "phone.pos": np.ndarray, # shape (3,) + "phone.rot": Rotation, # scipy.spatial.transform.Rotation + "phone.raw_inputs": dict, # analogs/buttons or webXR meta + "phone.enabled": bool, + } + + @property + def feedback_features(self) -> dict[str, type]: + # No haptic or other feedback implemented yet + pass + + def configure(self) -> None: + # No additional configuration required for phone teleop + pass + + def send_feedback(self, feedback: dict[str, float]) -> None: + # We could add haptic feedback (vibrations) here, but it's not implemented yet + raise NotImplementedError + + +class IOSPhone(BasePhone, Teleoperator): + name = "ios_phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._group = None + + @property + def is_connected(self) -> bool: + return self._group is not None + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") + lookup = hebi.Lookup() + time.sleep(2.0) + group = lookup.get_group_from_names(["HEBI"], ["mobileIO"]) + if group is None: + raise RuntimeError("Mobile I/O not found — check name/family settings in the app.") + self._group = group + logger.info(f"{self} connected to HEBI group with {group.size} module(s).") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n") + + pos, rot = self._wait_for_capture_trigger() + self._calib_pos = pos.copy() + self._calib_rot_inv = rot.inv() + self._enabled = False + print("Calibration done\n") + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """Wait trigger for calibration: iOS: B1. Android: 'move'.""" + while True: + ok, pos, rot, pose = self._read_current_pose() + if not ok: + time.sleep(0.01) + continue + + io = getattr(pose, "io", None) + b = getattr(io, "b", None) if io is not None else None + b1 = False + if b is not None: + b1 = bool(b.get_int(1)) + if b1: + return pos, rot + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + fbk = self._group.get_next_feedback() + pose = fbk[0] + ar_pos = getattr(pose, "ar_position", None) + ar_quat = getattr(pose, "ar_orientation", None) + if ar_pos is None or ar_quat is None: + return False, None, None, None + # HEBI provides orientation in w, x, y, z format. + # Scipy's Rotation expects x, y, z, w. + quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw + rot = Rotation.from_quat(quat_xyzw) + pos = ar_pos - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + + def get_action(self) -> dict: + ok, raw_pos, raw_rot, pose = self._read_current_pose() + if not ok or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + io = getattr(pose, "io", None) + if io is not None: + bank_a, bank_b = io.a, io.b + if bank_a: + for ch in range(1, 9): + if bank_a.has_float(ch): + raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch)) + if bank_b: + for ch in range(1, 9): + if bank_b.has_int(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch)) + elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch)) + + enable = bool(raw_inputs.get("b1", 0)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_pos) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rot + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._group = None + + +class AndroidPhone(BasePhone, Teleoperator): + name = "android_phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._teleop = None + self._teleop_thread = None + self._latest_pose = None + self._latest_message = None + self._android_lock = threading.Lock() + + @property + def is_connected(self) -> bool: + return self._teleop is not None + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + logger.info("Starting teleop stream for Android...") + self._teleop = Teleop() + self._teleop.subscribe(self._android_callback) + self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True) + self._teleop_thread.start() + logger.info(f"{self} connected, teleop stream started.") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + print("Touch and move on the WebXR page to capture this pose...\n") + + pos, rot = self._wait_for_capture_trigger() + self._calib_pos = pos.copy() + self._calib_rot_inv = rot.inv() + self._enabled = False + print("Calibration done\n") + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """Wait trigger for calibration: iOS: B1. Android: 'move'.""" + while True: + with self._android_lock: + msg = self._latest_message or {} + + if bool(msg.get("move", False)): + ok, pos, rot, _pose = self._read_current_pose() + if ok: + return pos, rot + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + with self._android_lock: + if self._latest_pose is None: + return False, None, None, None + p = self._latest_pose.copy() + pose = self._latest_pose + rot = Rotation.from_matrix(p[:3, :3]) + pos = p[:3, 3] - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + + def _android_callback(self, pose: np.ndarray, message: dict) -> None: + with self._android_lock: + self._latest_pose = pose + self._latest_message = message + + def get_action(self) -> dict: + ok, raw_pos, raw_rot, pose = self._read_current_pose() + if not ok or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + msg = self._latest_message or {} + raw_inputs["move"] = bool(msg.get("move", False)) + raw_inputs["scale"] = float(msg.get("scale", 1.0)) + raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False)) + raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False)) + + enable = bool(raw_inputs.get("move", False)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_pos) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rot + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._teleop = None + if self._teleop_thread and self._teleop_thread.is_alive(): + self._teleop_thread.join(timeout=1.0) + self._teleop_thread = None + self._latest_pose = None + + +class Phone(Teleoperator): + """ + Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API). + For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs. + + Press and hold **B1** to enable teleoperation. While enabled, the first B1 press + captures a reference pose and rotation, when disabled and pressed again the position is reapplied. + """ + + config_class = PhoneConfig + name = "phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + + self._phone_impl: Teleoperator + + if self.config.phone_os == PhoneOS.IOS: + self._phone_impl = IOSPhone(config) + elif self.config.phone_os == PhoneOS.ANDROID: + self._phone_impl = AndroidPhone(config) + else: + raise ValueError(f"Invalid config phone_os: {self.config.phone_os}") + + @property + def is_connected(self) -> bool: + return self._phone_impl.is_connected + + def connect(self) -> None: + return self._phone_impl.connect() + + def calibrate(self) -> None: + return self._phone_impl.calibrate() + + @property + def is_calibrated(self) -> bool: + return self._phone_impl.is_calibrated + + @property + def action_features(self) -> dict[str, type]: + return self._phone_impl.action_features + + @property + def feedback_features(self) -> dict[str, type]: + return self._phone_impl.feedback_features + + def configure(self) -> None: + return self._phone_impl.configure() + + def get_action(self) -> dict: + return self._phone_impl.get_action() + + def send_feedback(self, feedback: dict[str, float]) -> None: + return self._phone_impl.send_feedback(feedback) + + def disconnect(self) -> None: + return self._phone_impl.disconnect() diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 5a4606a8c..d91c31dbd 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -23,7 +23,7 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy, make_policy_config, make_processor +from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors from lerobot.processor import TransitionKey from lerobot.utils.random_utils import set_seed @@ -40,7 +40,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): dataset = make_dataset(train_cfg) dataset_stats = dataset.meta.stats policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) - preprocessor, postprocessor = make_processor(train_cfg.policy, dataset_stats=dataset_stats) + preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats) policy.train() optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index a135b344f..f0dfe3c9f 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -39,7 +39,7 @@ from lerobot.policies.factory import ( get_policy_class, make_policy, make_policy_config, - make_processor, + make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.random_utils import seeded_context @@ -151,7 +151,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # Check that we can make the policy object. dataset = make_dataset(train_cfg) - preprocessor, _ = make_processor(train_cfg.policy, None) + preprocessor, _ = make_pre_post_processors(train_cfg.policy, None) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) assert isinstance(policy, PreTrainedPolicy) @@ -225,7 +225,7 @@ def test_act_backbone_lr(): assert cfg.policy.optimizer_lr_backbone == 0.001 dataset = make_dataset(cfg) - preprocessor, _ = make_processor(cfg.policy, None) + preprocessor, _ = make_pre_post_processors(cfg.policy, None) policy = make_policy(cfg.policy, ds_meta=dataset.meta) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index 03fa35a2b..afba2aa26 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -23,7 +23,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.act.configuration_act import ACTConfig -from lerobot.policies.act.processor_act import make_act_processor +from lerobot.policies.act.processor_act import make_act_pre_post_processors from lerobot.processor import ( DeviceProcessor, NormalizerProcessor, @@ -78,7 +78,7 @@ def test_make_act_processor_basic(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) # Check processor names assert preprocessor.name == "robot_preprocessor" @@ -102,7 +102,7 @@ def test_act_processor_normalization(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) # Create test data observation = {OBS_STATE: torch.randn(7)} @@ -131,7 +131,7 @@ def test_act_processor_cuda(): config.device = "cuda" stats = create_default_stats() - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) # Create CPU data observation = {OBS_STATE: torch.randn(7)} @@ -160,7 +160,7 @@ def test_act_processor_accelerate_scenario(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) # Simulate Accelerate: data already on GPU device = torch.device("cuda:0") @@ -183,7 +183,7 @@ def test_act_processor_multi_gpu(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) # Simulate data on different GPU (like in multi-GPU training) device = torch.device("cuda:1") @@ -203,7 +203,7 @@ def test_act_processor_without_stats(): """Test ACT processor creation without dataset statistics.""" config = create_default_config() - preprocessor, postprocessor = make_act_processor(config, dataset_stats=None) + preprocessor, postprocessor = make_act_pre_post_processors(config, dataset_stats=None) # Should still create processors, but normalization won't have stats assert preprocessor is not None @@ -223,7 +223,7 @@ def test_act_processor_save_and_load(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) with tempfile.TemporaryDirectory() as tmpdir: # Save preprocessor @@ -249,7 +249,7 @@ def test_act_processor_device_placement_preservation(): # Test with CPU config config.device = "cpu" - preprocessor, _ = make_act_processor(config, stats) + preprocessor, _ = make_act_pre_post_processors(config, stats) # Process CPU data observation = {OBS_STATE: torch.randn(7)} @@ -269,7 +269,7 @@ def test_act_processor_mixed_precision(): stats = create_default_stats() # Modify the device processor to use float16 - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) # Replace DeviceProcessor with one that uses float16 for i, step in enumerate(preprocessor.steps): @@ -294,7 +294,7 @@ def test_act_processor_batch_consistency(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_act_processor(config, stats) + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) # Test single sample (unbatched) observation = {OBS_STATE: torch.randn(7)} diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index c9c4cd1dd..0bf050e20 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -603,24 +603,6 @@ def test_action_dtype_preservation(): assert result[TransitionKey.ACTION].shape == (1, 4) -def test_action_in_place_mutation(): - """Test that the processor mutates the transition in place for actions.""" - processor = ToBatchProcessor() - - action = torch.randn(4) - transition = create_transition(action=action) - - # Store reference to original transition - original_transition = transition - - # Process - result = processor(transition) - - # Should be the same object (in-place mutation) - assert result is original_transition - assert result[TransitionKey.ACTION].shape == (1, 4) - - def test_empty_action_tensor(): """Test handling of empty action tensors.""" processor = ToBatchProcessor() @@ -851,27 +833,6 @@ def test_task_comprehensive_string_cases(): processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert processed_comp_data["task"] == task_list assert isinstance(processed_comp_data["task"], list) - assert processed_comp_data["task"] is task_list # Should be same object (in-place) - - -def test_task_in_place_mutation(): - """Test that the processor mutates complementary_data in place for tasks.""" - processor = ToBatchProcessor() - - complementary_data = {"task": "sort_objects"} - transition = create_transition(complementary_data=complementary_data) - - # Store reference to original transition and complementary_data - original_transition = transition - original_comp_data = complementary_data - - # Process - result = processor(transition) - - # Should be the same objects (in-place mutation) - assert result is original_transition - assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data - assert original_comp_data["task"] == ["sort_objects"] def test_task_preserves_other_keys(): @@ -1127,3 +1088,49 @@ def test_empty_index_tensor(): # Should remain unchanged (already 1D) assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,) + + +def test_action_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed action.""" + processor = ToBatchProcessor() + + action = torch.randn(4) + transition = create_transition(action=action) + + # Store reference to original transition + original_transition = transition + + # Process + result = processor(transition) + + # Should be a different object (functional design, not in-place mutation) + assert result is not original_transition + # Original transition should remain unchanged + assert original_transition[TransitionKey.ACTION].shape == (4,) + # Result should have correctly processed action with batch dimension + assert result[TransitionKey.ACTION].shape == (1, 4) + assert torch.equal(result[TransitionKey.ACTION][0], action) + + +def test_task_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed task.""" + processor = ToBatchProcessor() + + complementary_data = {"task": "sort_objects"} + transition = create_transition(complementary_data=complementary_data) + + # Store reference to original transition and complementary_data + original_transition = transition + original_comp_data = complementary_data + + # Process + result = processor(transition) + + # Should be different transition object (functional design) + assert result is not original_transition + # But complementary_data is the same reference (current implementation behavior) + assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data + # The task should be processed correctly (wrapped in list) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["sort_objects"] + # Original complementary data is also modified (current behavior) + assert original_comp_data["task"] == ["sort_objects"] diff --git a/tests/processor/test_diffusion_processor.py b/tests/processor/test_diffusion_processor.py index 4b029d64c..e6d3ea590 100644 --- a/tests/processor/test_diffusion_processor.py +++ b/tests/processor/test_diffusion_processor.py @@ -23,7 +23,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor +from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors from lerobot.processor import ( DeviceProcessor, NormalizerProcessor, @@ -81,7 +81,7 @@ def test_make_diffusion_processor_basic(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Check processor names assert preprocessor.name == "robot_preprocessor" @@ -105,7 +105,7 @@ def test_diffusion_processor_with_images(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Create test data with images observation = { @@ -131,7 +131,7 @@ def test_diffusion_processor_cuda(): config.device = "cuda" stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Create CPU data observation = { @@ -164,7 +164,7 @@ def test_diffusion_processor_accelerate_scenario(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Simulate Accelerate: data already on GPU device = torch.device("cuda:0") @@ -191,7 +191,7 @@ def test_diffusion_processor_multi_gpu(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Simulate data on different GPU device = torch.device("cuda:1") @@ -215,7 +215,7 @@ def test_diffusion_processor_without_stats(): """Test Diffusion processor creation without dataset statistics.""" config = create_default_config() - preprocessor, postprocessor = make_diffusion_processor(config, dataset_stats=None) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -238,7 +238,7 @@ def test_diffusion_processor_save_and_load(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) with tempfile.TemporaryDirectory() as tmpdir: # Save preprocessor @@ -269,7 +269,7 @@ def test_diffusion_processor_mixed_precision(): stats = create_default_stats() # Create processor - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Replace DeviceProcessor with one that uses float16 for i, step in enumerate(preprocessor.steps): @@ -298,7 +298,7 @@ def test_diffusion_processor_identity_normalization(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Create test data image_value = torch.rand(3, 224, 224) * 255 # Large values @@ -322,7 +322,7 @@ def test_diffusion_processor_batch_consistency(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_diffusion_processor(config, stats) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) # Test with different batch sizes for batch_size in [1, 8, 32]: diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 5813cc37d..6b904eee7 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -25,7 +25,6 @@ from lerobot.processor.normalize_processor import ( UnnormalizerProcessor, _convert_stats_to_tensors, hotswap_stats, - rename_stats, ) from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey @@ -182,7 +181,10 @@ def test_selective_normalization(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map() normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + features=features, + norm_map=norm_map, + stats=observation_stats, + normalize_observation_keys={"observation.image"}, ) observation = { @@ -243,6 +245,7 @@ def test_from_lerobot_dataset(): def test_state_dict_save_load(observation_normalizer): # Save state state_dict = observation_normalizer.state_dict() + print("State dict:", state_dict) # Create new normalizer and load state features = _create_observation_features() @@ -464,10 +467,10 @@ def test_processor_from_lerobot_dataset(full_stats): norm_map = _create_full_norm_map() processor = NormalizerProcessor.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_keys={"observation.image"} + mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"} ) - assert processor.normalize_keys == {"observation.image"} + assert processor.normalize_observation_keys == {"observation.image"} assert "observation.image" in processor._tensor_stats assert "action" in processor._tensor_stats @@ -476,12 +479,16 @@ def test_get_config(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) config = processor.get_config() expected_config = { - "normalize_keys": ["observation.image"], + "normalize_observation_keys": ["observation.image"], "eps": 1e-6, "features": { "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, @@ -580,7 +587,11 @@ def test_serialization_roundtrip(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() original_processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) # Get config (serialization) @@ -591,7 +602,7 @@ def test_serialization_roundtrip(full_stats): features=config["features"], norm_map=config["norm_map"], stats=full_stats, - normalize_keys=set(config["normalize_keys"]), + normalize_observation_keys=set(config["normalize_observation_keys"]), eps=config["eps"], ) @@ -939,31 +950,31 @@ def test_identity_config_serialization(): assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) -def test_unsupported_normalization_mode_error(): - """Test that unsupported normalization modes raise appropriate errors.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} +# def test_unsupported_normalization_mode_error(): +# """Test that unsupported normalization modes raise appropriate errors.""" +# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} - # Create an invalid norm_map (this would never happen in practice, but tests error handling) - from enum import Enum +# # Create an invalid norm_map (this would never happen in practice, but tests error handling) +# from enum import Enum - class InvalidMode(str, Enum): - INVALID = "INVALID" +# class InvalidMode(str, Enum): +# INVALID = "INVALID" - # We can't actually pass an invalid enum to the processor due to type checking, - # but we can test the error by manipulating the norm_map after creation - norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} +# # We can't actually pass an invalid enum to the processor due to type checking, +# # but we can test the error by manipulating the norm_map after creation +# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} +# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) +# normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - # Manually inject an invalid mode to test error handling - normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" +# # Manually inject an invalid mode to test error handling +# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" - observation = {"observation.state": torch.tensor([1.0, -0.5])} - transition = create_transition(observation=observation) +# observation = {"observation.state": torch.tensor([1.0, -0.5])} +# transition = create_transition(observation=observation) - with pytest.raises(ValueError, match="Unsupported normalization mode"): - normalizer(transition) +# with pytest.raises(ValueError, match="Unsupported normalization mode"): +# normalizer(transition) def test_hotswap_stats_basic_functionality(): @@ -1149,11 +1160,15 @@ def test_hotswap_stats_preserves_other_attributes(): "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalize_keys = {"observation.image"} + normalize_observation_keys = {"observation.image"} eps = 1e-6 normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps + features=features, + norm_map=norm_map, + stats=initial_stats, + normalize_observation_keys=normalize_observation_keys, + eps=eps, ) robot_processor = RobotProcessor(steps=[normalizer]) @@ -1164,7 +1179,7 @@ def test_hotswap_stats_preserves_other_attributes(): new_normalizer = new_processor.steps[0] assert new_normalizer.features == features assert new_normalizer.norm_map == norm_map - assert new_normalizer.normalize_keys == normalize_keys + assert new_normalizer.normalize_observation_keys == normalize_observation_keys assert new_normalizer.eps == eps # But stats should be updated @@ -1270,273 +1285,6 @@ def test_hotswap_stats_with_different_data_types(): torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) -def test_normalization_info_tracking(): - """Test that normalization info is tracked in complementary_data.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - FeatureType.ACTION: NormalizationMode.IDENTITY, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "observation.state": { - "min": np.array([0.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - "action": { - "mean": np.array([0.0, 0.0]), - "std": np.array([1.0, 1.0]), - }, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - action = torch.tensor([1.0, -0.5]) - transition = create_transition(observation=observation, action=action) - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that normalization info is added - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - assert norm_info["observation.state"] == "MIN_MAX" - assert norm_info["action"] == "IDENTITY" - - -def test_unnormalization_info_tracking(): - """Test that unnormalization info is tracked in complementary_data.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "action": { - "min": np.array([-1.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - } - - unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} - action = torch.tensor([0.0, -0.5]) - transition = create_transition(observation=observation, action=action) - - # Process the transition - unnormalized_transition = unnormalizer(transition) - - # Check that unnormalization info is added - comp_data = unnormalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "unnormalized_keys" in comp_data - - unnorm_info = comp_data["unnormalized_keys"] - assert unnorm_info["observation.image"] == "MEAN_STD" - assert unnorm_info["action"] == "MIN_MAX" - - -def test_normalization_info_with_missing_stats(): - """Test normalization info when stats are missing for some keys.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - } - - # Only provide stats for image, not state - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - transition = create_transition(observation=observation) - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that only keys with stats are in normalization info - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - # State should not be in the normalization info since it has no stats - assert "observation.state" not in norm_info - - -def test_normalization_info_with_selective_keys(): - """Test normalization info with selective normalization.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "observation.state": { - "min": np.array([0.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - } - - # Only normalize image - normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.image"} - ) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - transition = create_transition(observation=observation) - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that only selected keys are in normalization info - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - # State should not be in the normalization info since it wasn't in normalize_keys - assert "observation.state" not in norm_info - - -def test_normalization_info_preserved_in_pipeline(): - """Test that normalization info is preserved when using RobotProcessor pipeline.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "action": { - "min": np.array([-1.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - # Create pipeline - pipeline = RobotProcessor([normalizer, unnormalizer]) - - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} - action = torch.tensor([0.5, -0.5]) - transition = create_transition(observation=observation, action=action) - - # Process through pipeline - result = pipeline(transition) - - # Check that both normalization and unnormalization info are present - comp_data = result.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - assert "unnormalized_keys" in comp_data - - # Check normalization info - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - assert norm_info["action"] == "MIN_MAX" - - # Check unnormalization info - unnorm_info = comp_data["unnormalized_keys"] - assert unnorm_info["observation.image"] == "MEAN_STD" - assert unnorm_info["action"] == "MIN_MAX" - - -def test_normalization_info_empty_transition(): - """Test that no normalization info is added for empty transitions.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, - "action": {"min": [-1.0], "max": [1.0]}, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - # Empty transition - transition = create_transition() - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that no normalization info is added - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is None or "normalized_keys" not in comp_data - - def test_hotswap_stats_functional_test(): """Test that hotswapped processor actually works functionally.""" # Create test data @@ -1631,8 +1379,8 @@ def test_min_equals_max_maps_to_minus_one(): assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) -def test_action_normalized_despite_normalize_keys(): - """Action normalization is independent of normalize_keys filter for observations.""" +def test_action_normalized_despite_normalize_observation_keys(): + """Action normalization is independent of normalize_observation_keys filter for observations.""" features = { "observation.state": PolicyFeature(FeatureType.STATE, (1,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), @@ -1640,7 +1388,7 @@ def test_action_normalized_despite_normalize_keys(): norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.state"} + features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"} ) transition = create_transition( @@ -1680,19 +1428,6 @@ def test_unnormalize_observations_mean_std_and_min_max(): assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2] -def test_rename_stats_basic(): - orig = { - "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, - "action": {"mean": np.array([0.0])}, - } - mapping = {"observation.state": "observation.robot_state"} - renamed = rename_stats(orig, mapping) - assert "observation.robot_state" in renamed and "observation.state" not in renamed - # Ensure deep copy: mutate original and verify renamed unaffected - orig["observation.state"]["mean"][0] = 42.0 - assert renamed["observation.robot_state"]["mean"][0] != 42.0 - - def test_unknown_observation_keys_ignored(): features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} @@ -1705,8 +1440,6 @@ def test_unknown_observation_keys_ignored(): # Unknown key should pass through unchanged and not be tracked assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"]) - comp = out.get(TransitionKey.COMPLEMENTARY_DATA) or {} - assert "normalized_keys" in comp and "observation.unknown" not in comp["normalized_keys"] def test_batched_action_normalization(): @@ -1731,7 +1464,7 @@ def test_complementary_data_preservation(): tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) out = normalizer(tr) new_comp = out[TransitionKey.COMPLEMENTARY_DATA] - assert new_comp["existing"] == 123 and "normalized_keys" in new_comp + assert new_comp["existing"] == 123 def test_roundtrip_normalize_unnormalize_non_identity(): diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py index 41a41ca2c..056814f37 100644 --- a/tests/processor/test_pi0_processor.py +++ b/tests/processor/test_pi0_processor.py @@ -23,7 +23,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_processor +from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors from lerobot.processor import ( DeviceProcessor, NormalizerProcessor, @@ -84,7 +84,7 @@ def test_make_pi0_processor_basic(): stats = create_default_stats() with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"): - preprocessor, postprocessor = make_pi0_processor(config, stats) + preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats) # Check processor names assert preprocessor.name == "robot_preprocessor" @@ -183,7 +183,7 @@ def test_pi0_processor_cuda(): return features with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor): - preprocessor, postprocessor = make_pi0_processor(config, stats) + preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats) # Create CPU data observation = { @@ -233,7 +233,7 @@ def test_pi0_processor_accelerate_scenario(): return features with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor): - preprocessor, postprocessor = make_pi0_processor(config, stats) + preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats) # Simulate Accelerate: data already on GPU and batched device = torch.device("cuda:0") @@ -284,7 +284,7 @@ def test_pi0_processor_multi_gpu(): return features with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor): - preprocessor, postprocessor = make_pi0_processor(config, stats) + preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats) # Simulate data on different GPU device = torch.device("cuda:1") @@ -310,7 +310,7 @@ def test_pi0_processor_without_stats(): # Mock the tokenizer processor with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"): - preprocessor, postprocessor = make_pi0_processor(config, dataset_stats=None) + preprocessor, postprocessor = make_pi0_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 398b3ec9c..4efb249dd 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -21,6 +21,7 @@ import torch from lerobot.configs.types import FeatureType from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from lerobot.processor.rename_processor import rename_stats from tests.conftest import assert_contract_is_typed @@ -465,3 +466,16 @@ def test_features_chained_processors(policy_feature_factory): assert out["observation.image"] == spec["img"] assert out["extra"] == spec["extra"] assert_contract_is_typed(out) + + +def test_rename_stats_basic(): + orig = { + "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + "action": {"mean": np.array([0.0])}, + } + mapping = {"observation.state": "observation.robot_state"} + renamed = rename_stats(orig, mapping) + assert "observation.robot_state" in renamed and "observation.state" not in renamed + # Ensure deep copy: mutate original and verify renamed unaffected + orig["observation.state"]["mean"][0] = 42.0 + assert renamed["observation.robot_state"]["mean"][0] != 42.0 diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py index f0825acb4..33f3330dc 100644 --- a/tests/processor/test_sac_processor.py +++ b/tests/processor/test_sac_processor.py @@ -23,7 +23,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.policies.sac.processor_sac import make_sac_processor +from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors from lerobot.processor import ( DeviceProcessor, NormalizerProcessor, @@ -78,7 +78,7 @@ def test_make_sac_processor_basic(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Check processor names assert preprocessor.name == "robot_preprocessor" @@ -102,7 +102,7 @@ def test_sac_processor_normalization_modes(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Create test data observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization @@ -133,7 +133,7 @@ def test_sac_processor_cuda(): config.device = "cuda" stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Create CPU data observation = {OBS_STATE: torch.randn(10)} @@ -162,7 +162,7 @@ def test_sac_processor_accelerate_scenario(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Simulate Accelerate: data already on GPU device = torch.device("cuda:0") @@ -185,7 +185,7 @@ def test_sac_processor_multi_gpu(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Simulate data on different GPU device = torch.device("cuda:1") @@ -205,7 +205,7 @@ def test_sac_processor_without_stats(): """Test SAC processor creation without dataset statistics.""" config = create_default_config() - preprocessor, postprocessor = make_sac_processor(config, dataset_stats=None) + preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -225,7 +225,7 @@ def test_sac_processor_save_and_load(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) with tempfile.TemporaryDirectory() as tmpdir: # Save preprocessor @@ -252,7 +252,7 @@ def test_sac_processor_mixed_precision(): stats = create_default_stats() # Create processor - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Replace DeviceProcessor with one that uses float16 for i, step in enumerate(preprocessor.steps): @@ -277,7 +277,7 @@ def test_sac_processor_batch_data(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Test with batched data batch_size = 32 @@ -298,7 +298,7 @@ def test_sac_processor_edge_cases(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_sac_processor(config, stats) + preprocessor, postprocessor = make_sac_pre_post_processors(config, stats) # Test with empty observation transition = create_transition(observation={}, action=torch.randn(5)) diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py index be538b017..d32731739 100644 --- a/tests/processor/test_smolvla_processor.py +++ b/tests/processor/test_smolvla_processor.py @@ -23,7 +23,10 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig -from lerobot.policies.smolvla.processor_smolvla import SmolVLANewLineProcessor, make_smolvla_processor +from lerobot.policies.smolvla.processor_smolvla import ( + SmolVLANewLineProcessor, + make_smolvla_pre_post_processors, +) from lerobot.processor import ( DeviceProcessor, NormalizerProcessor, @@ -86,7 +89,7 @@ def test_make_smolvla_processor_basic(): stats = create_default_stats() with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"): - preprocessor, postprocessor = make_smolvla_processor(config, stats) + preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats) # Check processor names assert preprocessor.name == "robot_preprocessor" @@ -185,7 +188,7 @@ def test_smolvla_processor_cuda(): return features with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor): - preprocessor, postprocessor = make_smolvla_processor(config, stats) + preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats) # Create CPU data observation = { @@ -235,7 +238,7 @@ def test_smolvla_processor_accelerate_scenario(): return features with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor): - preprocessor, postprocessor = make_smolvla_processor(config, stats) + preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats) # Simulate Accelerate: data already on GPU and batched device = torch.device("cuda:0") @@ -286,7 +289,7 @@ def test_smolvla_processor_multi_gpu(): return features with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor): - preprocessor, postprocessor = make_smolvla_processor(config, stats) + preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats) # Simulate data on different GPU device = torch.device("cuda:1") @@ -312,7 +315,7 @@ def test_smolvla_processor_without_stats(): # Mock the tokenizer processor with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"): - preprocessor, postprocessor = make_smolvla_processor(config, dataset_stats=None) + preprocessor, postprocessor = make_smolvla_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None diff --git a/tests/processor/test_tdmpc_processor.py b/tests/processor/test_tdmpc_processor.py index c6bac6442..473fc93db 100644 --- a/tests/processor/test_tdmpc_processor.py +++ b/tests/processor/test_tdmpc_processor.py @@ -23,7 +23,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig -from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor +from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors from lerobot.processor import ( DeviceProcessor, NormalizerProcessor, @@ -81,7 +81,7 @@ def test_make_tdmpc_processor_basic(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Check processor names assert preprocessor.name == "robot_preprocessor" @@ -105,7 +105,7 @@ def test_tdmpc_processor_normalization(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Create test data observation = { @@ -138,7 +138,7 @@ def test_tdmpc_processor_cuda(): config.device = "cuda" stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Create CPU data observation = { @@ -171,7 +171,7 @@ def test_tdmpc_processor_accelerate_scenario(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Simulate Accelerate: data already on GPU device = torch.device("cuda:0") @@ -198,7 +198,7 @@ def test_tdmpc_processor_multi_gpu(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Simulate data on different GPU device = torch.device("cuda:1") @@ -222,7 +222,7 @@ def test_tdmpc_processor_without_stats(): """Test TDMPC processor creation without dataset statistics.""" config = create_default_config() - preprocessor, postprocessor = make_tdmpc_processor(config, dataset_stats=None) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -245,7 +245,7 @@ def test_tdmpc_processor_save_and_load(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) with tempfile.TemporaryDirectory() as tmpdir: # Save preprocessor @@ -276,7 +276,7 @@ def test_tdmpc_processor_mixed_precision(): stats = create_default_stats() # Create processor - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Replace DeviceProcessor with one that uses float16 for i, step in enumerate(preprocessor.steps): @@ -305,7 +305,7 @@ def test_tdmpc_processor_batch_data(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Test with batched data batch_size = 64 @@ -330,7 +330,7 @@ def test_tdmpc_processor_edge_cases(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_tdmpc_processor(config, stats) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats) # Test with only state observation (no image) observation = {OBS_STATE: torch.randn(12)} diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 802b2edb7..300191d86 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -98,7 +98,11 @@ def test_basic_tokenization(mock_auto_tokenizer): processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10) - transition = create_transition(complementary_data={"task": "pick up the red cube"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "pick up the red cube"}, + ) result = processor(transition) @@ -126,7 +130,11 @@ def test_basic_tokenization_with_tokenizer_object(): processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10) - transition = create_transition(complementary_data={"task": "pick up the red cube"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "pick up the red cube"}, + ) result = processor(transition) @@ -156,7 +164,11 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer): processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8) - transition = create_transition(complementary_data={"task": ["pick up cube", "place on table"]}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": ["pick up cube", "place on table"]}, + ) result = processor(transition) @@ -180,7 +192,11 @@ def test_custom_keys(mock_auto_tokenizer): processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5) - transition = create_transition(complementary_data={"instruction": "move forward"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "move forward"}, + ) result = processor(transition) @@ -421,7 +437,11 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer): loaded_processor = RobotProcessor.from_pretrained(temp_dir) # Test that loaded processor works - transition = create_transition(complementary_data={"instruction": "test instruction"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "test instruction"}, + ) result = loaded_processor(transition) assert TransitionKey.OBSERVATION in result @@ -448,7 +468,11 @@ def test_save_and_load_pretrained_with_tokenizer_object(): ) # Test that loaded processor works - transition = create_transition(complementary_data={"instruction": "test instruction"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "test instruction"}, + ) result = loaded_processor(transition) assert TransitionKey.OBSERVATION in result @@ -569,7 +593,11 @@ def test_tokenization_parameters(mock_auto_tokenizer): padding_side="left", ) - transition = create_transition(complementary_data={"task": "test task"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) processor(transition) @@ -592,12 +620,14 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer): processor = TokenizerProcessor(tokenizer_name="test-tokenizer") transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), complementary_data={ "task": "test task", "episode_id": 123, "timestamp": 456.789, "other_field": {"nested": "data"}, - } + }, ) result = processor(transition) @@ -624,7 +654,11 @@ def test_deterministic_tokenization(mock_auto_tokenizer): processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10) - transition = create_transition(complementary_data={"task": "consistent test"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "consistent test"}, + ) result1 = processor(transition) result2 = processor(transition) @@ -648,7 +682,11 @@ def test_empty_string_task(mock_auto_tokenizer): processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8) - transition = create_transition(complementary_data={"task": ""}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": ""}, + ) result = processor(transition) @@ -669,7 +707,11 @@ def test_very_long_task(mock_auto_tokenizer): processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True) long_task = " ".join(["word"] * 100) # Very long task - transition = create_transition(complementary_data={"task": long_task}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": long_task}, + ) result = processor(transition) @@ -714,7 +756,11 @@ def test_custom_padding_side(mock_auto_tokenizer): # Test left padding processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left") - transition = create_transition(complementary_data={"task": "test task"}) + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) processor_left(transition) assert tracking_tokenizer.padding_side_calls[-1] == "left" @@ -873,32 +919,6 @@ def test_device_detection_from_action(): assert attention_mask.device.type == "cuda" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@require_package("transformers") -def test_device_detection_from_complementary_data(): - """Test that device is detected from tensors in complementary_data.""" - mock_tokenizer = MockTokenizer(vocab_size=100) - processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10) - - # Create transition with tensor in complementary_data - transition = create_transition( - observation={"metadata": {"key": "value"}}, # No tensors - complementary_data={ - "task": "comp data test", - "index": torch.tensor([42]).cuda(), # Tensor in complementary_data - }, - ) - - result = processor(transition) - - # Check that tokenized tensors match complementary_data tensor's device - tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] - attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] - - assert tokens.device.type == "cuda" - assert attention_mask.device.type == "cuda" - - @require_package("transformers") def test_device_detection_preserves_dtype(): """Test that device detection doesn't affect dtype of tokenized tensors.""" diff --git a/tests/processor/test_vqbet_processor.py b/tests/processor/test_vqbet_processor.py index 6369d92d1..6df59c99d 100644 --- a/tests/processor/test_vqbet_processor.py +++ b/tests/processor/test_vqbet_processor.py @@ -23,7 +23,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor +from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors from lerobot.processor import ( DeviceProcessor, NormalizerProcessor, @@ -81,7 +81,7 @@ def test_make_vqbet_processor_basic(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Check processor names assert preprocessor.name == "robot_preprocessor" @@ -105,7 +105,7 @@ def test_vqbet_processor_with_images(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Create test data with images and states observation = { @@ -131,7 +131,7 @@ def test_vqbet_processor_cuda(): config.device = "cuda" stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Create CPU data observation = { @@ -164,7 +164,7 @@ def test_vqbet_processor_accelerate_scenario(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Simulate Accelerate: data already on GPU and batched device = torch.device("cuda:0") @@ -191,7 +191,7 @@ def test_vqbet_processor_multi_gpu(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Simulate data on different GPU device = torch.device("cuda:1") @@ -215,7 +215,7 @@ def test_vqbet_processor_without_stats(): """Test VQBeT processor creation without dataset statistics.""" config = create_default_config() - preprocessor, postprocessor = make_vqbet_processor(config, dataset_stats=None) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -238,7 +238,7 @@ def test_vqbet_processor_save_and_load(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) with tempfile.TemporaryDirectory() as tmpdir: # Save preprocessor @@ -269,7 +269,7 @@ def test_vqbet_processor_mixed_precision(): stats = create_default_stats() # Create processor - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Replace DeviceProcessor with one that uses float16 for i, step in enumerate(preprocessor.steps): @@ -298,7 +298,7 @@ def test_vqbet_processor_large_batch(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Test with large batch batch_size = 128 @@ -323,7 +323,7 @@ def test_vqbet_processor_sequential_processing(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_vqbet_processor(config, stats) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats) # Process multiple samples sequentially results = []