From 43d878a102a5dd5f4f75d220505a68671d3e0c84 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 25 Sep 2025 15:36:47 +0200 Subject: [PATCH 1/3] chore: replace hard-coded obs values with constants throughout all the source code (#2037) * chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code --- benchmarks/video/run_video_benchmark.py | 3 +- examples/lekiwi/evaluate.py | 3 +- examples/lekiwi/record.py | 3 +- src/lerobot/async_inference/helpers.py | 6 +- src/lerobot/datasets/factory.py | 3 +- src/lerobot/datasets/pipeline_features.py | 18 +- src/lerobot/datasets/utils.py | 7 +- src/lerobot/envs/utils.py | 9 +- src/lerobot/policies/act/modeling_act.py | 22 +- .../policies/diffusion/modeling_diffusion.py | 18 +- src/lerobot/policies/pi0/configuration_pi0.py | 3 +- .../conversion_scripts/compare_with_jax.py | 15 +- .../policies/pi0fast/configuration_pi0fast.py | 3 +- src/lerobot/policies/sac/modeling_sac.py | 13 +- .../reward_model/configuration_classifier.py | 3 +- .../policies/smolvla/configuration_smolvla.py | 3 +- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 24 +- src/lerobot/policies/vqbet/modeling_vqbet.py | 14 +- src/lerobot/processor/converters.py | 4 +- .../processor/observation_processor.py | 6 +- src/lerobot/rl/buffer.py | 3 +- src/lerobot/rl/gym_manipulator.py | 7 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 3 +- src/lerobot/scripts/lerobot_dataset_viz.py | 5 +- src/lerobot/scripts/lerobot_eval.py | 7 +- src/lerobot/scripts/lerobot_record.py | 3 +- src/lerobot/utils/constants.py | 18 +- src/lerobot/utils/visualization_utils.py | 4 +- .../policies/save_policy_to_safetensors.py | 3 +- tests/async_inference/test_helpers.py | 55 +-- tests/async_inference/test_policy_server.py | 5 +- tests/datasets/test_compute_stats.py | 33 +- tests/datasets/test_dataset_utils.py | 9 +- tests/datasets/test_datasets.py | 9 +- .../hilserl/test_modeling_classifier.py | 9 +- tests/policies/test_policies.py | 14 +- tests/policies/test_sac_config.py | 13 +- tests/policies/test_sac_policy.py | 21 +- tests/processor/test_act_processor.py | 2 +- tests/processor/test_batch_conversion.py | 79 ++-- tests/processor/test_converters.py | 17 +- tests/processor/test_device_processor.py | 99 ++-- tests/processor/test_migration_detection.py | 3 +- tests/processor/test_normalize_processor.py | 437 +++++++++--------- tests/processor/test_observation_processor.py | 46 +- tests/processor/test_pipeline.py | 37 +- tests/processor/test_rename_processor.py | 73 ++- tests/processor/test_tokenizer_processor.py | 24 +- tests/rl/test_actor.py | 5 +- tests/rl/test_actor_learner.py | 5 +- tests/utils/test_replay_buffer.py | 73 ++- tests/utils/test_visualization_utils.py | 7 +- 52 files changed, 659 insertions(+), 649 deletions(-) diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index f041a9066..9f34b2273 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -41,6 +41,7 @@ from lerobot.datasets.video_utils import ( decode_video_frames_torchvision, encode_video_frames, ) +from lerobot.utils.constants import OBS_IMAGE BASE_ENCODING = OrderedDict( [ @@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: hf_dataset = dataset.hf_dataset.with_format(None) # We only save images from the first camera - img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] + img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate( diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 32a5e0a2b..174486eb8 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -21,6 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.scripts.lerobot_record import record_loop +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -42,7 +43,7 @@ policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") +obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} # Create the dataset diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 30f34e718..471cb3668 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -22,6 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -48,7 +49,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") +obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} # Create the dataset diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 175cecf6d..75d81a0f3 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -27,7 +27,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.robots.robot import Robot -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import init_logging Action = torch.Tensor @@ -66,7 +66,7 @@ def validate_robot_cameras_for_policy( def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: - return hw_to_dataset_features(robot.observation_features, "observation", use_video=False) + return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False) def is_image_key(k: str) -> bool: @@ -141,7 +141,7 @@ def make_lerobot_observation( lerobot_features: dict[str, dict], ) -> LeRobotObservation: """Make a lerobot observation from a raw observation.""" - return build_dataset_frame(lerobot_features, robot_obs, prefix="observation") + return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR) def prepare_raw_observation( diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index a71e978bc..2bac84aed 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms +from lerobot.utils.constants import OBS_PREFIX IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -58,7 +59,7 @@ def resolve_delta_timestamps( delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] if key == "action" and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] - if key.startswith("observation.") and cfg.observation_delta_indices is not None: + if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] if len(delta_timestamps) == 0: diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index cdf0b7448..13555dd31 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -19,7 +19,7 @@ from typing import Any from lerobot.configs.types import PipelineFeatureType from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor import DataProcessorPipeline -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR def create_initial_features( @@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features( # Intermediate storage for categorized and filtered features. processed_features: dict[str, dict[str, Any]] = { - "action": {}, - "observation": {}, + ACTION: {}, + OBS_STR: {}, } images_token = OBS_IMAGES.split(".")[-1] @@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features( # 3. Add the feature to the appropriate group with a clean name. name = strip_prefix(key, PREFIXES_TO_STRIP) if is_action: - processed_features["action"][name] = value + processed_features[ACTION][name] = value else: - processed_features["observation"][name] = value + processed_features[OBS_STR][name] = value # Convert the processed features into the final dataset format. dataset_features = {} - if processed_features["action"]: + if processed_features[ACTION]: dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) - if processed_features["observation"]: - dataset_features.update( - hw_to_dataset_features(processed_features["observation"], "observation", use_videos) - ) + if processed_features[OBS_STR]: + dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos)) return dataset_features diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 922fc4e3f..96ae2eca6 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) +from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR from lerobot.utils.utils import is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk @@ -652,7 +653,7 @@ def hw_to_dataset_features( "names": list(joint_fts), } - if joint_fts and prefix == "observation": + if joint_fts and prefix == OBS_STR: features[f"{prefix}.state"] = { "dtype": "float32", "shape": (len(joint_fts),), @@ -728,9 +729,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) - elif key == "observation.environment_state": + elif key == OBS_ENV_STATE: type = FeatureType.ENV - elif key.startswith("observation"): + elif key.startswith(OBS_STR): type = FeatureType.STATE elif key.startswith("action"): type = FeatureType.ACTION diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index f0aa0b5c6..023ceea67 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -26,6 +26,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.utils.utils import get_channel_first_image_shape @@ -41,9 +42,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations = {} if "pixels" in observations: if isinstance(observations["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()} else: - imgs = {"observation.image": observations["pixels"]} + imgs = {OBS_IMAGE: observations["pixels"]} for imgkey, img in imgs.items(): # TODO(aliberts, rcadene): use transforms.ToTensor()? @@ -72,13 +73,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten if env_state.dim() == 1: env_state = env_state.unsqueeze(0) - return_observations["observation.environment_state"] = env_state + return_observations[OBS_ENV_STATE] = env_state # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing agent_pos = torch.from_numpy(observations["agent_pos"]).float() if agent_pos.dim() == 1: agent_pos = agent_pos.unsqueeze(0) - return_observations["observation.state"] = agent_pos + return_observations[OBS_STATE] = agent_pos return return_observations diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index e4ebec199..f8261bb7f 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_IMAGES +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE class ACTPolicy(PreTrainedPolicy): @@ -398,10 +398,10 @@ class ACT(nn.Module): "actions must be provided when using the variational objective in training mode." ) - if "observation.images" in batch: - batch_size = batch["observation.images"][0].shape[0] + if OBS_IMAGES in batch: + batch_size = batch[OBS_IMAGES][0].shape[0] else: - batch_size = batch["observation.environment_state"].shape[0] + batch_size = batch[OBS_ENV_STATE].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch and self.training: @@ -410,7 +410,7 @@ class ACT(nn.Module): self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) if self.config.robot_state_feature: - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) @@ -430,7 +430,7 @@ class ACT(nn.Module): cls_joint_is_pad = torch.full( (batch_size, 2 if self.config.robot_state_feature else 1), False, - device=batch["observation.state"].device, + device=batch[OBS_STATE].device, ) key_padding_mask = torch.cat( [cls_joint_is_pad, batch["action_is_pad"]], axis=1 @@ -454,7 +454,7 @@ class ACT(nn.Module): mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device + batch[OBS_STATE].device ) # Prepare transformer encoder inputs. @@ -462,18 +462,16 @@ class ACT(nn.Module): encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) # Robot state token. if self.config.robot_state_feature: - encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE])) # Environment state token. if self.config.env_state_feature: - encoder_in_tokens.append( - self.encoder_env_state_input_proj(batch["observation.environment_state"]) - ) + encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE])) if self.config.image_features: # For a list of images, the H and W may vary but H*W is constant. # NOTE: If modifying this section, verify on MPS devices that # gradients remain stable (no explosions or NaNs). - for img in batch["observation.images"]: + for img in batch[OBS_IMAGES]: cam_features = self.backbone(img)["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 0bd2e282b..af1327ba2 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -81,13 +81,13 @@ class DiffusionPolicy(PreTrainedPolicy): def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { - "observation.state": deque(maxlen=self.config.n_obs_steps), + OBS_STATE: deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } if self.config.image_features: - self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps) @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: @@ -234,7 +234,7 @@ class DiffusionModel(nn.Module): if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: # Combine batch and sequence dims while rearranging to make the camera index dimension first. - images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") + images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...") img_features_list = torch.cat( [ encoder(images) @@ -249,7 +249,7 @@ class DiffusionModel(nn.Module): else: # Combine batch, sequence, and "which camera" dims before passing to shared encoder. img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...") ) # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). @@ -275,7 +275,7 @@ class DiffusionModel(nn.Module): "observation.environment_state": (B, n_obs_steps, environment_dim) } """ - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] assert n_obs_steps == self.config.n_obs_steps # Encode image features and concatenate them all together along with the state vector. @@ -306,9 +306,9 @@ class DiffusionModel(nn.Module): } """ # Input validation. - assert set(batch).issuperset({"observation.state", "action", "action_is_pad"}) - assert "observation.images" in batch or "observation.environment_state" in batch - n_obs_steps = batch["observation.state"].shape[1] + assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"}) + assert OBS_IMAGES in batch or OBS_ENV_STATE in batch + n_obs_steps = batch[OBS_STATE].shape[1] horizon = batch["action"].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index c9728e418..bd5bbf7ee 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("pi0") @@ -113,7 +114,7 @@ class PI0Config(PreTrainedConfig): # raise ValueError("You must provide at least one image or the environment state among the inputs.") for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py index c0c2e4816..fe9865697 100644 --- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py @@ -21,6 +21,7 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.factory import make_policy +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE def display(tensor: torch.Tensor): @@ -60,26 +61,26 @@ def main(): # Override stats dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) - dataset_meta.stats["observation.state"]["mean"] = torch.tensor( + dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor( norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 ) - dataset_meta.stats["observation.state"]["std"] = torch.tensor( + dataset_meta.stats[OBS_STATE]["std"] = torch.tensor( norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 ) # Create LeRobot batch from Jax batch = {} for cam_key, uint_chw_array in example["images"].items(): - batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 - batch["observation.state"] = torch.from_numpy(example["state"]) + batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 + batch[OBS_STATE] = torch.from_numpy(example["state"]) batch["action"] = torch.from_numpy(outputs["actions"]) batch["task"] = example["prompt"] if model_name == "pi0_aloha_towel": - del batch["observation.images.cam_low"] + del batch[f"{OBS_IMAGES}.cam_low"] elif model_name == "pi0_aloha_sim": - batch["observation.images.top"] = batch["observation.images.cam_high"] - del batch["observation.images.cam_high"] + batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"] + del batch[f"{OBS_IMAGES}.cam_high"] # Batchify for key in batch: diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py index b72bcd735..705b61ea8 100644 --- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py +++ b/src/lerobot/policies/pi0fast/configuration_pi0fast.py @@ -6,6 +6,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("pi0fast") @@ -99,7 +100,7 @@ class PI0FASTConfig(PreTrainedConfig): def validate_features(self) -> None: for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index fcaf02a4b..a6ed79d4e 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.utils import get_device_from_parameters +from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension @@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module): ) def _init_state_layers(self) -> None: - self.has_env = "observation.environment_state" in self.config.input_features - self.has_state = "observation.state" in self.config.input_features + self.has_env = OBS_ENV_STATE in self.config.input_features + self.has_state = OBS_STATE in self.config.input_features if self.has_env: - dim = self.config.input_features["observation.environment_state"].shape[0] + dim = self.config.input_features[OBS_ENV_STATE].shape[0] self.env_encoder = nn.Sequential( nn.Linear(dim, self.config.latent_dim), nn.LayerNorm(self.config.latent_dim), nn.Tanh(), ) if self.has_state: - dim = self.config.input_features["observation.state"].shape[0] + dim = self.config.input_features[OBS_STATE].shape[0] self.state_encoder = nn.Sequential( nn.Linear(dim, self.config.latent_dim), nn.LayerNorm(self.config.latent_dim), @@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module): cache = self.get_cached_image_features(obs) parts.append(self._encode_images(cache, detach)) if self.has_env: - parts.append(self.env_encoder(obs["observation.environment_state"])) + parts.append(self.env_encoder(obs[OBS_ENV_STATE])) if self.has_state: - parts.append(self.state_encoder(obs["observation.state"])) + parts.append(self.state_encoder(obs[OBS_STATE])) if parts: return torch.cat(parts, dim=-1) diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py index fc53283b3..9b76b8037 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py @@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.constants import OBS_IMAGE @PreTrainedConfig.register_subclass(name="reward_classifier") @@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig): def validate_features(self) -> None: """Validate feature configurations.""" - has_image = any(key.startswith("observation.image") for key in self.input_features) + has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features) if not has_image: raise ValueError( "You must provide an image observation (key starting with 'observation.image') in the input features" diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index 571900c4a..eedf477a5 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("smolvla") @@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig): def validate_features(self) -> None: for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index f83048862..4b5e8b7bd 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -38,7 +38,7 @@ from torch import Tensor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD class TDMPCPolicy(PreTrainedPolicy): @@ -91,13 +91,13 @@ class TDMPCPolicy(PreTrainedPolicy): called on `env.reset()` """ self._queues = { - "observation.state": deque(maxlen=1), + OBS_STATE: deque(maxlen=1), "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), } if self.config.image_features: - self._queues["observation.image"] = deque(maxlen=1) + self._queues[OBS_IMAGE] = deque(maxlen=1) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque(maxlen=1) + self._queues[OBS_ENV_STATE] = deque(maxlen=1) # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # CEM for the next step. self._prev_mean: torch.Tensor | None = None @@ -325,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy): action = batch[ACTION] # (t, b, action_dim) reward = batch[REWARD] # (t, b) - observations = {k: v for k, v in batch.items() if k.startswith("observation.")} + observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)} # Apply random image augmentations. if self.config.image_features and self.config.max_random_shift_ratio > 0: @@ -387,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) # `z_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] # `z_targets` depends on the next observation. - * ~batch["observation.state_is_pad"][1:] + * ~batch[f"{OBS_STR}.state_is_pad"][1:] ) .sum(0) .mean() @@ -403,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy): * F.mse_loss(reward_preds, reward, reduction="none") * ~batch["next.reward_is_pad"] # `reward_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ) .sum(0) @@ -419,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy): reduction="none", ).sum(0) # sum over ensemble # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] # q_targets depends on the reward and the next observations. * ~batch["next.reward_is_pad"] - * ~batch["observation.state_is_pad"][1:] + * ~batch[f"{OBS_STR}.state_is_pad"][1:] ) .sum(0) .mean() @@ -441,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * raw_v_value_loss # `v_targets` depends on the first observation and the actions, as does `v_preds`. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ) .sum(0) @@ -477,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy): * mse * temporal_loss_coeffs # `action_preds` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ).mean() diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 34e5b1c0d..91d609701 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -133,7 +133,7 @@ class VQBeTPolicy(PreTrainedPolicy): batch.pop(ACTION) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original # NOTE: It's important that this happens after stacking the images into a single key. - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out if ACTION in batch: batch.pop(ACTION) @@ -340,14 +340,12 @@ class VQBeTModel(nn.Module): def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: # Input validation. - assert set(batch).issuperset({"observation.state", "observation.images"}) - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert set(batch).issuperset({OBS_STATE, OBS_IMAGES}) + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") - ) + img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")) # Separate batch and sequence dims. img_features = einops.rearrange( img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images @@ -359,9 +357,7 @@ class VQBeTModel(nn.Module): img_features ) # (batch, obs_step, number of different cameras, projection dims) input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))] - input_tokens.append( - self.state_projector(batch["observation.state"]) - ) # (batch, obs_step, projection dims) + input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims) input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) # Interleave tokens by stacking and rearranging. input_tokens = torch.stack(input_tokens, dim=2) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 440f8b1db..2e80cf4bb 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,6 +23,8 @@ from typing import Any import numpy as np import torch +from lerobot.utils.constants import OBS_PREFIX + from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -347,7 +349,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: raise ValueError(f"Action should be a PolicyAction type got {type(action)}") # Extract observation and complementary data keys. - observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)} complementary_data = _extract_complementary_data(batch) return create_transition( diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 2b9402bee..486218157 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -21,7 +21,7 @@ import torch from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -171,7 +171,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): # Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1) for old_prefix, new_prefix in prefix_pairs.items(): - prefixed_old = f"observation.{old_prefix}" + prefixed_old = f"{OBS_STR}.{old_prefix}" if key.startswith(prefixed_old): suffix = key[len(prefixed_old) :] new_key = f"{new_prefix}{suffix}" @@ -191,7 +191,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): # Exact-name rules (pixels, environment_state, agent_pos) for old, new in exact_pairs.items(): - if key == old or key == f"observation.{old}": + if key == old or key == f"{OBS_STR}.{old}": new_key = new new_features[src_ft][new_key] = feat handled = True diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index c65801896..fbf36de36 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import OBS_IMAGE from lerobot.utils.transition import Transition @@ -240,7 +241,7 @@ class ReplayBuffer: idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) # Identify image keys that need augmentation - image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] + image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] # Create batched state and next_state batch_state = {} diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index f91d077f4..393135708 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -73,6 +73,7 @@ from lerobot.teleoperators import ( ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -180,7 +181,7 @@ class RobotEnv(gym.Env): # Define observation spaces for images and other states. if current_observation is not None and "pixels" in current_observation: - prefix = "observation.images" + prefix = OBS_IMAGES observation_spaces = { f"{prefix}.{key}": gym.spaces.Box( low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8 @@ -190,7 +191,7 @@ class RobotEnv(gym.Env): if current_observation is not None: agent_pos = current_observation["agent_pos"] - observation_spaces["observation.state"] = gym.spaces.Box( + observation_spaces[OBS_STATE] = gym.spaces.Box( low=0, high=10, shape=agent_pos.shape, @@ -612,7 +613,7 @@ def control_loop( } for key, value in transition[TransitionKey.OBSERVATION].items(): - if key == "observation.state": + if key == OBS_STATE: features[key] = { "dtype": "float32", "shape": value.squeeze(0).shape, diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 9f6367152..392d6d575 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -23,6 +23,7 @@ from typing import Any import cv2 import numpy as np +from lerobot.utils.constants import OBS_STATE from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -203,7 +204,7 @@ class LeKiwiClient(Robot): state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32) - obs_dict: dict[str, Any] = {**flat_state, "observation.state": state_vec} + obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec} # Decode images current_frames: dict[str, np.ndarray] = {} diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2033b36ba..5c0d31f73 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -75,6 +75,7 @@ import torch.utils.data import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import OBS_STATE class EpisodeSampler(torch.utils.data.Sampler): @@ -161,8 +162,8 @@ def visualize_dataset( rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) # display each dimension of observed state space (e.g. agent position in joint space) - if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): + if OBS_STATE in batch: + for dim_idx, val in enumerate(batch[OBS_STATE][i]): rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) if "next.done" in batch: diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index ca900f8df..310f771a9 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -81,6 +81,7 @@ from lerobot.envs.utils import ( from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.utils.constants import OBS_STR from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -221,7 +222,7 @@ def rollout( stacked_observations = {} for key in all_observations[0]: stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) - ret["observation"] = stacked_observations + ret[OBS_STR] = stacked_observations if hasattr(policy, "use_original_modules"): policy.use_original_modules() @@ -459,8 +460,8 @@ def _compile_episode_data( for k in ep_dict: ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]]) - for key in rollout_data["observation"]: - ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames] + for key in rollout_data[OBS_STR]: + ep_dict[key] = rollout_data[OBS_STR][key][ep_ix, :num_frames] ep_dicts.append(ep_dict) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index dd4984fab..f1d026a39 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -109,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401 so101_leader, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import ( init_keyboard_listener, is_headless, @@ -303,7 +304,7 @@ def record_loop( obs_processed = robot_observation_processor(obs) if policy is not None or dataset is not None: - observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation") + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) # Get action from either policy or teleop if policy is not None and preprocessor is not None and postprocessor is not None: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 464969c72..337817908 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -17,19 +17,21 @@ from pathlib import Path from huggingface_hub.constants import HF_HOME -OBS_ENV_STATE = "observation.environment_state" -OBS_STATE = "observation.state" -OBS_IMAGE = "observation.image" -OBS_IMAGES = "observation.images" -OBS_LANGUAGE = "observation.language" +OBS_STR = "observation" +OBS_PREFIX = OBS_STR + "." +OBS_ENV_STATE = OBS_STR + ".environment_state" +OBS_STATE = OBS_STR + ".state" +OBS_IMAGE = OBS_STR + ".image" +OBS_IMAGES = OBS_IMAGE + "s" +OBS_LANGUAGE = OBS_STR + ".language" +OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" +OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" + ACTION = "action" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" -OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" -OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" - ROBOTS = "robots" ROBOT_TYPE = "robot_type" TELEOPERATORS = "teleoperators" diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 7fc881f26..ae070b7c4 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -19,6 +19,8 @@ from typing import Any import numpy as np import rerun as rr +from .constants import OBS_PREFIX, OBS_STR + def init_rerun(session_name: str = "lerobot_control_loop") -> None: """Initializes the Rerun SDK for visualizing the control loop.""" @@ -63,7 +65,7 @@ def log_rerun_data( for k, v in observation.items(): if v is None: continue - key = k if str(k).startswith("observation.") else f"observation.{k}" + key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}" if _is_scalar(v): rr.log(key, rr.Scalar(float(v))) diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index b0ffa9a31..e130ae144 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -24,6 +24,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors +from lerobot.utils.constants import OBS_STR from lerobot.utils.random_utils import set_seed @@ -92,7 +93,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # for backward compatibility if k == "task": continue - if k.startswith("observation"): + if k.startswith(OBS_STR): obs[k] = batch[k] if hasattr(train_cfg.policy, "n_action_steps"): diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index f1c7636e2..acf5870d5 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -30,6 +30,7 @@ from lerobot.async_inference.helpers import ( resize_robot_observation_image, ) from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # --------------------------------------------------------------------- # FPSTracker @@ -115,7 +116,7 @@ def test_timed_action_getters(): def test_timed_observation_getters(): """TimedObservation stores & returns timestamp, dict and timestep.""" ts = time.time() - obs_dict = {"observation.state": torch.ones(6)} + obs_dict = {OBS_STATE: torch.ones(6)} to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0) assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) @@ -151,7 +152,7 @@ def test_timed_data_deserialization_data_getters(): # ------------------------------------------------------------------ # TimedObservation # ------------------------------------------------------------------ - obs_dict = {"observation.state": torch.arange(4).float()} + obs_dict = {OBS_STATE: torch.arange(4).float()} to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True) to_bytes = pickle.dumps(to_in) # nosec @@ -161,7 +162,7 @@ def test_timed_data_deserialization_data_getters(): assert to_out.get_timestep() == 7 assert to_out.must_go is True assert to_out.get_observation().keys() == obs_dict.keys() - torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"]) + torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE]) # --------------------------------------------------------------------- @@ -187,7 +188,7 @@ def test_observations_similar_true(): """Distance below atol → observations considered similar.""" # Create mock lerobot features for the similarity check lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], @@ -222,17 +223,17 @@ def _create_mock_robot_observation(): def _create_mock_lerobot_features(): """Create mock lerobot features mapping similar to what hw_to_dataset_features returns.""" return { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], }, - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "dtype": "image", "shape": [480, 640, 3], "names": ["height", "width", "channels"], }, - "observation.images.phone": { + f"{OBS_IMAGES}.phone": { "dtype": "image", "shape": [480, 640, 3], "names": ["height", "width", "channels"], @@ -243,11 +244,11 @@ def _create_mock_lerobot_features(): def _create_mock_policy_image_features(): """Create mock policy image features with different resolutions.""" return { - "observation.images.laptop": PolicyFeature( + f"{OBS_IMAGES}.laptop": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 224, 224), # Policy expects smaller resolution ), - "observation.images.phone": PolicyFeature( + f"{OBS_IMAGES}.phone": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 160, 160), # Different resolution for second camera ), @@ -306,21 +307,21 @@ def test_prepare_raw_observation(): prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features) # Check that state is properly extracted and batched - assert "observation.state" in prepared - state = prepared["observation.state"] + assert OBS_STATE in prepared + state = prepared[OBS_STATE] assert isinstance(state, torch.Tensor) assert state.shape == (1, 4) # Batched state # Check that images are processed and resized - assert "observation.images.laptop" in prepared - assert "observation.images.phone" in prepared + assert f"{OBS_IMAGES}.laptop" in prepared + assert f"{OBS_IMAGES}.phone" in prepared - laptop_img = prepared["observation.images.laptop"] - phone_img = prepared["observation.images.phone"] + laptop_img = prepared[f"{OBS_IMAGES}.laptop"] + phone_img = prepared[f"{OBS_IMAGES}.phone"] # Check image shapes match policy requirements - assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape - assert phone_img.shape == policy_image_features["observation.images.phone"].shape + assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape + assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape # Check that images are tensors assert isinstance(laptop_img, torch.Tensor) @@ -337,19 +338,19 @@ def test_raw_observation_to_observation_basic(): observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) # Check that all expected keys are present - assert "observation.state" in observation - assert "observation.images.laptop" in observation - assert "observation.images.phone" in observation + assert OBS_STATE in observation + assert f"{OBS_IMAGES}.laptop" in observation + assert f"{OBS_IMAGES}.phone" in observation # Check state processing - state = observation["observation.state"] + state = observation[OBS_STATE] assert isinstance(state, torch.Tensor) assert state.device.type == device assert state.shape == (1, 4) # Batched # Check image processing - laptop_img = observation["observation.images.laptop"] - phone_img = observation["observation.images.phone"] + laptop_img = observation[f"{OBS_IMAGES}.laptop"] + phone_img = observation[f"{OBS_IMAGES}.phone"] # Images should have batch dimension: (B, C, H, W) assert laptop_img.shape == (1, 3, 224, 224) @@ -429,19 +430,19 @@ def test_image_processing_pipeline_preserves_content(): robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img} lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], }, - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "dtype": "image", "shape": [100, 100, 3], "names": ["height", "width", "channels"], }, } policy_image_features = { - "observation.images.laptop": PolicyFeature( + f"{OBS_IMAGES}.laptop": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 50, 50), # Downsamples from 100x100 ) @@ -449,7 +450,7 @@ def test_image_processing_pipeline_preserves_content(): observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") - processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim + processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim # Check that the center region has higher values than corners # Due to bilinear interpolation, exact values will change but pattern should remain diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index c5c52460f..de441ff09 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -23,6 +23,7 @@ import pytest import torch from lerobot.configs.types import PolicyFeature +from lerobot.utils.constants import OBS_STATE from tests.utils import require_package # ----------------------------------------------------------------------------- @@ -44,7 +45,7 @@ class MockPolicy: def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: """Return a chunk of 20 dummy actions.""" - batch_size = len(observation["observation.state"]) + batch_size = len(observation[OBS_STATE]) return torch.zeros(batch_size, 20, 6) def __init__(self): @@ -77,7 +78,7 @@ def policy_server(): # Add mock lerobot_features that the observation similarity functions need server.lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [6], "names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 8f8179c29..982f35c3f 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -28,6 +28,7 @@ from lerobot.datasets.compute_stats import ( sample_images, sample_indices, ) +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def mock_load_image_as_numpy(path, dtype, channel_first): @@ -136,21 +137,21 @@ def test_get_feature_stats_single_value(): def test_compute_episode_stats(): episode_data = { - "observation.image": [f"image_{i}.jpg" for i in range(100)], - "observation.state": np.random.rand(100, 10), + OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)], + OBS_STATE: np.random.rand(100, 10), } features = { - "observation.image": {"dtype": "image"}, - "observation.state": {"dtype": "numeric"}, + OBS_IMAGE: {"dtype": "image"}, + OBS_STATE: {"dtype": "numeric"}, } with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): stats = compute_episode_stats(episode_data, features) - assert "observation.image" in stats and "observation.state" in stats - assert stats["observation.image"]["count"].item() == 100 - assert stats["observation.state"]["count"].item() == 100 - assert stats["observation.image"]["mean"].shape == (3, 1, 1) + assert OBS_IMAGE in stats and OBS_STATE in stats + assert stats[OBS_IMAGE]["count"].item() == 100 + assert stats[OBS_STATE]["count"].item() == 100 + assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1) def test_assert_type_and_shape_valid(): @@ -224,38 +225,38 @@ def test_aggregate_feature_stats(): def test_aggregate_stats(): all_stats = [ { - "observation.image": { + OBS_IMAGE: { "min": [1, 2, 3], "max": [10, 20, 30], "mean": [5.5, 10.5, 15.5], "std": [2.87, 5.87, 8.87], "count": 10, }, - "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, + OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, }, { - "observation.image": { + OBS_IMAGE: { "min": [2, 1, 0], "max": [15, 10, 5], "mean": [8.5, 5.5, 2.5], "std": [3.42, 2.42, 1.42], "count": 15, }, - "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, + OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, }, ] expected_agg_stats = { - "observation.image": { + OBS_IMAGE: { "min": [1, 1, 0], "max": [15, 20, 30], "mean": [7.3, 7.5, 7.7], "std": [3.5317, 4.8267, 8.5581], "count": 25, }, - "observation.state": { + OBS_STATE: { "min": 1, "max": 15, "mean": 7.3, @@ -283,7 +284,7 @@ def test_aggregate_stats(): for fkey, stats in ep_stats.items(): for k in stats: stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": + if fkey == OBS_IMAGE and k != "count": stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) @@ -292,7 +293,7 @@ def test_aggregate_stats(): for fkey, stats in expected_agg_stats.items(): for k in stats: stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": + if fkey == OBS_IMAGE and k != "count": stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index f1ffd800a..c0b07ca65 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -21,6 +21,7 @@ from huggingface_hub import DatasetCard from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch +from lerobot.utils.constants import OBS_IMAGES def test_default_parameters(): @@ -96,14 +97,14 @@ def test_merge_multiple_groups_order_and_dedup(): def test_non_vector_last_wins_for_images(): # Non-vector (images) with same name should be overwritten by the last image specified g1 = { - "observation.images.front": { + f"{OBS_IMAGES}.front": { "dtype": "image", "shape": (3, 480, 640), "names": ["channels", "height", "width"], } } g2 = { - "observation.images.front": { + f"{OBS_IMAGES}.front": { "dtype": "image", "shape": (3, 720, 1280), "names": ["channels", "height", "width"], @@ -111,8 +112,8 @@ def test_non_vector_last_wins_for_images(): } out = combine_feature_dicts(g1, g2) - assert out["observation.images.front"]["shape"] == (3, 720, 1280) - assert out["observation.images.front"]["dtype"] == "image" + assert out[f"{OBS_IMAGES}.front"]["shape"] == (3, 720, 1280) + assert out[f"{OBS_IMAGES}.front"]["dtype"] == "image" def test_dtype_mismatch_raises(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index d1d6dbdb2..1d461c8ba 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -46,6 +46,7 @@ from lerobot.datasets.utils import ( from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -75,7 +76,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) action_features = hw_to_dataset_features(robot.action_features, "action", True) - obs_features = hw_to_dataset_features(robot.observation_features, "observation", True) + obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True) dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" dataset_create = LeRobotDataset.create( @@ -397,7 +398,7 @@ def test_factory(env_name, repo_id, policy_name): ("frame_index", 0, True), ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? - ("observation.state", 1, True), + (OBS_STATE, 1, True), ("next.reward", 0, False), ("next.done", 0, False), ] @@ -662,7 +663,7 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): """Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata.""" features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], @@ -769,7 +770,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): def test_update_chunk_settings_video_dataset(tmp_path): """Test update_chunk_settings with a video dataset to ensure video-specific logic works.""" features = { - "observation.images.cam": { + f"{OBS_IMAGES}.cam": { "dtype": "video", "shape": (480, 640, 3), "names": ["height", "width", "channels"], diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index 0be1b9c7c..7a8782230 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -19,6 +19,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput +from lerobot.utils.constants import OBS_IMAGE from tests.utils import require_package @@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params(): config = RewardClassifierConfig() config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), @@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), + OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), } @@ -83,7 +84,7 @@ def test_multiclass_classifier(): num_classes = 5 config = RewardClassifierConfig() config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), @@ -95,7 +96,7 @@ def test_multiclass_classifier(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), + OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.rand((batch_size, num_classes)), } diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index b577e5763..7752ad63f 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -41,7 +41,7 @@ from lerobot.policies.factory import ( make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -52,7 +52,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p # Create only one camera input which is squared to fit all current policy constraints # e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared camera_features = { - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "shape": (84, 84, 3), "names": ["height", "width", "channels"], "info": None, @@ -64,7 +64,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], }, - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], @@ -281,7 +281,7 @@ def test_multikey_construction(multikey: bool): preventing erroneous creation of the policy object. """ input_features = { - "observation.state": PolicyFeature( + OBS_STATE: PolicyFeature( type=FeatureType.STATE, shape=(10,), ), @@ -297,9 +297,9 @@ def test_multikey_construction(multikey: bool): """Simulates the complete state/action is constructed from more granular multiple keys, of the same type as the overall state/action""" input_features = {} - input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) + input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) output_features = {} output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index a67815eed..59ed4af65 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import ( PolicyConfig, SACConfig, ) +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def test_sac_config_default_initialization(): @@ -37,11 +38,11 @@ def test_sac_config_default_initialization(): "ACTION": NormalizationMode.MIN_MAX, } assert config.dataset_stats == { - "observation.image": { + OBS_IMAGE: { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], }, - "observation.state": { + OBS_STATE: { "min": [0.0, 0.0], "max": [1.0, 1.0], }, @@ -90,11 +91,11 @@ def test_sac_config_default_initialization(): # Dataset stats defaults expected_dataset_stats = { - "observation.image": { + OBS_IMAGE: { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], }, - "observation.state": { + OBS_STATE: { "min": [0.0, 0.0], "max": [1.0, 1.0], }, @@ -191,7 +192,7 @@ def test_sac_config_custom_initialization(): def test_validate_features(): config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) config.validate_features() @@ -210,7 +211,7 @@ def test_validate_features_missing_observation(): def test_validate_features_missing_action(): config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) with pytest.raises(ValueError, match="You must provide 'action' in the output features"): diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 7891c2e52..71e45e055 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed try: @@ -85,14 +86,14 @@ def test_sac_policy_with_default_args(): def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: return { - "observation.state": torch.randn(batch_size, state_dim), + OBS_STATE: torch.randn(batch_size, state_dim), } def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: return { - "observation.image": torch.randn(batch_size, 3, 84, 84), - "observation.state": torch.randn(batch_size, state_dim), + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), + OBS_STATE: torch.randn(batch_size, state_dim), } @@ -126,14 +127,14 @@ def create_train_batch_with_visual_input( def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: return { - "observation.state": torch.randn(batch_size, state_dim), + OBS_STATE: torch.randn(batch_size, state_dim), } def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: return { - "observation.state": torch.randn(batch_size, state_dim), - "observation.image": torch.randn(batch_size, 3, 84, 84), + OBS_STATE: torch.randn(batch_size, state_dim), + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), } @@ -180,10 +181,10 @@ def create_default_config( action_dim += 1 config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, dataset_stats={ - "observation.state": { + OBS_STATE: { "min": [0.0] * state_dim, "max": [1.0] * state_dim, }, @@ -205,8 +206,8 @@ def create_config_with_visual_input( continuous_action_dim=continuous_action_dim, has_discrete_action=has_discrete_action, ) - config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) - config.dataset_stats["observation.image"] = { + config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats[OBS_IMAGE] = { "mean": torch.randn(3, 1, 1), "std": torch.randn(3, 1, 1), } diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index 00a4dbb96..134cff684 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -342,7 +342,7 @@ def test_act_processor_batch_consistency(): batch = transition_to_batch(transition) processed = preprocessor(batch) - assert processed["observation.state"].shape[0] == 1 # Batched + assert processed[OBS_STATE].shape[0] == 1 # Batched # Test already batched data observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 631ad7899..8bf24db02 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,14 +2,15 @@ import torch from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor.converters import batch_to_transition, transition_to_batch +from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE def _dummy_batch(): """Create a dummy batch using the new format with observation.* and next.* keys.""" return { - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.image.right": torch.randn(1, 3, 128, 128), - "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), + OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), "action": torch.tensor([[0.5]]), "next.reward": 1.0, "next.done": False, @@ -25,15 +26,15 @@ def test_observation_grouping_roundtrip(): batch_out = proc(batch_in) # Check that all observation.* keys are preserved - original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} - reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} + original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith(OBS_PREFIX)} + reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith(OBS_PREFIX)} assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) # Check tensor values - assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) - assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) - assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) + assert torch.allclose(batch_out[f"{OBS_IMAGE}.left"], batch_in[f"{OBS_IMAGE}.left"]) + assert torch.allclose(batch_out[f"{OBS_IMAGE}.right"], batch_in[f"{OBS_IMAGE}.right"]) + assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE]) # Check other fields assert torch.allclose(batch_out["action"], batch_in["action"]) @@ -46,9 +47,9 @@ def test_observation_grouping_roundtrip(): def test_batch_to_transition_observation_grouping(): """Test that batch_to_transition correctly groups observation.* keys.""" batch = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], + f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + OBS_STATE: [1, 2, 3, 4], "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, @@ -60,18 +61,18 @@ def test_batch_to_transition_observation_grouping(): # Check observation is a dict with all observation.* keys assert isinstance(transition[TransitionKey.OBSERVATION], dict) - assert "observation.image.top" in transition[TransitionKey.OBSERVATION] - assert "observation.image.left" in transition[TransitionKey.OBSERVATION] - assert "observation.state" in transition[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGE}.top" in transition[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGE}.left" in transition[TransitionKey.OBSERVATION] + assert OBS_STATE in transition[TransitionKey.OBSERVATION] # Check values are preserved assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.top"], batch[f"{OBS_IMAGE}.top"] ) assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.left"], batch[f"{OBS_IMAGE}.left"] ) - assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + assert transition[TransitionKey.OBSERVATION][OBS_STATE] == [1, 2, 3, 4] # Check other fields assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4])) @@ -85,9 +86,9 @@ def test_batch_to_transition_observation_grouping(): def test_transition_to_batch_observation_flattening(): """Test that transition_to_batch correctly flattens observation dict.""" observation_dict = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], + f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + OBS_STATE: [1, 2, 3, 4], } transition = { @@ -103,14 +104,14 @@ def test_transition_to_batch_observation_flattening(): batch = transition_to_batch(transition) # Check that observation.* keys are flattened back to batch - assert "observation.image.top" in batch - assert "observation.image.left" in batch - assert "observation.state" in batch + assert f"{OBS_IMAGE}.top" in batch + assert f"{OBS_IMAGE}.left" in batch + assert OBS_STATE in batch # Check values are preserved - assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) - assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) - assert batch["observation.state"] == [1, 2, 3, 4] + assert torch.allclose(batch[f"{OBS_IMAGE}.top"], observation_dict[f"{OBS_IMAGE}.top"]) + assert torch.allclose(batch[f"{OBS_IMAGE}.left"], observation_dict[f"{OBS_IMAGE}.left"]) + assert batch[OBS_STATE] == [1, 2, 3, 4] # Check other fields are mapped to next.* format assert batch["action"] == "action_data" @@ -153,12 +154,12 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])} + batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])} transition = batch_to_transition(batch) # Check observation - assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionKey.OBSERVATION] == {OBS_STATE: "minimal_state"} assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5])) # Check defaults @@ -170,7 +171,7 @@ def test_minimal_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) - assert reconstructed_batch["observation.state"] == "minimal_state" + assert reconstructed_batch[OBS_STATE] == "minimal_state" assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] @@ -205,9 +206,9 @@ def test_empty_batch(): def test_complex_nested_observation(): """Test with complex nested observation data.""" batch = { - "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, - "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, - "observation.state": torch.randn(7), + f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, + f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + OBS_STATE: torch.randn(7), "action": torch.randn(8), "next.reward": 3.14, "next.done": False, @@ -219,20 +220,20 @@ def test_complex_nested_observation(): reconstructed_batch = transition_to_batch(transition) # Check that all observation keys are preserved - original_obs_keys = {k for k in batch if k.startswith("observation.")} - reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} + original_obs_keys = {k for k in batch if k.startswith(OBS_PREFIX)} + reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith(OBS_PREFIX)} assert original_obs_keys == reconstructed_obs_keys # Check tensor values - assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) + assert torch.allclose(batch[OBS_STATE], reconstructed_batch[OBS_STATE]) # Check nested dict with tensors assert torch.allclose( - batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] + batch[f"{OBS_IMAGE}.top"]["image"], reconstructed_batch[f"{OBS_IMAGE}.top"]["image"] ) assert torch.allclose( - batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] + batch[f"{OBS_IMAGE}.left"]["image"], reconstructed_batch[f"{OBS_IMAGE}.left"]["image"] ) # Check action tensor @@ -264,7 +265,7 @@ def test_custom_converter(): processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch) batch = { - "observation.state": torch.randn(1, 4), + OBS_STATE: torch.randn(1, 4), "action": torch.randn(1, 2), "next.reward": 1.0, "next.done": False, @@ -274,5 +275,5 @@ def test_custom_converter(): # Check the reward was doubled by our custom converter assert result["next.reward"] == 2.0 - assert torch.allclose(result["observation.state"], batch["observation.state"]) + assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index fc91951de..b03d49214 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -9,6 +9,7 @@ from lerobot.processor.converters import ( to_tensor, transition_to_batch, ) +from lerobot.utils.constants import OBS_STATE, OBS_STR # Tests for the unified to_tensor function @@ -118,16 +119,16 @@ def test_to_tensor_dictionaries(): # Nested dictionary nested = { "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, - "observation": {"mean": np.array([0.5, 0.6]), "count": 10}, + OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10}, } result = to_tensor(nested) assert isinstance(result, dict) assert isinstance(result["action"], dict) - assert isinstance(result["observation"], dict) + assert isinstance(result[OBS_STR], dict) assert isinstance(result["action"]["mean"], torch.Tensor) - assert isinstance(result["observation"]["mean"], torch.Tensor) + assert isinstance(result[OBS_STR]["mean"], torch.Tensor) assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) - assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6])) + assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6])) def test_to_tensor_none_filtering(): @@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields(): # Create batch with index and task_index fields batch = { - "observation.state": torch.randn(1, 7), + OBS_STATE: torch.randn(1, 7), "action": torch.randn(1, 4), "next.reward": 1.5, "next.done": False, @@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields(): # Create transition with index and task_index in complementary_data transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), reward=1.5, done=False, @@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields(): # Batch without index/task_index batch = { - "observation.state": torch.randn(1, 7), + OBS_STATE: torch.randn(1, 7), "action": torch.randn(1, 4), "task": ["pick_cube"], } @@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields(): # Transition without index/task_index transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), complementary_data={"task": ["navigate"]}, ) diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index 10ee313d7..36081e021 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -21,6 +21,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def test_basic_functionality(): @@ -28,7 +29,7 @@ def test_basic_functionality(): processor = DeviceProcessorStep(device="cpu") # Create a transition with CPU tensors - observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)} action = torch.randn(5) reward = torch.tensor(1.0) done = torch.tensor(False) @@ -41,8 +42,8 @@ def test_basic_functionality(): result = processor(transition) # Check that all tensors are on CPU - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cpu" assert result[TransitionKey.ACTION].device.type == "cpu" assert result[TransitionKey.REWARD].device.type == "cpu" assert result[TransitionKey.DONE].device.type == "cpu" @@ -55,7 +56,7 @@ def test_cuda_functionality(): processor = DeviceProcessorStep(device="cuda") # Create a transition with CPU tensors - observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)} action = torch.randn(5) reward = torch.tensor(1.0) done = torch.tensor(False) @@ -68,8 +69,8 @@ def test_cuda_functionality(): result = processor(transition) # Check that all tensors are on CUDA - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda" @@ -81,14 +82,14 @@ def test_specific_cuda_device(): """Test device processor with specific CUDA device.""" processor = DeviceProcessorStep(device="cuda:0") - observation = {"observation.state": torch.randn(10)} + observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0 + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.index == 0 assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.index == 0 @@ -98,7 +99,7 @@ def test_non_tensor_values(): processor = DeviceProcessorStep(device="cpu") observation = { - "observation.state": torch.randn(10), + OBS_STATE: torch.randn(10), "observation.metadata": {"key": "value"}, # Non-tensor data "observation.list": [1, 2, 3], # Non-tensor data } @@ -110,7 +111,7 @@ def test_non_tensor_values(): result = processor(transition) # Check tensors are processed - assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor) + assert isinstance(result[TransitionKey.OBSERVATION][OBS_STATE], torch.Tensor) assert isinstance(result[TransitionKey.ACTION], torch.Tensor) # Check non-tensor values are preserved @@ -130,9 +131,9 @@ def test_none_values(): assert result[TransitionKey.ACTION].device.type == "cpu" # Test with None action - transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) + transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" assert result[TransitionKey.ACTION] is None @@ -271,9 +272,7 @@ def test_features(): processor = DeviceProcessorStep(device="cpu") features = { - PipelineFeatureType.OBSERVATION: { - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) - }, + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } @@ -376,7 +375,7 @@ def test_reward_done_truncated_types(): # Test with scalar values (not tensors) transition = create_transition( - observation={"observation.state": torch.randn(5)}, + observation={OBS_STATE: torch.randn(5)}, action=torch.randn(3), reward=1.0, # float done=False, # bool @@ -392,7 +391,7 @@ def test_reward_done_truncated_types(): # Test with tensor values transition = create_transition( - observation={"observation.state": torch.randn(5)}, + observation={OBS_STATE: torch.randn(5)}, action=torch.randn(3), reward=torch.tensor(1.0), done=torch.tensor(False), @@ -422,7 +421,7 @@ def test_complementary_data_preserved(): } transition = create_transition( - observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data + observation={OBS_STATE: torch.randn(5)}, complementary_data=complementary_data ) result = processor(transition) @@ -491,13 +490,13 @@ def test_float_dtype_bfloat16(): """Test conversion to bfloat16.""" processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16") - observation = {"observation.state": torch.randn(5, dtype=torch.float32)} + observation = {OBS_STATE: torch.randn(5, dtype=torch.float32)} action = torch.randn(3, dtype=torch.float64) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 assert result[TransitionKey.ACTION].dtype == torch.bfloat16 @@ -505,13 +504,13 @@ def test_float_dtype_float64(): """Test conversion to float64.""" processor = DeviceProcessorStep(device="cpu", float_dtype="float64") - observation = {"observation.state": torch.randn(5, dtype=torch.float16)} + observation = {OBS_STATE: torch.randn(5, dtype=torch.float16)} action = torch.randn(3, dtype=torch.float32) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float64 assert result[TransitionKey.ACTION].dtype == torch.float64 @@ -541,8 +540,8 @@ def test_float_dtype_with_mixed_tensors(): processor = DeviceProcessorStep(device="cpu", float_dtype="float32") observation = { - "observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert - "observation.state": torch.randn(10, dtype=torch.float64), # Should convert + OBS_IMAGE: torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert + OBS_STATE: torch.randn(10, dtype=torch.float64), # Should convert "observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert "observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert } @@ -552,8 +551,8 @@ def test_float_dtype_with_mixed_tensors(): result = processor(transition) # Check conversions - assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.uint8 # Unchanged + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted @@ -612,7 +611,7 @@ def test_complementary_data_index_fields(): "episode_id": 123, # Non-tensor field } transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), complementary_data=complementary_data, ) @@ -736,7 +735,7 @@ def test_complementary_data_full_pipeline_cuda(): processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") # Create full transition with mixed CPU tensors - observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)} + observation = {OBS_STATE: torch.randn(1, 7, dtype=torch.float32)} action = torch.randn(1, 4, dtype=torch.float32) reward = torch.tensor(1.5, dtype=torch.float32) done = torch.tensor(False) @@ -757,7 +756,7 @@ def test_complementary_data_full_pipeline_cuda(): result = processor(transition) # Check all components moved to CUDA - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda" @@ -768,7 +767,7 @@ def test_complementary_data_full_pipeline_cuda(): assert processed_comp_data["task_index"].device.type == "cuda" # Check float conversion happened for float tensors - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 assert result[TransitionKey.ACTION].dtype == torch.float16 assert result[TransitionKey.REWARD].dtype == torch.float16 @@ -782,7 +781,7 @@ def test_complementary_data_empty(): processor = DeviceProcessorStep(device="cpu") transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, complementary_data={}, ) @@ -797,7 +796,7 @@ def test_complementary_data_none(): processor = DeviceProcessorStep(device="cpu") transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, complementary_data=None, ) @@ -814,8 +813,8 @@ def test_preserves_gpu_placement(): # Create tensors already on GPU observation = { - "observation.state": torch.randn(10).cuda(), # Already on GPU - "observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU + OBS_STATE: torch.randn(10).cuda(), # Already on GPU + OBS_IMAGE: torch.randn(3, 224, 224).cuda(), # Already on GPU } action = torch.randn(5).cuda() # Already on GPU @@ -823,14 +822,12 @@ def test_preserves_gpu_placement(): result = processor(transition) # Check that tensors remain on their original GPU - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" # Verify no unnecessary copies were made (same data pointer) - assert torch.equal( - result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"] - ) + assert torch.equal(result[TransitionKey.OBSERVATION][OBS_STATE], observation[OBS_STATE]) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -842,8 +839,8 @@ def test_multi_gpu_preservation(): # Create tensors on cuda:1 (simulating Accelerate placement) cuda1_device = torch.device("cuda:1") observation = { - "observation.state": torch.randn(10).to(cuda1_device), - "observation.image": torch.randn(3, 224, 224).to(cuda1_device), + OBS_STATE: torch.randn(10).to(cuda1_device), + OBS_IMAGE: torch.randn(3, 224, 224).to(cuda1_device), } action = torch.randn(5).to(cuda1_device) @@ -851,20 +848,20 @@ def test_multi_gpu_preservation(): result = processor_gpu(transition) # Check that tensors remain on cuda:1 (not moved to cuda:0) - assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device - assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device + assert result[TransitionKey.OBSERVATION][OBS_STATE].device == cuda1_device + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device == cuda1_device assert result[TransitionKey.ACTION].device == cuda1_device # Test 2: GPU-to-CPU should move to CPU (not preserve GPU) processor_cpu = DeviceProcessorStep(device="cpu") transition_gpu = create_transition( - observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda() + observation={OBS_STATE: torch.randn(10).cuda()}, action=torch.randn(5).cuda() ) result_cpu = processor_cpu(transition_gpu) # Check that tensors are moved to CPU - assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result_cpu[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" assert result_cpu[TransitionKey.ACTION].device.type == "cpu" @@ -933,14 +930,14 @@ def test_simulated_accelerate_scenario(): # Simulate data already placed by Accelerate device = torch.device(f"cuda:{gpu_id}") - observation = {"observation.state": torch.randn(1, 10).to(device)} + observation = {OBS_STATE: torch.randn(1, 10).to(device)} action = torch.randn(1, 5).to(device) transition = create_transition(observation=observation, action=action) result = processor(transition) # Verify data stays on the GPU where Accelerate placed it - assert result[TransitionKey.OBSERVATION]["observation.state"].device == device + assert result[TransitionKey.OBSERVATION][OBS_STATE].device == device assert result[TransitionKey.ACTION].device == device @@ -1081,7 +1078,7 @@ def test_mps_float64_with_complementary_data(): } transition = create_transition( - observation={"observation.state": torch.randn(5, dtype=torch.float64)}, + observation={OBS_STATE: torch.randn(5, dtype=torch.float64)}, action=torch.randn(3, dtype=torch.float64), complementary_data=complementary_data, ) @@ -1089,7 +1086,7 @@ def test_mps_float64_with_complementary_data(): result = processor(transition) # Check that all tensors are on MPS device - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "mps" assert result[TransitionKey.ACTION].device.type == "mps" processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] @@ -1099,7 +1096,7 @@ def test_mps_float64_with_complementary_data(): assert processed_comp_data["float32_tensor"].device.type == "mps" # Check dtype conversions - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged diff --git a/tests/processor/test_migration_detection.py b/tests/processor/test_migration_detection.py index 6bed8289d..b46cc6bdd 100644 --- a/tests/processor/test_migration_detection.py +++ b/tests/processor/test_migration_detection.py @@ -25,6 +25,7 @@ from pathlib import Path import pytest from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError +from lerobot.utils.constants import OBS_STATE def test_is_processor_config_valid_configs(): @@ -111,7 +112,7 @@ def test_should_suggest_migration_with_model_config_only(): # Create a model config (like old LeRobot format) model_config = { "type": "act", - "input_features": {"observation.state": {"shape": [7]}}, + "input_features": {OBS_STATE: {"shape": [7]}}, "output_features": {"action": {"shape": [7]}}, "hidden_dim": 256, "n_obs_steps": 1, diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 5d7791919..616f33db9 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -29,22 +29,23 @@ from lerobot.processor import ( hotswap_stats, ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR from lerobot.utils.utils import auto_select_torch_device def test_numpy_conversion(): stats = { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), } } tensor_stats = to_tensor(stats) - assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) - assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) - assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) + assert isinstance(tensor_stats[OBS_IMAGE]["mean"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) + assert torch.allclose(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.5, 0.5, 0.5])) + assert torch.allclose(tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.2, 0.2, 0.2])) def test_tensor_conversion(): @@ -75,15 +76,15 @@ def test_scalar_conversion(): def test_list_conversion(): stats = { - "observation.state": { + OBS_STATE: { "min": [0.0, -1.0, -2.0], "max": [1.0, 1.0, 2.0], } } tensor_stats = to_tensor(stats) - assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) - assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) + assert torch.allclose(tensor_stats[OBS_STATE]["min"], torch.tensor([0.0, -1.0, -2.0])) + assert torch.allclose(tensor_stats[OBS_STATE]["max"], torch.tensor([1.0, 1.0, 2.0])) def test_unsupported_type(): @@ -99,8 +100,8 @@ def test_unsupported_type(): # Helper functions to create feature maps and norm maps def _create_observation_features(): return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } @@ -115,11 +116,11 @@ def _create_observation_norm_map(): @pytest.fixture def observation_stats(): return { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -136,8 +137,8 @@ def observation_normalizer(observation_stats): def test_mean_std_normalization(observation_normalizer): observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -146,12 +147,12 @@ def test_mean_std_normalization(observation_normalizer): # Check mean/std normalization expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(normalized_obs["observation.image"], expected_image) + assert torch.allclose(normalized_obs[OBS_IMAGE], expected_image) def test_min_max_normalization(observation_normalizer): observation = { - "observation.state": torch.tensor([0.5, 0.0]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -162,7 +163,7 @@ def test_min_max_normalization(observation_normalizer): # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 expected_state = torch.tensor([0.0, 0.0]) - assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state, atol=1e-6) def test_selective_normalization(observation_stats): @@ -172,12 +173,12 @@ def test_selective_normalization(observation_stats): features=features, norm_map=norm_map, stats=observation_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, ) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -185,9 +186,9 @@ def test_selective_normalization(observation_stats): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Only image should be normalized - assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) + assert torch.allclose(normalized_obs[OBS_IMAGE], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) # State should remain unchanged - assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + assert torch.allclose(normalized_obs[OBS_STATE], observation[OBS_STATE]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -196,26 +197,26 @@ def test_device_compatibility(observation_stats): norm_map = _create_observation_norm_map() normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]).cuda(), } transition = create_transition(observation=observation) normalized_transition = normalizer(transition) normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - assert normalized_obs["observation.image"].device.type == "cuda" + assert normalized_obs[OBS_IMAGE].device.type == "cuda" def test_from_lerobot_dataset(): # Mock dataset mock_dataset = Mock() mock_dataset.meta.stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, + OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, "action": {"mean": [0.0], "std": [1.0]}, } features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), "action": PolicyFeature(FeatureType.ACTION, (1,)), } norm_map = { @@ -226,7 +227,7 @@ def test_from_lerobot_dataset(): normalizer = NormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # Both observation and action statistics should be present in tensor stats - assert "observation.image" in normalizer._tensor_stats + assert OBS_IMAGE in normalizer._tensor_stats assert "action" in normalizer._tensor_stats @@ -242,13 +243,13 @@ def test_state_dict_save_load(observation_normalizer): new_normalizer.load_state_dict(state_dict) # Test that it works the same - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} transition = create_transition(observation=observation) result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(result1["observation.image"], result2["observation.image"]) + assert torch.allclose(result1[OBS_IMAGE], result2[OBS_IMAGE]) # Fixtures for ActionUnnormalizer tests @@ -375,11 +376,11 @@ def test_action_from_lerobot_dataset(): @pytest.fixture def full_stats(): return { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -392,8 +393,8 @@ def full_stats(): def _create_full_features(): return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } @@ -415,8 +416,8 @@ def normalizer_processor(full_stats): def test_combined_normalization(normalizer_processor): observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -434,7 +435,7 @@ def test_combined_normalization(normalizer_processor): # Check normalized observations processed_obs = processed_transition[TransitionKey.OBSERVATION] expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(processed_obs["observation.image"], expected_image) + assert torch.allclose(processed_obs[OBS_IMAGE], expected_image) # Check normalized action processed_action = processed_transition[TransitionKey.ACTION] @@ -455,11 +456,11 @@ def test_processor_from_lerobot_dataset(full_stats): norm_map = _create_full_norm_map() processor = NormalizerProcessorStep.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"} + mock_dataset, features, norm_map, normalize_observation_keys={OBS_IMAGE} ) - assert processor.normalize_observation_keys == {"observation.image"} - assert "observation.image" in processor._tensor_stats + assert processor.normalize_observation_keys == {OBS_IMAGE} + assert OBS_IMAGE in processor._tensor_stats assert "action" in processor._tensor_stats @@ -470,17 +471,17 @@ def test_get_config(full_stats): features=features, norm_map=norm_map, stats=full_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, eps=1e-6, ) config = processor.get_config() expected_config = { - "normalize_observation_keys": ["observation.image"], + "normalize_observation_keys": [OBS_IMAGE], "eps": 1e-6, "features": { - "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, - "observation.state": {"type": "STATE", "shape": (2,)}, + OBS_IMAGE: {"type": "VISUAL", "shape": (3, 96, 96)}, + OBS_STATE: {"type": "STATE", "shape": (2,)}, "action": {"type": "ACTION", "shape": (2,)}, }, "norm_map": { @@ -499,8 +500,8 @@ def test_integration_with_robot_processor(normalizer_processor): ) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -522,8 +523,8 @@ def test_integration_with_robot_processor(normalizer_processor): # Edge case tests def test_empty_observation(): - stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -534,37 +535,35 @@ def test_empty_observation(): def test_empty_stats(): - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) - observation = {"observation.image": torch.tensor([0.5])} + observation = {OBS_IMAGE: torch.tensor([0.5])} transition = create_transition(observation=observation) result = normalizer(transition) # Should return observation unchanged since no stats are available - assert torch.allclose( - result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] - ) + assert torch.allclose(result[TransitionKey.OBSERVATION][OBS_IMAGE], observation[OBS_IMAGE]) def test_partial_stats(): """If statistics are incomplete, the value should pass through unchanged.""" - stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + stats = {OBS_IMAGE: {"mean": [0.5]}} # Missing std / (min,max) + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - observation = {"observation.image": torch.tensor([0.7])} + observation = {OBS_IMAGE: torch.tensor([0.7])} transition = create_transition(observation=observation) processed = normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(processed["observation.image"], observation["observation.image"]) + assert torch.allclose(processed[OBS_IMAGE], observation[OBS_IMAGE]) def test_missing_action_stats_no_error(): mock_dataset = Mock() - mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + mock_dataset.meta.stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) @@ -580,7 +579,7 @@ def test_serialization_roundtrip(full_stats): features=features, norm_map=norm_map, stats=full_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, eps=1e-6, ) @@ -598,8 +597,8 @@ def test_serialization_roundtrip(full_stats): # Test that both processors work the same way observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -617,8 +616,8 @@ def test_serialization_roundtrip(full_stats): # Compare results assert torch.allclose( - result1[TransitionKey.OBSERVATION]["observation.image"], - result2[TransitionKey.OBSERVATION]["observation.image"], + result1[TransitionKey.OBSERVATION][OBS_IMAGE], + result2[TransitionKey.OBSERVATION][OBS_IMAGE], ) assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) @@ -644,23 +643,23 @@ def test_serialization_roundtrip(full_stats): def test_identity_normalization_observations(): """Test that IDENTITY mode skips normalization for observations.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } transition = create_transition(observation=observation) @@ -668,11 +667,11 @@ def test_identity_normalization_observations(): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (MEAN_STD) expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0]) - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) def test_identity_normalization_actions(): @@ -695,23 +694,23 @@ def test_identity_normalization_actions(): def test_identity_unnormalization_observations(): """Test that IDENTITY mode skips unnormalization for observations.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] } transition = create_transition(observation=observation) @@ -719,13 +718,13 @@ def test_identity_unnormalization_observations(): unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(unnormalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be unnormalized (MIN_MAX) # (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0 # (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0 expected_state = torch.tensor([0.0, -1.0]) - assert torch.allclose(unnormalized_obs["observation.state"], expected_state) + assert torch.allclose(unnormalized_obs[OBS_STATE], expected_state) def test_identity_unnormalization_actions(): @@ -748,7 +747,7 @@ def test_identity_unnormalization_actions(): def test_identity_with_missing_stats(): """Test that IDENTITY mode works even when stats are missing.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -760,7 +759,7 @@ def test_identity_with_missing_stats(): normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -769,13 +768,13 @@ def test_identity_with_missing_stats(): unnormalized_transition = unnormalizer(transition) assert torch.allclose( - normalized_transition[TransitionKey.OBSERVATION]["observation.image"], - observation["observation.image"], + normalized_transition[TransitionKey.OBSERVATION][OBS_IMAGE], + observation[OBS_IMAGE], ) assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) assert torch.allclose( - unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"], - observation["observation.image"], + unnormalized_transition[TransitionKey.OBSERVATION][OBS_IMAGE], + observation[OBS_IMAGE], ) assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) @@ -783,8 +782,8 @@ def test_identity_with_missing_stats(): def test_identity_mixed_with_other_modes(): """Test IDENTITY mode mixed with other normalization modes.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -793,16 +792,16 @@ def test_identity_mixed_with_other_modes(): FeatureType.ACTION: NormalizationMode.MIN_MAX, } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } action = torch.tensor([0.5, 0.0]) transition = create_transition(observation=observation, action=action) @@ -812,11 +811,11 @@ def test_identity_mixed_with_other_modes(): normalized_action = normalized_transition[TransitionKey.ACTION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (MEAN_STD) expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) # Action should be normalized (MIN_MAX) to [-1, 1] # 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5 @@ -828,23 +827,23 @@ def test_identity_mixed_with_other_modes(): def test_identity_defaults_when_not_in_norm_map(): """Test that IDENTITY is used as default when feature type not in norm_map.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.STATE: NormalizationMode.MEAN_STD, # VISUAL not specified, should default to IDENTITY } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } transition = create_transition(observation=observation) @@ -852,17 +851,17 @@ def test_identity_defaults_when_not_in_norm_map(): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (defaults to IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (explicitly MEAN_STD) expected_state = torch.tensor([1.0, -0.5]) - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) def test_identity_roundtrip(): """Test that IDENTITY normalization and unnormalization are true inverses.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -870,14 +869,14 @@ def test_identity_roundtrip(): FeatureType.ACTION: NormalizationMode.IDENTITY, } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + original_observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} original_action = torch.tensor([0.5, -0.2]) original_transition = create_transition(observation=original_observation, action=original_action) @@ -886,16 +885,14 @@ def test_identity_roundtrip(): roundtrip = unnormalizer(normalized) # Should be identical to original - assert torch.allclose( - roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"] - ) + assert torch.allclose(roundtrip[TransitionKey.OBSERVATION][OBS_IMAGE], original_observation[OBS_IMAGE]) assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action) def test_identity_config_serialization(): """Test that IDENTITY mode is properly saved and loaded in config.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -903,7 +900,7 @@ def test_identity_config_serialization(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, + OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } @@ -925,7 +922,7 @@ def test_identity_config_serialization(): ) # Test that both work the same way - observation = {"observation.image": torch.tensor([0.7])} + observation = {OBS_IMAGE: torch.tensor([0.7])} action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -934,15 +931,15 @@ def test_identity_config_serialization(): # Results should be identical assert torch.allclose( - result1[TransitionKey.OBSERVATION]["observation.image"], - result2[TransitionKey.OBSERVATION]["observation.image"], + result1[TransitionKey.OBSERVATION][OBS_IMAGE], + result2[TransitionKey.OBSERVATION][OBS_IMAGE], ) assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # def test_unsupported_normalization_mode_error(): # """Test that unsupported normalization modes raise appropriate errors.""" -# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} +# features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (2,))} # # Create an invalid norm_map (this would never happen in practice, but tests error handling) # from enum import Enum @@ -953,14 +950,14 @@ def test_identity_config_serialization(): # # We can't actually pass an invalid enum to the processor due to type checking, # # but we can test the error by manipulating the norm_map after creation # norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} -# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} +# stats = {OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} # normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) # # Manually inject an invalid mode to test error handling # normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" -# observation = {"observation.state": torch.tensor([1.0, -0.5])} +# observation = {OBS_STATE: torch.tensor([1.0, -0.5])} # transition = create_transition(observation=observation) # with pytest.raises(ValueError, match="Unsupported normalization mode"): @@ -971,19 +968,19 @@ def test_hotswap_stats_basic_functionality(): """Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps.""" # Create initial stats initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create new stats for hotswapping new_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } # Create features and norm_map features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1021,15 +1018,15 @@ def test_hotswap_stats_basic_functionality(): def test_hotswap_stats_deep_copy(): """Test that hotswap_stats creates a deep copy and doesn't modify the original processor.""" initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1060,15 +1057,15 @@ def test_hotswap_stats_deep_copy(): def test_hotswap_stats_only_affects_normalizer_steps(): """Test that hotswap_stats only modifies NormalizerProcessorStep and UnnormalizerProcessorStep steps.""" stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1093,13 +1090,13 @@ def test_hotswap_stats_only_affects_normalizer_steps(): def test_hotswap_stats_empty_stats(): """Test hotswap_stats with empty stats dictionary.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } empty_stats = {} features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1117,7 +1114,7 @@ def test_hotswap_stats_empty_stats(): def test_hotswap_stats_no_normalizer_steps(): """Test hotswap_stats with a processor that has no normalizer/unnormalizer steps.""" stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } # Create processor with only identity steps @@ -1139,18 +1136,18 @@ def test_hotswap_stats_no_normalizer_steps(): def test_hotswap_stats_preserves_other_attributes(): """Test that hotswap_stats preserves other processor attributes like features and norm_map.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalize_observation_keys = {"observation.image"} + normalize_observation_keys = {OBS_IMAGE} eps = 1e-6 normalizer = NormalizerProcessorStep( @@ -1179,17 +1176,17 @@ def test_hotswap_stats_preserves_other_attributes(): def test_hotswap_stats_multiple_normalizer_types(): """Test hotswap_stats with multiple normalizer and unnormalizer steps.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, "action": {"min": np.array([-1.0]), "max": np.array([1.0])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, "action": {"min": np.array([-2.0]), "max": np.array([2.0])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), } norm_map = { @@ -1224,12 +1221,12 @@ def test_hotswap_stats_multiple_normalizer_types(): def test_hotswap_stats_with_different_data_types(): """Test hotswap_stats with various data types in stats.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } # New stats with different data types (int, float, list, tuple) new_stats = { - "observation.image": { + OBS_IMAGE: { "mean": [0.3, 0.4, 0.5], # list "std": (0.1, 0.2, 0.3), # tuple "min": 0, # int @@ -1242,7 +1239,7 @@ def test_hotswap_stats_with_different_data_types(): } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1261,43 +1258,43 @@ def test_hotswap_stats_with_different_data_types(): # Check that tensor conversion worked correctly tensor_stats = new_processor.steps[0]._tensor_stats - assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["mean"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["min"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["max"], torch.Tensor) assert isinstance(tensor_stats["action"]["mean"], torch.Tensor) assert isinstance(tensor_stats["action"]["std"], torch.Tensor) # Check values - torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5])) - torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3])) - torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0)) - torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.4, 0.5])) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2, 0.3])) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["min"], torch.tensor(0.0)) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["max"], torch.tensor(1.0)) def test_hotswap_stats_functional_test(): """Test that hotswapped processor actually works functionally.""" # Create test data observation = { - "observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), + OBS_IMAGE: torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), } action = torch.tensor([0.5, -0.5]) transition = create_transition(observation=observation, action=action) # Initial stats initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # New stats new_stats = { - "observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, "action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1322,8 +1319,8 @@ def test_hotswap_stats_functional_test(): # Results should be different since normalization changed assert not torch.allclose( - original_result["observation"]["observation.image"], - new_result["observation"]["observation.image"], + original_result[OBS_STR][OBS_IMAGE], + new_result[OBS_STR][OBS_IMAGE], rtol=1e-3, atol=1e-3, ) @@ -1331,60 +1328,54 @@ def test_hotswap_stats_functional_test(): # Verify that the new processor is actually using the new stats by checking internal state assert new_processor.steps[0].stats == new_stats - assert torch.allclose( - new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2]) - ) - assert torch.allclose( - new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2]) - ) + assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.2])) + assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2])) assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1])) assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5])) # Test that normalization actually happens (output should not equal input) - assert not torch.allclose( - new_result["observation"]["observation.image"], observation["observation.image"] - ) + assert not torch.allclose(new_result[OBS_STR][OBS_IMAGE], observation[OBS_IMAGE]) assert not torch.allclose(new_result["action"], action) def test_zero_std_uses_eps(): """When std == 0, (x-mean)/(std+eps) is well-defined; x==mean should map to 0.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}} + stats = {OBS_STATE: {"mean": np.array([0.5]), "std": np.array([0.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) - observation = {"observation.state": torch.tensor([0.5])} # equals mean + observation = {OBS_STATE: torch.tensor([0.5])} # equals mean out = normalizer(create_transition(observation=observation)) - assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([0.0])) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([0.0])) def test_min_equals_max_maps_to_minus_one(): """When min == max, MIN_MAX path maps to -1 after [-1,1] scaling for x==min.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX} - stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}} + stats = {OBS_STATE: {"min": np.array([2.0]), "max": np.array([2.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) - observation = {"observation.state": torch.tensor([2.0])} + observation = {OBS_STATE: torch.tensor([2.0])} out = normalizer(create_transition(observation=observation)) - assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([-1.0])) def test_action_normalized_despite_normalize_observation_keys(): """Action normalization is independent of normalize_observation_keys filter for observations.""" features = { - "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (1,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessorStep( - features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"} + features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={OBS_STATE} ) transition = create_transition( - observation={"observation.state": torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) + observation={OBS_STATE: torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) ) out = normalizer(transition) # (3-1)/2 = 1.0 ; (3-(-1))/4 = 1.0 @@ -1421,12 +1412,12 @@ def test_unnormalize_observations_mean_std_and_min_max(): def test_unknown_observation_keys_ignored(): - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} + obs = {OBS_STATE: torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} tr = create_transition(observation=obs) out = normalizer(tr) @@ -1447,13 +1438,13 @@ def test_batched_action_normalization(): def test_complementary_data_preservation(): - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) comp = {"existing": 123} - tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) + tr = create_transition(observation={OBS_STATE: torch.tensor([1.0])}, complementary_data=comp) out = normalizer(tr) new_comp = out[TransitionKey.COMPLEMENTARY_DATA] assert new_comp["existing"] == 123 @@ -1461,36 +1452,34 @@ def test_complementary_data_preservation(): def test_roundtrip_normalize_unnormalize_non_identity(): features = { - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} stats = { - "observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, + OBS_STATE: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) # Add a time dimension in action for broadcasting check (B,T,D) - obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])} + obs = {OBS_STATE: torch.tensor([[3.0, 3.0], [1.0, -1.0]])} act = torch.tensor([[[0.0, -1.0], [1.0, 1.0]]]) # shape (1,2,2) already in [-1,1] tr = create_transition(observation=obs, action=act) out = unnormalizer(normalizer(tr)) - assert torch.allclose( - out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 - ) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], obs[OBS_STATE], atol=1e-5) assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5) def test_dtype_adaptation_bfloat16_input_float32_normalizer(): """Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (5,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} stats = { - "observation.state": { + OBS_STATE: { "mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), } @@ -1503,11 +1492,11 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer(): # Verify initial configuration assert normalizer.dtype == torch.float32 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.float32 # Create bfloat16 input tensor - observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} + observation = {OBS_STATE: torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} transition = create_transition(observation=observation) # Process the transition @@ -1516,11 +1505,11 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer(): # Verify that: # 1. Stats were automatically adapted to bfloat16 assert normalizer.dtype == torch.bfloat16 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.bfloat16 # 2. Output is in bfloat16 - output_tensor = result[TransitionKey.OBSERVATION]["observation.state"] + output_tensor = result[TransitionKey.OBSERVATION][OBS_STATE] assert output_tensor.dtype == torch.bfloat16 # 3. Normalization was applied correctly (mean should be close to original - mean) / std @@ -1540,18 +1529,18 @@ def test_stats_override_preservation_in_load_state_dict(): """ # Create original stats original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create override stats (what user wants to use) override_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1611,12 +1600,12 @@ def test_stats_without_override_loads_normally(): load_state_dict works as before. """ original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1651,12 +1640,12 @@ def test_stats_without_override_loads_normally(): def test_stats_explicit_provided_flag_detection(): """Test that the _stats_explicitly_provided flag is set correctly in different scenarios.""" features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} # Test 1: Explicitly provided stats (non-empty dict) - stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) assert normalizer1._stats_explicitly_provided is True @@ -1684,7 +1673,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Create test data features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1693,12 +1682,12 @@ def test_pipeline_from_pretrained_with_stats_overrides(): } original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } override_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } @@ -1740,7 +1729,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Test that the override stats are actually used in processing observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), } action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -1770,9 +1759,9 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): """Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output""" from lerobot.processor import DeviceProcessorStep - features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (3,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} # Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32) device_processor = DeviceProcessorStep(device=str(auto_select_torch_device()), float_dtype="bfloat16") @@ -1784,18 +1773,18 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): assert normalizer.dtype == torch.float32 # Create CPU input - observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} + observation = {OBS_STATE: torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} transition = create_transition(observation=observation) # Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA processed_1 = device_processor(transition) - intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"] + intermediate_tensor = processed_1[TransitionKey.OBSERVATION][OBS_STATE] assert intermediate_tensor.dtype == torch.bfloat16 assert intermediate_tensor.device.type == str(auto_select_torch_device()) # Step 2: NormalizerProcessor receives bfloat16 input and adapts final_result = normalizer(processed_1) - final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"] + final_tensor = final_result[TransitionKey.OBSERVATION][OBS_STATE] # Verify final output is bfloat16 (automatic adaptation worked) assert final_tensor.dtype == torch.bfloat16 @@ -1803,7 +1792,7 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): # Verify normalizer adapted its internal state assert normalizer.dtype == torch.bfloat16 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.bfloat16 assert stat_tensor.device.type == str(auto_select_torch_device()) @@ -1821,8 +1810,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Create normalizer with stats features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -1831,11 +1820,11 @@ def test_stats_reconstruction_after_load_state_dict(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } stats = { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -1861,15 +1850,15 @@ def test_stats_reconstruction_after_load_state_dict(): assert new_normalizer.stats != {} # Check that all expected keys are present - assert "observation.image" in new_normalizer.stats - assert "observation.state" in new_normalizer.stats + assert OBS_IMAGE in new_normalizer.stats + assert OBS_STATE in new_normalizer.stats assert "action" in new_normalizer.stats # Check that values are correct (converted back from tensors) - np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5]) - np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2]) - np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0]) - np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2]) + np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0]) + np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0]) np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) @@ -1885,8 +1874,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Test 2: hotswap_stats should work new_stats = { - "observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, - "observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, + OBS_IMAGE: {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, + OBS_STATE: {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, } @@ -1900,8 +1889,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Test 3: The normalizer should work functionally the same as the original observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -1911,11 +1900,11 @@ def test_stats_reconstruction_after_load_state_dict(): # Results should be identical (within floating point precision) torch.testing.assert_close( - original_result[TransitionKey.OBSERVATION]["observation.image"], - new_result[TransitionKey.OBSERVATION]["observation.image"], + original_result[TransitionKey.OBSERVATION][OBS_IMAGE], + new_result[TransitionKey.OBSERVATION][OBS_IMAGE], ) torch.testing.assert_close( - original_result[TransitionKey.OBSERVATION]["observation.state"], - new_result[TransitionKey.OBSERVATION]["observation.state"], + original_result[TransitionKey.OBSERVATION][OBS_STATE], + new_result[TransitionKey.OBSERVATION][OBS_STATE], ) torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION]) diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 6abc9edef..11b58a66c 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -39,8 +39,8 @@ def test_process_single_image(): processed_obs = result[TransitionKey.OBSERVATION] # Check that the image was processed correctly - assert "observation.image" in processed_obs - processed_img = processed_obs["observation.image"] + assert OBS_IMAGE in processed_obs + processed_img = processed_obs[OBS_IMAGE] # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width assert processed_img.shape == (1, 3, 64, 64) @@ -66,12 +66,12 @@ def test_process_image_dict(): processed_obs = result[TransitionKey.OBSERVATION] # Check that both images were processed - assert "observation.images.camera1" in processed_obs - assert "observation.images.camera2" in processed_obs + assert f"{OBS_IMAGES}.camera1" in processed_obs + assert f"{OBS_IMAGES}.camera2" in processed_obs # Check shapes - assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) - assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32) + assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48) def test_process_batched_image(): @@ -88,7 +88,7 @@ def test_process_batched_image(): processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimension is preserved - assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64) def test_invalid_image_format(): @@ -173,10 +173,10 @@ def test_process_environment_state(): processed_obs = result[TransitionKey.OBSERVATION] # Check that environment_state was renamed and processed - assert "observation.environment_state" in processed_obs + assert OBS_ENV_STATE in processed_obs assert "environment_state" not in processed_obs - processed_state = processed_obs["observation.environment_state"] + processed_state = processed_obs[OBS_ENV_STATE] assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.dtype == torch.float32 torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) @@ -194,10 +194,10 @@ def test_process_agent_pos(): processed_obs = result[TransitionKey.OBSERVATION] # Check that agent_pos was renamed and processed - assert "observation.state" in processed_obs + assert OBS_STATE in processed_obs assert "agent_pos" not in processed_obs - processed_state = processed_obs["observation.state"] + processed_state = processed_obs[OBS_STATE] assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.dtype == torch.float32 torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) @@ -217,8 +217,8 @@ def test_process_batched_states(): processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimensions are preserved - assert processed_obs["observation.environment_state"].shape == (2, 2) - assert processed_obs["observation.state"].shape == (2, 2) + assert processed_obs[OBS_ENV_STATE].shape == (2, 2) + assert processed_obs[OBS_STATE].shape == (2, 2) def test_process_both_states(): @@ -235,8 +235,8 @@ def test_process_both_states(): processed_obs = result[TransitionKey.OBSERVATION] # Check that both states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs + assert OBS_ENV_STATE in processed_obs + assert OBS_STATE in processed_obs # Check that original keys were removed assert "environment_state" not in processed_obs @@ -281,12 +281,12 @@ def test_complete_observation_processing(): processed_obs = result[TransitionKey.OBSERVATION] # Check that image was processed - assert "observation.image" in processed_obs - assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + assert OBS_IMAGE in processed_obs + assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32) # Check that states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs + assert OBS_ENV_STATE in processed_obs + assert OBS_STATE in processed_obs # Check that original keys were removed assert "pixels" not in processed_obs @@ -308,7 +308,7 @@ def test_image_only_processing(): result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.image" in processed_obs + assert OBS_IMAGE in processed_obs assert len(processed_obs) == 1 @@ -323,7 +323,7 @@ def test_state_only_processing(): result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.state" in processed_obs + assert OBS_STATE in processed_obs assert "agent_pos" not in processed_obs @@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory): proc = VanillaObservationProcessorStep() features = { PipelineFeatureType.OBSERVATION: { - "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), }, } @@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory): assert ( OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] - == features[PipelineFeatureType.OBSERVATION]["observation.environment_state"] + == features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] ) assert ( OBS_STATE in out[PipelineFeatureType.OBSERVATION] diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 0d17fed00..6d056e4dc 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -35,6 +35,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -255,7 +256,7 @@ def test_step_through_with_dict(): pipeline = DataProcessorPipeline([step1, step2]) batch = { - "observation.image": None, + OBS_IMAGE: None, "action": None, "next.reward": 0.0, "next.done": False, @@ -1840,7 +1841,7 @@ def test_save_load_with_custom_converter_functions(): # Verify it uses default converters by checking with standard batch format batch = { - "observation.image": torch.randn(1, 3, 32, 32), + OBS_IMAGE: torch.randn(1, 3, 32, 32), "action": torch.randn(1, 7), "next.reward": torch.tensor([1.0]), "next.done": torch.tensor([False]), @@ -1851,7 +1852,7 @@ def test_save_load_with_custom_converter_functions(): # Should work with standard format (wouldn't work with custom converter) result = loaded(batch) # With new behavior, default to_output is _default_transition_to_batch, so result is batch dict - assert "observation.image" in result + assert OBS_IMAGE in result class NonCompliantStep: @@ -2075,10 +2076,10 @@ class AddObservationStateFeatures(ProcessorStep): self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: # State features (mix EE and a joint state) - features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float - features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float + features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float + features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float if self.add_front_image: - features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape + features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape return features @@ -2094,7 +2095,7 @@ def test_aggregate_joint_action_only(): ) # Expect only "action" with joint names - assert "action" in out and "observation.state" not in out + assert "action" in out and OBS_STATE not in out assert out["action"]["dtype"] == "float32" assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} assert out["action"]["shape"] == (len(out["action"]["names"]),) @@ -2108,7 +2109,7 @@ def test_aggregate_ee_action_and_observation_with_videos(): pipeline=rp, initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, use_videos=True, - patterns=["action.ee", "observation.state"], + patterns=["action.ee", OBS_STATE], ) # Action should pack only EE names @@ -2117,13 +2118,13 @@ def test_aggregate_ee_action_and_observation_with_videos(): assert out["action"]["dtype"] == "float32" # Observation state should pack both ee.x and j1.pos as a vector - assert "observation.state" in out - assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} - assert out["observation.state"]["dtype"] == "float32" + assert OBS_STATE in out + assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"} + assert out[OBS_STATE]["dtype"] == "float32" # Cameras from initial_features appear as videos for cam in ("front", "side"): - key = f"observation.images.{cam}" + key = f"{OBS_IMAGES}.{cam}" assert key in out assert out[key]["dtype"] == "video" assert out[key]["shape"] == initial[cam] @@ -2156,8 +2157,8 @@ def test_aggregate_images_when_use_videos_false(): patterns=None, ) - key = "observation.images.back" - key_front = "observation.images.front" + key = f"{OBS_IMAGES}.back" + key_front = f"{OBS_IMAGES}.front" assert key not in out assert key_front not in out @@ -2173,8 +2174,8 @@ def test_aggregate_images_when_use_videos_true(): patterns=None, ) - key = "observation.images.front" - key_back = "observation.images.back" + key = f"{OBS_IMAGES}.front" + key_back = f"{OBS_IMAGES}.back" assert key in out assert key_back in out assert out[key]["dtype"] == "video" @@ -2194,9 +2195,9 @@ def test_initial_camera_not_overridden_by_step_image(): pipeline=rp, initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, use_videos=True, - patterns=["observation.images.front"], + patterns=[f"{OBS_IMAGES}.front"], ) - key = "observation.images.front" + key = f"{OBS_IMAGES}.front" assert key in out assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 5f2b48576..c6aa303f1 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -28,6 +28,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.rename_processor import rename_stats +from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -121,13 +122,13 @@ def test_overlapping_rename(): def test_partial_rename(): """Test renaming only some keys.""" rename_map = { - "observation.state": "observation.proprio_state", - "pixels": "observation.image", + OBS_STATE: "observation.proprio_state", + "pixels": OBS_IMAGE, } processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { - "observation.state": torch.randn(10), + OBS_STATE: torch.randn(10), "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), "reward": 1.0, "info": {"episode": 1}, @@ -139,8 +140,8 @@ def test_partial_rename(): # Check renamed keys assert "observation.proprio_state" in processed_obs - assert "observation.image" in processed_obs - assert "observation.state" not in processed_obs + assert OBS_IMAGE in processed_obs + assert OBS_STATE not in processed_obs assert "pixels" not in processed_obs # Check unchanged keys @@ -174,8 +175,8 @@ def test_state_dict(): def test_integration_with_robot_processor(): """Test integration with RobotProcessor pipeline.""" rename_map = { - "agent_pos": "observation.state", - "pixels": "observation.image", + "agent_pos": OBS_STATE, + "pixels": OBS_IMAGE, } rename_processor = RenameObservationsProcessorStep(rename_map=rename_map) @@ -196,8 +197,8 @@ def test_integration_with_robot_processor(): processed_obs = result[TransitionKey.OBSERVATION] # Check renaming worked through pipeline - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs + assert OBS_STATE in processed_obs + assert OBS_IMAGE in processed_obs assert "agent_pos" not in processed_obs assert "pixels" not in processed_obs assert processed_obs["other_data"] == "preserve_me" @@ -210,8 +211,8 @@ def test_integration_with_robot_processor(): def test_save_and_load_pretrained(): """Test saving and loading processor with RobotProcessor.""" rename_map = { - "old_state": "observation.state", - "old_image": "observation.image", + "old_state": OBS_STATE, + "old_image": OBS_IMAGE, } processor = RenameObservationsProcessorStep(rename_map=rename_map) pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep") @@ -253,10 +254,10 @@ def test_save_and_load_pretrained(): result = loaded_pipeline(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs - assert processed_obs["observation.state"] == [1, 2, 3] - assert processed_obs["observation.image"] == "image_data" + assert OBS_STATE in processed_obs + assert OBS_IMAGE in processed_obs + assert processed_obs[OBS_STATE] == [1, 2, 3] + assert processed_obs[OBS_IMAGE] == "image_data" def test_registry_functionality(): @@ -317,8 +318,8 @@ def test_chained_rename_processors(): # Second processor: rename to final format processor2 = RenameObservationsProcessorStep( rename_map={ - "agent_position": "observation.state", - "camera_image": "observation.image", + "agent_position": OBS_STATE, + "camera_image": OBS_IMAGE, } ) @@ -342,8 +343,8 @@ def test_chained_rename_processors(): # After second processor final_obs = results[2][TransitionKey.OBSERVATION] - assert "observation.state" in final_obs - assert "observation.image" in final_obs + assert OBS_STATE in final_obs + assert OBS_IMAGE in final_obs assert final_obs["extra"] == "keep_me" # Original keys should be gone @@ -356,15 +357,15 @@ def test_chained_rename_processors(): def test_nested_observation_rename(): """Test renaming with nested observation structures.""" rename_map = { - "observation.images.left": "observation.camera.left_view", - "observation.images.right": "observation.camera.right_view", + f"{OBS_IMAGES}.left": "observation.camera.left_view", + f"{OBS_IMAGES}.right": "observation.camera.right_view", "observation.proprio": "observation.proprioception", } processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { - "observation.images.left": torch.randn(3, 64, 64), - "observation.images.right": torch.randn(3, 64, 64), + f"{OBS_IMAGES}.left": torch.randn(3, 64, 64), + f"{OBS_IMAGES}.right": torch.randn(3, 64, 64), "observation.proprio": torch.randn(7), "observation.gripper": torch.tensor([0.0]), # Not renamed } @@ -382,8 +383,8 @@ def test_nested_observation_rename(): assert "observation.gripper" in processed_obs # Check old keys removed - assert "observation.images.left" not in processed_obs - assert "observation.images.right" not in processed_obs + assert f"{OBS_IMAGES}.left" not in processed_obs + assert f"{OBS_IMAGES}.right" not in processed_obs assert "observation.proprio" not in processed_obs @@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory): # Chain two rename processors at the contract level processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"}) processor2 = RenameObservationsProcessorStep( - rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} + rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE} ) pipeline = DataProcessorPipeline([processor1, processor2]) @@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory): } out = pipeline.transform_features(initial_features=spec) - assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"} - assert ( - out[PipelineFeatureType.OBSERVATION]["observation.state"] - == spec[PipelineFeatureType.OBSERVATION]["pos"] - ) - assert ( - out[PipelineFeatureType.OBSERVATION]["observation.image"] - == spec[PipelineFeatureType.OBSERVATION]["img"] - ) + assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"} + assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"] + assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"] assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"] assert_contract_is_typed(out) def test_rename_stats_basic(): orig = { - "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}, "action": {"mean": np.array([0.0])}, } - mapping = {"observation.state": "observation.robot_state"} + mapping = {OBS_STATE: "observation.robot_state"} renamed = rename_stats(orig, mapping) - assert "observation.robot_state" in renamed and "observation.state" not in renamed + assert "observation.robot_state" in renamed and OBS_STATE not in renamed # Ensure deep copy: mutate original and verify renamed unaffected - orig["observation.state"]["mean"][0] = 42.0 + orig[OBS_STATE]["mean"][0] = 42.0 assert renamed["observation.robot_state"]["mean"][0] != 42.0 diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 9e6c8de2f..35bbcfd8a 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -11,7 +11,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_LANGUAGE +from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE from tests.utils import require_package @@ -503,16 +503,14 @@ def test_features_basic(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128) input_features = { - PipelineFeatureType.OBSERVATION: { - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) - }, + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } output_features = processor.transform_features(input_features) # Check that original features are preserved - assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION] + assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION] assert "action" in output_features[PipelineFeatureType.ACTION] # Check that tokenized features are added @@ -797,7 +795,7 @@ def test_device_detection_cpu(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with CPU tensors - observation = {"observation.state": torch.randn(10)} # CPU tensor + observation = {OBS_STATE: torch.randn(10)} # CPU tensor action = torch.randn(5) # CPU tensor transition = create_transition( observation=observation, action=action, complementary_data={"task": "test task"} @@ -821,7 +819,7 @@ def test_device_detection_cuda(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with CUDA tensors - observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor + observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor action = torch.randn(5).cuda() # CUDA tensor transition = create_transition( observation=observation, action=action, complementary_data={"task": "test task"} @@ -847,7 +845,7 @@ def test_device_detection_multi_gpu(): # Test with tensors on cuda:1 device = torch.device("cuda:1") - observation = {"observation.state": torch.randn(10).to(device)} + observation = {OBS_STATE: torch.randn(10).to(device)} action = torch.randn(5).to(device) transition = create_transition( observation=observation, action=action, complementary_data={"task": "multi gpu test"} @@ -943,7 +941,7 @@ def test_device_detection_preserves_dtype(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with float tensor (to test dtype isn't affected) - observation = {"observation.state": torch.randn(10, dtype=torch.float16)} + observation = {OBS_STATE: torch.randn(10, dtype=torch.float16)} transition = create_transition(observation=observation, complementary_data={"task": "dtype test"}) result = processor(transition) @@ -977,7 +975,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer): # Start with CPU tensors transition = create_transition( - observation={"observation.state": torch.randn(10)}, # CPU + observation={OBS_STATE: torch.randn(10)}, # CPU action=torch.randn(5), # CPU complementary_data={"task": "pipeline test"}, ) @@ -985,7 +983,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer): result = robot_processor(transition) # All tensors should end up on CUDA (moved by DeviceProcessorStep) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" # Tokenized tensors should also be on CUDA @@ -1005,8 +1003,8 @@ def test_simulated_accelerate_scenario(): # Simulate Accelerate scenario: batch already on GPU device = torch.device("cuda:0") observation = { - "observation.state": torch.randn(1, 10).to(device), # Batched, on GPU - "observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU + OBS_STATE: torch.randn(1, 10).to(device), # Batched, on GPU + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU } action = torch.randn(1, 5).to(device) # Batched, on GPU diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index aa9913bb2..ec67f1889 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -21,6 +21,7 @@ import pytest import torch from torch.multiprocessing import Event, Queue +from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -110,12 +111,12 @@ def test_push_transitions_to_transport_queue(): transitions = [] for i in range(3): transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, action=torch.randn(5), reward=torch.tensor(1.0 + i), done=torch.tensor(False), truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, complementary_info={"step": torch.tensor(i)}, ) transitions.append(transition) diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 43a6b0957..5d95dee04 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -24,6 +24,7 @@ from torch.multiprocessing import Event, Queue from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -33,12 +34,12 @@ def create_test_transitions(count: int = 3) -> list[Transition]: transitions = [] for i in range(count): transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, action=torch.randn(5), reward=torch.tensor(1.0 + i), done=torch.tensor(i == count - 1), # Last transition is done truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, ) transitions.append(transition) diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index b5254f393..6820d321f 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -22,11 +22,12 @@ import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_REPO_ID def state_dims() -> list[str]: - return ["observation.image", "observation.state"] + return [OBS_IMAGE, OBS_STATE] @pytest.fixture @@ -61,10 +62,10 @@ def create_random_image() -> torch.Tensor: def create_dummy_transition() -> dict: return { - "observation.image": create_random_image(), + OBS_IMAGE: create_random_image(), "action": torch.randn(4), "reward": torch.tensor(1.0), - "observation.state": torch.randn( + OBS_STATE: torch.randn( 10, ), "done": torch.tensor(False), @@ -98,8 +99,8 @@ def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayB def create_dummy_state() -> dict: return { - "observation.image": create_random_image(), - "observation.state": torch.randn( + OBS_IMAGE: create_random_image(), + OBS_STATE: torch.randn( 10, ), } @@ -180,7 +181,7 @@ def test_empty_buffer_sample_raises_error(replay_buffer): def test_zero_capacity_buffer_raises_error(): with pytest.raises(ValueError, match="Capacity must be greater than 0."): - ReplayBuffer(0, "cpu", ["observation", "next_observation"]) + ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"]) def test_add_transition(replay_buffer, dummy_state, dummy_action): @@ -203,7 +204,7 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action): def test_add_over_capacity(): - replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) + replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"]) dummy_state_1 = create_dummy_state() dummy_action_1 = create_dummy_action() @@ -373,7 +374,7 @@ def test_to_lerobot_dataset(tmp_path): assert ds.num_frames == 4 for j, value in enumerate(ds): - print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) + print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j])) for i in range(len(ds)): for feature, value in ds[i].items(): @@ -383,12 +384,12 @@ def test_to_lerobot_dataset(tmp_path): assert torch.equal(value, buffer.rewards[i]) elif feature == "next.done": assert torch.equal(value, buffer.dones[i]) - elif feature == "observation.image": + elif feature == OBS_IMAGE: # Tensor -> numpy is not precise, so we have some diff there # TODO: Check and fix it - torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) - elif feature == "observation.state": - assert torch.equal(value, buffer.states["observation.state"][i]) + torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003) + elif feature == OBS_STATE: + assert torch.equal(value, buffer.states[OBS_STATE][i]) def test_from_lerobot_dataset(tmp_path): @@ -436,14 +437,14 @@ def test_from_lerobot_dataset(tmp_path): ) assert torch.equal( - replay_buffer.states["observation.state"][: len(replay_buffer)], - reconverted_buffer.states["observation.state"][: len(replay_buffer)], + replay_buffer.states[OBS_STATE][: len(replay_buffer)], + reconverted_buffer.states[OBS_STATE][: len(replay_buffer)], ), "State should be the same after converting to dataset and return back" for i in range(4): torch.testing.assert_close( - replay_buffer.states["observation.image"][i], - reconverted_buffer.states["observation.image"][i], + replay_buffer.states[OBS_IMAGE][i], + reconverted_buffer.states[OBS_IMAGE][i], rtol=0.4, atol=0.004, ) @@ -454,16 +455,16 @@ def test_from_lerobot_dataset(tmp_path): next_index = (i + 1) % 4 torch.testing.assert_close( - replay_buffer.states["observation.image"][next_index], - reconverted_buffer.next_states["observation.image"][i], + replay_buffer.states[OBS_IMAGE][next_index], + reconverted_buffer.next_states[OBS_IMAGE][i], rtol=0.4, atol=0.004, ) for i in range(2, 4): assert torch.equal( - replay_buffer.states["observation.state"][i], - reconverted_buffer.next_states["observation.state"][i], + replay_buffer.states[OBS_STATE][i], + reconverted_buffer.next_states[OBS_STATE][i], ) @@ -563,10 +564,8 @@ def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_functio replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) sampled_transitions = replay_buffer.sample(1) - assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( - "Image augmentations should be applied" - ) - assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), ( + assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied" + assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), ( "Image augmentations should be applied" ) @@ -580,8 +579,8 @@ def test_check_image_augmentations_with_drq_and_default_image_augmentation_funct # Let's check that it doesn't fail and shapes are correct sampled_transitions = replay_buffer.sample(1) - assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) - assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) + assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84) + assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84) def test_random_crop_vectorized_basic(): @@ -620,7 +619,7 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: buffer = ReplayBuffer( capacity=capacity, device="cpu", - state_keys=["observation.image", "observation.state"], + state_keys=[OBS_IMAGE, OBS_STATE], storage_device="cpu", ) @@ -628,8 +627,8 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: img = torch.ones(3, 128, 128) * i state_vec = torch.arange(11).float() + i state = { - "observation.image": img, - "observation.state": state_vec, + OBS_IMAGE: img, + OBS_STATE: state_vec, } buffer.add( state=state, @@ -648,14 +647,14 @@ def test_async_iterator_shapes_basic(): iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) batch = next(iterator) - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] + images = batch["state"][OBS_IMAGE] + states = batch["state"][OBS_STATE] assert images.shape == (batch_size, 3, 128, 128) assert states.shape == (batch_size, 11) - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] + next_images = batch["next_state"][OBS_IMAGE] + next_states = batch["next_state"][OBS_STATE] assert next_images.shape == (batch_size, 3, 128, 128) assert next_states.shape == (batch_size, 11) @@ -668,13 +667,13 @@ def test_async_iterator_multiple_iterations(): for _ in range(5): batch = next(iterator) - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] + images = batch["state"][OBS_IMAGE] + states = batch["state"][OBS_STATE] assert images.shape == (batch_size, 3, 128, 128) assert states.shape == (batch_size, 11) - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] + next_images = batch["next_state"][OBS_IMAGE] + next_states = batch["next_state"][OBS_STATE] assert next_images.shape == (batch_size, 3, 128, 128) assert next_states.shape == (batch_size, 11) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 29b7bf70a..65a97c6a3 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -6,6 +6,7 @@ import numpy as np import pytest from lerobot.processor import TransitionKey +from lerobot.utils.constants import OBS_STATE @pytest.fixture @@ -72,7 +73,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # Build EnvTransition dict obs = { - "observation.state.temperature": np.float32(25.0), + f"{OBS_STATE}.temperature": np.float32(25.0), # CHW image should be converted to HWC for rr.Image "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), } @@ -97,7 +98,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # - action.throttle -> Scalar # - action.vector_0, action.vector_1 -> Scalars expected_keys = { - "observation.state.temperature", + f"{OBS_STATE}.temperature", "observation.camera", "action.throttle", "action.vector_0", @@ -106,7 +107,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): assert set(_keys(calls)) == expected_keys # Check scalar types and values - temp_obj = _obj_for(calls, "observation.state.temperature") + temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature") assert type(temp_obj).__name__ == "DummyScalar" assert temp_obj.value == pytest.approx(25.0) From 9627765ce20ac7404898394bcd18a48b077ec82c Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Fri, 26 Sep 2025 11:53:27 +0200 Subject: [PATCH 2/3] chore(mypy): add mypy configuration and module overrides for gradual type checking (#2052) --- pyproject.toml | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d2f1e502a..44e29043b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -267,8 +267,83 @@ default.extend-ignore-identifiers-re = [ # color = true # paths = ["src/lerobot"] +# TODO: Enable mypy gradually module by module across multiple PRs +# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations + # [tool.mypy] # python_version = "3.10" # warn_return_any = true # warn_unused_configs = true # ignore_missing_imports = false +# strict = true +# disallow_untyped_defs = true +# disallow_incomplete_defs = true +# check_untyped_defs = true + +# [[tool.mypy.overrides]] +# module = "lerobot.utils.*" +# # include = "src/lerobot/utils/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.configs.*" +# # include = "src/lerobot/configs/**/*.py" + +# # Data processing modules +# [[tool.mypy.overrides]] +# module = "lerobot.processor.*" +# # include = "src/lerobot/processor/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.datasets.*" +# # include = "src/lerobot/datasets/**/*.py" + +# # Core machine learning modules +# [[tool.mypy.overrides]] +# module = "lerobot.optim.*" +# # include = "src/lerobot/optim/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.model.*" +# # include = "src/lerobot/model/**/*.py" + +# # Hardware interfaces +# [[tool.mypy.overrides]] +# module = "lerobot.cameras.*" +# # include = "src/lerobot/cameras/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.motors.*" +# # include = "src/lerobot/motors/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.robots.*" +# # include = "src/lerobot/robots/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.teleoperators.*" +# # include = "src/lerobot/teleoperators/**/*.py" + +# # Complex modules (enable these last) +# [[tool.mypy.overrides]] +# module = "lerobot.policies.*" +# # include = "src/lerobot/policies/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.rl.*" +# # include = "src/lerobot/rl/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.envs.*" +# # include = "src/lerobot/envs/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.async_inference.*" +# # include = "src/lerobot/async_inference/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.transport.*" +# # include = "src/lerobot/transport/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.scripts.*" +# # include = "src/lerobot/scripts/**/*.py" From d2782cf66b0b22d9bfae8912ee2bbc7f63c3615e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 26 Sep 2025 13:33:18 +0200 Subject: [PATCH 3/3] chore: replace hard-coded action values with constants throughout all the source code (#2055) * chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code --- examples/backward_compatibility/replay.py | 7 +- examples/lekiwi/evaluate.py | 4 +- examples/lekiwi/record.py | 4 +- examples/lekiwi/replay.py | 5 +- examples/phone_to_so100/replay.py | 5 +- examples/so100_to_so100_EE/replay.py | 5 +- src/lerobot/datasets/factory.py | 4 +- src/lerobot/datasets/pipeline_features.py | 2 +- src/lerobot/datasets/utils.py | 6 +- src/lerobot/envs/configs.py | 16 +- src/lerobot/policies/act/modeling_act.py | 6 +- .../policies/diffusion/modeling_diffusion.py | 10 +- .../conversion_scripts/compare_with_jax.py | 6 +- src/lerobot/policies/sac/configuration_sac.py | 2 +- src/lerobot/policies/sac/modeling_sac.py | 6 +- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 2 +- src/lerobot/processor/converters.py | 8 +- .../processor/migrate_policy_normalization.py | 7 +- src/lerobot/processor/normalize_processor.py | 3 +- src/lerobot/processor/policy_robot_bridge.py | 3 +- src/lerobot/rl/buffer.py | 18 +-- src/lerobot/rl/gym_manipulator.py | 10 +- src/lerobot/rl/learner.py | 11 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 4 +- src/lerobot/scripts/lerobot_dataset_viz.py | 8 +- src/lerobot/scripts/lerobot_eval.py | 8 +- src/lerobot/scripts/lerobot_record.py | 6 +- src/lerobot/scripts/lerobot_replay.py | 7 +- src/lerobot/utils/transition.py | 4 +- tests/datasets/test_dataset_utils.py | 28 ++-- tests/datasets/test_datasets.py | 18 +-- tests/datasets/test_streaming.py | 5 +- tests/fixtures/constants.py | 4 +- tests/policies/test_policies.py | 6 +- tests/policies/test_sac_config.py | 10 +- tests/policies/test_sac_policy.py | 10 +- tests/processor/test_batch_conversion.py | 28 ++-- tests/processor/test_converters.py | 14 +- tests/processor/test_device_processor.py | 4 +- tests/processor/test_migration_detection.py | 4 +- tests/processor/test_normalize_processor.py | 140 +++++++++--------- tests/processor/test_pipeline.py | 28 ++-- tests/processor/test_policy_robot_bridge.py | 15 +- tests/processor/test_rename_processor.py | 4 +- tests/processor/test_tokenizer_processor.py | 6 +- tests/transport/test_transport_utils.py | 3 +- tests/utils/test_replay_buffer.py | 10 +- 47 files changed, 269 insertions(+), 255 deletions(-) diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 6c680f204..6bca0570f 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -44,6 +44,7 @@ from lerobot.robots import ( # noqa: F401 so100_follower, so101_follower, ) +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( init_logging, @@ -78,16 +79,16 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - actions = dataset.hf_dataset.select_columns("action") + actions = dataset.hf_dataset.select_columns(ACTION) robot.connect() log_say("Replaying episode", cfg.play_sounds, blocking=True) for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() - action_array = actions[idx]["action"] + action_array = actions[idx][ACTION] action = {} - for i, name in enumerate(dataset.features["action"]["names"]): + for i, name in enumerate(dataset.features[ACTION]["names"]): key = f"{name.removeprefix('main_')}.pos" action[key] = action_array[i].item() diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 174486eb8..8a62d92a9 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -21,7 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.scripts.lerobot_record import record_loop -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -42,7 +42,7 @@ robot = LeKiwiClient(robot_config) policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features -action_features = hw_to_dataset_features(robot.action_features, "action") +action_features = hw_to_dataset_features(robot.action_features, ACTION) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 471cb3668..9070741bf 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -22,7 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -48,7 +48,7 @@ keyboard = KeyboardTeleop(keyboard_config) teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() # Configure the dataset features -action_features = hw_to_dataset_features(robot.action_features, "action") +action_features = hw_to_dataset_features(robot.action_features, ACTION) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index 0f8eabdff..3ae915286 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -19,6 +19,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -34,7 +35,7 @@ robot = LeKiwiClient(robot_config) dataset = LeRobotDataset("/", episodes=[EPISODE_IDX]) # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) -actions = episode_frames.select_columns("action") +actions = episode_frames.select_columns(ACTION) # Connect to the robot robot.connect() @@ -49,7 +50,7 @@ for idx in range(len(episode_frames)): # Get recorded action from dataset action = { - name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) } # Send action to robot diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 80c65a4c2..f1181143c 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -28,6 +28,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -66,7 +67,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) -actions = episode_frames.select_columns("action") +actions = episode_frames.select_columns(ACTION) # Connect to the robot robot.connect() @@ -81,7 +82,7 @@ for idx in range(len(episode_frames)): # Get recorded action from dataset ee_action = { - name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) } # Get robot observation diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 6987f4839..ea78d4e66 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -29,6 +29,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -67,7 +68,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) -actions = episode_frames.select_columns("action") +actions = episode_frames.select_columns(ACTION) # Connect to the robot robot.connect() @@ -82,7 +83,7 @@ for idx in range(len(episode_frames)): # Get recorded action from dataset ee_action = { - name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) } # Get robot observation diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 2bac84aed..f74b6ac4f 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms -from lerobot.utils.constants import OBS_PREFIX +from lerobot.utils.constants import ACTION, OBS_PREFIX IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -57,7 +57,7 @@ def resolve_delta_timestamps( for key in ds_meta.features: if key == "next.reward" and cfg.reward_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] - if key == "action" and cfg.action_delta_indices is not None: + if key == ACTION and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 13555dd31..4fad7bd20 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -132,7 +132,7 @@ def aggregate_pipeline_dataset_features( # Convert the processed features into the final dataset format. dataset_features = {} if processed_features[ACTION]: - dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) + dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos)) if processed_features[OBS_STR]: dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos)) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 96ae2eca6..35313bde5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -43,7 +43,7 @@ from lerobot.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) -from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR from lerobot.utils.utils import is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk @@ -646,7 +646,7 @@ def hw_to_dataset_features( } cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} - if joint_fts and prefix == "action": + if joint_fts and prefix == ACTION: features[prefix] = { "dtype": "float32", "shape": (len(joint_fts),), @@ -733,7 +733,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea type = FeatureType.ENV elif key.startswith(OBS_STR): type = FeatureType.STATE - elif key.startswith("action"): + elif key.startswith(ACTION): type = FeatureType.ACTION else: continue diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 4456c51a5..8cbc597dc 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -53,12 +53,12 @@ class AlohaEnv(EnvConfig): render_mode: str = "rgb_array" features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "top": f"{OBS_IMAGE}.top", "pixels/top": f"{OBS_IMAGES}.top", @@ -93,13 +93,13 @@ class PushtEnv(EnvConfig): visualization_height: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "environment_state": OBS_ENV_STATE, "pixels": OBS_IMAGE, @@ -135,13 +135,13 @@ class XarmEnv(EnvConfig): visualization_height: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)), "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "pixels": OBS_IMAGE, } @@ -259,12 +259,12 @@ class LiberoEnv(EnvConfig): camera_name_mapping: dict[str, str] | None = (None,) features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "pixels/agentview_image": f"{OBS_IMAGES}.image", "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2", diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index f8261bb7f..e987f9070 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -394,7 +394,7 @@ class ACT(nn.Module): latent dimension. """ if self.config.use_vae and self.training: - assert "action" in batch, ( + assert ACTION in batch, ( "actions must be provided when using the variational objective in training mode." ) @@ -404,7 +404,7 @@ class ACT(nn.Module): batch_size = batch[OBS_ENV_STATE].shape[0] # Prepare the latent for input to the transformer encoder. - if self.config.use_vae and "action" in batch and self.training: + if self.config.use_vae and ACTION in batch and self.training: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size @@ -412,7 +412,7 @@ class ACT(nn.Module): if self.config.robot_state_feature: robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) + action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D) if self.config.robot_state_feature: vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index af1327ba2..ad808d7c7 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -82,7 +82,7 @@ class DiffusionPolicy(PreTrainedPolicy): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { OBS_STATE: deque(maxlen=self.config.n_obs_steps), - "action": deque(maxlen=self.config.n_action_steps), + ACTION: deque(maxlen=self.config.n_action_steps), } if self.config.image_features: self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) @@ -306,10 +306,10 @@ class DiffusionModel(nn.Module): } """ # Input validation. - assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"}) + assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"}) assert OBS_IMAGES in batch or OBS_ENV_STATE in batch n_obs_steps = batch[OBS_STATE].shape[1] - horizon = batch["action"].shape[1] + horizon = batch[ACTION].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps @@ -317,7 +317,7 @@ class DiffusionModel(nn.Module): global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # Forward diffusion. - trajectory = batch["action"] + trajectory = batch[ACTION] # Sample noise to add to the trajectory. eps = torch.randn(trajectory.shape, device=trajectory.device) # Sample a random noising timestep for each item in the batch. @@ -338,7 +338,7 @@ class DiffusionModel(nn.Module): if self.config.prediction_type == "epsilon": target = eps elif self.config.prediction_type == "sample": - target = batch["action"] + target = batch[ACTION] else: raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py index fe9865697..dad7d002e 100644 --- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py @@ -21,7 +21,7 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.factory import make_policy -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE def display(tensor: torch.Tensor): @@ -73,7 +73,7 @@ def main(): for cam_key, uint_chw_array in example["images"].items(): batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 batch[OBS_STATE] = torch.from_numpy(example["state"]) - batch["action"] = torch.from_numpy(outputs["actions"]) + batch[ACTION] = torch.from_numpy(outputs["actions"]) batch["task"] = example["prompt"] if model_name == "pi0_aloha_towel": @@ -117,7 +117,7 @@ def main(): actions.append(action) actions = torch.stack(actions, dim=1) - pi_actions = batch["action"] + pi_actions = batch[ACTION] print("actions") display(actions) print() diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py index a42758b85..6b5ad5b59 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/sac/configuration_sac.py @@ -225,7 +225,7 @@ class SACConfig(PreTrainedConfig): "You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features" ) - if "action" not in self.output_features: + if ACTION not in self.output_features: raise ValueError("You must provide 'action' in the output features") @property diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index a6ed79d4e..c66044406 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -31,7 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.utils import get_device_from_parameters -from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension @@ -51,7 +51,7 @@ class SACPolicy( self.config = config # Determine action dimension and initialize all components - continuous_action_dim = config.output_features["action"].shape[0] + continuous_action_dim = config.output_features[ACTION].shape[0] self._init_encoders() self._init_critics(continuous_action_dim) self._init_actor(continuous_action_dim) @@ -158,7 +158,7 @@ class SACPolicy( The computed loss tensor """ # Extract common components from batch - actions: Tensor = batch["action"] + actions: Tensor = batch[ACTION] observations: dict[str, Tensor] = batch["state"] observation_features: Tensor = batch.get("observation_feature") diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 4b5e8b7bd..195cf6154 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -92,7 +92,7 @@ class TDMPCPolicy(PreTrainedPolicy): """ self._queues = { OBS_STATE: deque(maxlen=1), - "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), + ACTION: deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), } if self.config.image_features: self._queues[OBS_IMAGE] = deque(maxlen=1) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 2e80cf4bb..68f9dd6fa 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,7 +23,7 @@ from typing import Any import numpy as np import torch -from lerobot.utils.constants import OBS_PREFIX +from lerobot.utils.constants import ACTION, OBS_PREFIX from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -344,7 +344,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: if not isinstance(batch, dict): raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") - action = batch.get("action") + action = batch.get(ACTION) if action is not None and not isinstance(action, PolicyAction): raise ValueError(f"Action should be a PolicyAction type got {type(action)}") @@ -354,7 +354,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: return create_transition( observation=observation_keys if observation_keys else None, - action=batch.get("action"), + action=batch.get(ACTION), reward=batch.get("next.reward", 0.0), done=batch.get("next.done", False), truncated=batch.get("next.truncated", False), @@ -379,7 +379,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") batch = { - "action": transition.get(TransitionKey.ACTION), + ACTION: transition.get(TransitionKey.ACTION), "next.reward": transition.get(TransitionKey.REWARD, 0.0), "next.done": transition.get(TransitionKey.DONE, False), "next.truncated": transition.get(TransitionKey.TRUNCATED, False), diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 131f799d6..319145d1a 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -59,6 +59,7 @@ from safetensors.torch import load_file as load_safetensors from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors +from lerobot.utils.constants import ACTION def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: @@ -196,7 +197,7 @@ def detect_features_and_norm_modes( feature_type = FeatureType.VISUAL elif "state" in key: feature_type = FeatureType.STATE - elif "action" in key: + elif ACTION in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE # Default @@ -215,7 +216,7 @@ def detect_features_and_norm_modes( feature_type = FeatureType.VISUAL elif "state" in key or "joint" in key or "position" in key: feature_type = FeatureType.STATE - elif "action" in key: + elif ACTION in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE @@ -321,7 +322,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ feature_type = FeatureType.VISUAL elif "state" in key: feature_type = FeatureType.STATE - elif "action" in key: + elif ACTION in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index bece54f0b..c4ded722f 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -26,6 +26,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor from .core import EnvTransition, PolicyAction, TransitionKey @@ -272,7 +273,7 @@ class _NormalizationMixin: Returns: The transformed action tensor. """ - processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse) + processed_action = self._apply_transform(action, ACTION, FeatureType.ACTION, inverse=inverse) return processed_action def _apply_transform( diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py index 74c534998..845ee065a 100644 --- a/src/lerobot/processor/policy_robot_bridge.py +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -5,6 +5,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction +from lerobot.utils.constants import ACTION @dataclass @@ -23,7 +24,7 @@ class RobotActionToPolicyActionProcessorStep(ActionProcessorStep): return asdict(self) def transform_features(self, features): - features[PipelineFeatureType.ACTION]["action"] = PolicyFeature( + features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature( type=FeatureType.ACTION, shape=(len(self.motor_names),) ) return features diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index fbf36de36..b572bbce5 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import OBS_IMAGE +from lerobot.utils.constants import ACTION, OBS_IMAGE from lerobot.utils.transition import Transition @@ -467,7 +467,7 @@ class ReplayBuffer: if list_transition: first_transition = list_transition[0] first_state = {k: v.to(device) for k, v in first_transition["state"].items()} - first_action = first_transition["action"].to(device) + first_action = first_transition[ACTION].to(device) # Get complementary info if available first_complementary_info = None @@ -492,7 +492,7 @@ class ReplayBuffer: elif isinstance(v, torch.Tensor): data[k] = v.to(storage_device) - action = data["action"] + action = data[ACTION] replay_buffer.add( state=data["state"], @@ -530,8 +530,8 @@ class ReplayBuffer: # Add "action" sample_action = self.actions[0] - act_info = guess_feature_info(t=sample_action, name="action") - features["action"] = act_info + act_info = guess_feature_info(t=sample_action, name=ACTION) + features[ACTION] = act_info # Add "reward" and "done" features["next.reward"] = {"dtype": "float32", "shape": (1,)} @@ -577,7 +577,7 @@ class ReplayBuffer: frame_dict[key] = self.states[key][actual_idx].cpu() # Fill action, reward, done - frame_dict["action"] = self.actions[actual_idx].cpu() + frame_dict[ACTION] = self.actions[actual_idx].cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict["task"] = task_name @@ -668,7 +668,7 @@ class ReplayBuffer: current_state[key] = val.unsqueeze(0) # Add batch dimension # ----- 2) Action ----- - action = current_sample["action"].unsqueeze(0) # Add batch dimension + action = current_sample[ACTION].unsqueeze(0) # Add batch dimension # ----- 3) Reward and done ----- reward = float(current_sample["next.reward"].item()) # ensure float @@ -788,8 +788,8 @@ def concatenate_batch_transitions( } # Concatenate basic fields - left_batch_transitions["action"] = torch.cat( - [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + left_batch_transitions[ACTION] = torch.cat( + [left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0 ) left_batch_transitions["reward"] = torch.cat( [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 393135708..fa9f4e3e1 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -73,7 +73,7 @@ from lerobot.teleoperators import ( ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -601,7 +601,7 @@ def control_loop( if cfg.mode == "record": action_features = teleop_device.action_features features = { - "action": action_features, + ACTION: action_features, "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, "next.done": {"dtype": "bool", "shape": (1,), "names": None}, } @@ -672,7 +672,7 @@ def control_loop( ) frame = { **observations, - "action": action_to_record.cpu(), + ACTION: action_to_record.cpu(), "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), "next.done": np.array([terminated or truncated], dtype=bool), } @@ -733,7 +733,7 @@ def replay_trajectory( download_videos=False, ) episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode) - actions = episode_frames.select_columns("action") + actions = episode_frames.select_columns(ACTION) _, info = env.reset() @@ -741,7 +741,7 @@ def replay_trajectory( start_time = time.perf_counter() transition = create_transition( observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {}, - action=action_data["action"], + action=action_data[ACTION], ) transition = action_processor(transition) env.step(transition[TransitionKey.ACTION]) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 0faa460ef..b7cfdb30c 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -80,6 +80,7 @@ from lerobot.transport.utils import ( state_to_bytes, ) from lerobot.utils.constants import ( + ACTION, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, @@ -402,7 +403,7 @@ def add_actor_information_and_train( left_batch_transitions=batch, right_batch_transition=batch_offline ) - actions = batch["action"] + actions = batch[ACTION] rewards = batch["reward"] observations = batch["state"] next_observations = batch["next_state"] @@ -415,7 +416,7 @@ def add_actor_information_and_train( # Create a batch dictionary with all required elements for the forward method forward_batch = { - "action": actions, + ACTION: actions, "reward": rewards, "state": observations, "next_state": next_observations, @@ -460,7 +461,7 @@ def add_actor_information_and_train( left_batch_transitions=batch, right_batch_transition=batch_offline ) - actions = batch["action"] + actions = batch[ACTION] rewards = batch["reward"] observations = batch["state"] next_observations = batch["next_state"] @@ -474,7 +475,7 @@ def add_actor_information_and_train( # Create a batch dictionary with all required elements for the forward method forward_batch = { - "action": actions, + ACTION: actions, "reward": rewards, "state": observations, "next_state": next_observations, @@ -1155,7 +1156,7 @@ def process_transitions( # Skip transitions with NaN values if check_nan_in_transition( observations=transition["state"], - actions=transition["action"], + actions=transition[ACTION], next_state=transition["next_state"], ): logging.warning("[LEARNER] NaN detected in transition, skipping") diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 392d6d575..19744e244 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -23,7 +23,7 @@ from typing import Any import cv2 import numpy as np -from lerobot.utils.constants import OBS_STATE +from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -330,7 +330,7 @@ class LeKiwiClient(Robot): actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32) action_sent = {key: actions[i] for i, key in enumerate(self._state_order)} - action_sent["action"] = actions + action_sent[ACTION] = actions return action_sent def disconnect(self): diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 5c0d31f73..adff5c085 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -75,7 +75,7 @@ import torch.utils.data import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import OBS_STATE +from lerobot.utils.constants import ACTION, OBS_STATE class EpisodeSampler(torch.utils.data.Sampler): @@ -157,9 +157,9 @@ def visualize_dataset( rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) # display each dimension of action space (e.g. actuators command) - if "action" in batch: - for dim_idx, val in enumerate(batch["action"][i]): - rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + if ACTION in batch: + for dim_idx, val in enumerate(batch[ACTION][i]): + rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item())) # display each dimension of observed state space (e.g. agent position in joint space) if OBS_STATE in batch: diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 310f771a9..882aeacc3 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -81,7 +81,7 @@ from lerobot.envs.utils import ( from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -213,7 +213,7 @@ def rollout( # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. ret = { - "action": torch.stack(all_actions, dim=1), + ACTION: torch.stack(all_actions, dim=1), "reward": torch.stack(all_rewards, dim=1), "success": torch.stack(all_successes, dim=1), "done": torch.stack(all_dones, dim=1), @@ -440,14 +440,14 @@ def _compile_episode_data( """ ep_dicts = [] total_frames = 0 - for ep_ix in range(rollout_data["action"].shape[0]): + for ep_ix in range(rollout_data[ACTION].shape[0]): # + 2 to include the first done frame and the last observation frame. num_frames = done_indices[ep_ix].item() + 2 total_frames += num_frames # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. ep_dict = { - "action": rollout_data["action"][ep_ix, : num_frames - 1], + ACTION: rollout_data[ACTION][ep_ix, : num_frames - 1], "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), "frame_index": torch.arange(0, num_frames - 1, 1), "timestamp": torch.arange(0, num_frames - 1, 1) / fps, diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index f1d026a39..d097a9d2f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -109,7 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401 so101_leader, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import ( init_keyboard_listener, is_headless, @@ -319,7 +319,7 @@ def record_loop( robot_type=robot.robot_type, ) - action_names = dataset.features["action"]["names"] + action_names = dataset.features[ACTION]["names"] act_processed_policy: RobotAction = { f"{name}": float(action_values[i]) for i, name in enumerate(action_names) } @@ -361,7 +361,7 @@ def record_loop( # Write to dataset if dataset is not None: - action_frame = build_dataset_frame(dataset.features, action_values, prefix="action") + action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION) frame = {**observation_frame, **action_frame, "task": single_task} dataset.add_frame(frame) diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 6761e3f4f..b899745b6 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401 so100_follower, so101_follower, ) +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( init_logging, @@ -99,7 +100,7 @@ def replay(cfg: ReplayConfig): # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode) - actions = episode_frames.select_columns("action") + actions = episode_frames.select_columns(ACTION) robot.connect() @@ -107,9 +108,9 @@ def replay(cfg: ReplayConfig): for idx in range(len(episode_frames)): start_episode_t = time.perf_counter() - action_array = actions[idx]["action"] + action_array = actions[idx][ACTION] action = {} - for i, name in enumerate(dataset.features["action"]["names"]): + for i, name in enumerate(dataset.features[ACTION]["names"]): action[name] = action_array[i] robot_obs = robot.get_observation() diff --git a/src/lerobot/utils/transition.py b/src/lerobot/utils/transition.py index db413c388..e874bd096 100644 --- a/src/lerobot/utils/transition.py +++ b/src/lerobot/utils/transition.py @@ -18,6 +18,8 @@ from typing import TypedDict import torch +from lerobot.utils.constants import ACTION + class Transition(TypedDict): state: dict[str, torch.Tensor] @@ -39,7 +41,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr } # Move action to device - transition["action"] = transition["action"].to(device, non_blocking=non_blocking) + transition[ACTION] = transition[ACTION].to(device, non_blocking=non_blocking) # Move reward and done if they are tensors if isinstance(transition["reward"], torch.Tensor): diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index c0b07ca65..99b832e55 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -21,7 +21,7 @@ from huggingface_hub import DatasetCard from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch -from lerobot.utils.constants import OBS_IMAGES +from lerobot.utils.constants import ACTION, OBS_IMAGES def test_default_parameters(): @@ -59,14 +59,14 @@ def test_calculate_episode_data_index(): def test_merge_simple_vectors(): g1 = { - "action": { + ACTION: { "dtype": "float32", "shape": (2,), "names": ["ee.x", "ee.y"], } } g2 = { - "action": { + ACTION: { "dtype": "float32", "shape": (2,), "names": ["ee.y", "ee.z"], @@ -75,23 +75,23 @@ def test_merge_simple_vectors(): out = combine_feature_dicts(g1, g2) - assert "action" in out - assert out["action"]["dtype"] == "float32" + assert ACTION in out + assert out[ACTION]["dtype"] == "float32" # Names merged with preserved order and de-dupuplication - assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"] + assert out[ACTION]["names"] == ["ee.x", "ee.y", "ee.z"] # Shape correctly recomputed from names length - assert out["action"]["shape"] == (3,) + assert out[ACTION]["shape"] == (3,) def test_merge_multiple_groups_order_and_dedup(): - g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} - g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} - g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} + g1 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} + g2 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} + g3 = {ACTION: {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} out = combine_feature_dicts(g1, g2, g3) - assert out["action"]["names"] == ["a", "b", "c", "d"] - assert out["action"]["shape"] == (4,) + assert out[ACTION]["names"] == ["a", "b", "c", "d"] + assert out[ACTION]["shape"] == (4,) def test_non_vector_last_wins_for_images(): @@ -117,8 +117,8 @@ def test_non_vector_last_wins_for_images(): def test_dtype_mismatch_raises(): - g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}} - g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}} + g1 = {ACTION: {"dtype": "float32", "shape": (1,), "names": ["a"]}} + g2 = {ACTION: {"dtype": "float64", "shape": (1,), "names": ["b"]}} with pytest.raises(ValueError, match="dtype mismatch for 'action'"): _ = combine_feature_dicts(g1, g2) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 1d461c8ba..fcfef677b 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -46,7 +46,7 @@ from lerobot.datasets.utils import ( from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -75,7 +75,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) - action_features = hw_to_dataset_features(robot.action_features, "action", True) + action_features = hw_to_dataset_features(robot.action_features, ACTION, True) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True) dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" @@ -393,7 +393,7 @@ def test_factory(env_name, repo_id, policy_name): item = dataset[0] keys_ndim_required = [ - ("action", 1, True), + (ACTION, 1, True), ("episode_index", 0, True), ("frame_index", 0, True), ("timestamp", 0, True), @@ -668,7 +668,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], }, - "action": { + ACTION: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], @@ -775,7 +775,7 @@ def test_update_chunk_settings_video_dataset(tmp_path): "shape": (480, 640, 3), "names": ["height", "width", "channels"], }, - "action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, + ACTION: {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, } # Create video dataset @@ -842,7 +842,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact """Test episode metadata consistency across multiple episodes.""" features = { "state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]}, - "action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]}, + ACTION: {"dtype": "float32", "shape": (2,), "names": ["v", "w"]}, } dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) @@ -852,7 +852,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact for episode_idx in range(num_episodes): for _ in range(frames_per_episode[episode_idx]): - dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]}) + dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]}) dataset.save_episode() # Load and validate episode metadata @@ -927,7 +927,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) """Test that statistics are properly computed and stored for all features.""" features = { "state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]}, - "action": {"dtype": "float32", "shape": (1,), "names": ["force"]}, + ACTION: {"dtype": "float32", "shape": (1,), "names": ["force"]}, } dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) @@ -941,7 +941,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) for frame_idx in range(frames_per_episode[episode_idx]): state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32) action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32) - dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"}) + dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"}) dataset.save_episode() loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py index 506be3ecf..1bd4c1787 100644 --- a/tests/datasets/test_streaming.py +++ b/tests/datasets/test_streaming.py @@ -19,6 +19,7 @@ import torch from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.utils import safe_shard +from lerobot.utils.constants import ACTION from tests.fixtures.constants import DUMMY_REPO_ID @@ -234,7 +235,7 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_ delta_timestamps = { camera_key: state_deltas, "state": state_deltas, - "action": action_deltas, + ACTION: action_deltas, } ds = lerobot_dataset_factory( @@ -319,7 +320,7 @@ def test_frames_with_delta_consistency_with_shards( delta_timestamps = { camera_key: state_deltas, "state": state_deltas, - "action": action_deltas, + ACTION: action_deltas, } ds = lerobot_dataset_factory( diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 973c5b050..35d8776ce 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" DUMMY_ROBOT_TYPE = "dummy_robot" DUMMY_MOTOR_FEATURES = { - "action": { + ACTION: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 7752ad63f..34fa89390 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -59,7 +59,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p }, } motor_features = { - "action": { + ACTION: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], @@ -287,7 +287,7 @@ def test_multikey_construction(multikey: bool): ), } output_features = { - "action": PolicyFeature( + ACTION: PolicyFeature( type=FeatureType.ACTION, shape=(5,), ), @@ -304,7 +304,7 @@ def test_multikey_construction(multikey: bool): output_features = {} output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,)) - output_features["action"] = PolicyFeature( + output_features[ACTION] = PolicyFeature( type=FeatureType.ACTION, shape=(5,), ) diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index 59ed4af65..be6a8d26e 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -25,7 +25,7 @@ from lerobot.policies.sac.configuration_sac import ( PolicyConfig, SACConfig, ) -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def test_sac_config_default_initialization(): @@ -46,7 +46,7 @@ def test_sac_config_default_initialization(): "min": [0.0, 0.0], "max": [1.0, 1.0], }, - "action": { + ACTION: { "min": [0.0, 0.0, 0.0], "max": [1.0, 1.0, 1.0], }, @@ -99,7 +99,7 @@ def test_sac_config_default_initialization(): "min": [0.0, 0.0], "max": [1.0, 1.0], }, - "action": { + ACTION: { "min": [0.0, 0.0, 0.0], "max": [1.0, 1.0, 1.0], }, @@ -193,7 +193,7 @@ def test_sac_config_custom_initialization(): def test_validate_features(): config = SACConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) config.validate_features() @@ -201,7 +201,7 @@ def test_validate_features(): def test_validate_features_missing_observation(): config = SACConfig( input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) with pytest.raises( ValueError, match="You must provide either 'observation.state' or an image observation" diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 71e45e055..8576883bd 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -23,7 +23,7 @@ from torch import Tensor, nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import MLP, SACPolicy -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed try: @@ -105,7 +105,7 @@ def create_default_train_batch( batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 ) -> dict[str, Tensor]: return { - "action": create_dummy_action(batch_size, action_dim), + ACTION: create_dummy_action(batch_size, action_dim), "reward": torch.randn(batch_size), "state": create_dummy_state(batch_size, state_dim), "next_state": create_dummy_state(batch_size, state_dim), @@ -117,7 +117,7 @@ def create_train_batch_with_visual_input( batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 ) -> dict[str, Tensor]: return { - "action": create_dummy_action(batch_size, action_dim), + ACTION: create_dummy_action(batch_size, action_dim), "reward": torch.randn(batch_size), "state": create_dummy_with_visual_input(batch_size, state_dim), "next_state": create_dummy_with_visual_input(batch_size, state_dim), @@ -182,13 +182,13 @@ def create_default_config( config = SACConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, dataset_stats={ OBS_STATE: { "min": [0.0] * state_dim, "max": [1.0] * state_dim, }, - "action": { + ACTION: { "min": [0.0] * continuous_action_dim, "max": [1.0] * continuous_action_dim, }, diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 8bf24db02..0f7018972 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,7 +2,7 @@ import torch from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor.converters import batch_to_transition, transition_to_batch -from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE def _dummy_batch(): @@ -11,7 +11,7 @@ def _dummy_batch(): f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), - "action": torch.tensor([[0.5]]), + ACTION: torch.tensor([[0.5]]), "next.reward": 1.0, "next.done": False, "next.truncated": False, @@ -37,7 +37,7 @@ def test_observation_grouping_roundtrip(): assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE]) # Check other fields - assert torch.allclose(batch_out["action"], batch_in["action"]) + assert torch.allclose(batch_out[ACTION], batch_in[ACTION]) assert batch_out["next.reward"] == batch_in["next.reward"] assert batch_out["next.done"] == batch_in["next.done"] assert batch_out["next.truncated"] == batch_in["next.truncated"] @@ -50,7 +50,7 @@ def test_batch_to_transition_observation_grouping(): f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), OBS_STATE: [1, 2, 3, 4], - "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), + ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, "next.truncated": False, @@ -114,7 +114,7 @@ def test_transition_to_batch_observation_flattening(): assert batch[OBS_STATE] == [1, 2, 3, 4] # Check other fields are mapped to next.* format - assert batch["action"] == "action_data" + assert batch[ACTION] == "action_data" assert batch["next.reward"] == 1.5 assert batch["next.done"] assert not batch["next.truncated"] @@ -124,7 +124,7 @@ def test_transition_to_batch_observation_flattening(): def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { - "action": torch.tensor([1.0, 2.0]), + ACTION: torch.tensor([1.0, 2.0]), "next.reward": 2.0, "next.done": False, "next.truncated": True, @@ -145,7 +145,7 @@ def test_no_observation_keys(): # Round trip should work reconstructed_batch = transition_to_batch(transition) - assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0])) + assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0])) assert reconstructed_batch["next.reward"] == 2.0 assert not reconstructed_batch["next.done"] assert reconstructed_batch["next.truncated"] @@ -154,7 +154,7 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])} + batch = {OBS_STATE: "minimal_state", ACTION: torch.tensor([0.5])} transition = batch_to_transition(batch) @@ -172,7 +172,7 @@ def test_minimal_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch[OBS_STATE] == "minimal_state" - assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) + assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] @@ -196,7 +196,7 @@ def test_empty_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) - assert reconstructed_batch["action"] is None + assert reconstructed_batch[ACTION] is None assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] @@ -209,7 +209,7 @@ def test_complex_nested_observation(): f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, OBS_STATE: torch.randn(7), - "action": torch.randn(8), + ACTION: torch.randn(8), "next.reward": 3.14, "next.done": False, "next.truncated": True, @@ -237,7 +237,7 @@ def test_complex_nested_observation(): ) # Check action tensor - assert torch.allclose(batch["action"], reconstructed_batch["action"]) + assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION]) # Check other fields assert batch["next.reward"] == reconstructed_batch["next.reward"] @@ -266,7 +266,7 @@ def test_custom_converter(): batch = { OBS_STATE: torch.randn(1, 4), - "action": torch.randn(1, 2), + ACTION: torch.randn(1, 2), "next.reward": 1.0, "next.done": False, } @@ -276,4 +276,4 @@ def test_custom_converter(): # Check the reward was doubled by our custom converter assert result["next.reward"] == 2.0 assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) - assert torch.allclose(result["action"], batch["action"]) + assert torch.allclose(result[ACTION], batch[ACTION]) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index b03d49214..d347858dc 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -9,7 +9,7 @@ from lerobot.processor.converters import ( to_tensor, transition_to_batch, ) -from lerobot.utils.constants import OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR # Tests for the unified to_tensor function @@ -118,16 +118,16 @@ def test_to_tensor_dictionaries(): # Nested dictionary nested = { - "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, + ACTION: {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10}, } result = to_tensor(nested) assert isinstance(result, dict) - assert isinstance(result["action"], dict) + assert isinstance(result[ACTION], dict) assert isinstance(result[OBS_STR], dict) - assert isinstance(result["action"]["mean"], torch.Tensor) + assert isinstance(result[ACTION]["mean"], torch.Tensor) assert isinstance(result[OBS_STR]["mean"], torch.Tensor) - assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result[ACTION]["mean"], torch.tensor([0.1, 0.2])) assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6])) @@ -200,7 +200,7 @@ def test_batch_to_transition_with_index_fields(): # Create batch with index and task_index fields batch = { OBS_STATE: torch.randn(1, 7), - "action": torch.randn(1, 4), + ACTION: torch.randn(1, 4), "next.reward": 1.5, "next.done": False, "task": ["pick_cube"], @@ -262,7 +262,7 @@ def test_batch_to_transition_without_index_fields(): # Batch without index/task_index batch = { OBS_STATE: torch.randn(1, 7), - "action": torch.randn(1, 4), + ACTION: torch.randn(1, 4), "task": ["pick_cube"], } diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index 36081e021..bb7d467bf 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -21,7 +21,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def test_basic_functionality(): @@ -273,7 +273,7 @@ def test_features(): features = { PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + PipelineFeatureType.ACTION: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } result = processor.transform_features(features) diff --git a/tests/processor/test_migration_detection.py b/tests/processor/test_migration_detection.py index b46cc6bdd..1ddc87d1e 100644 --- a/tests/processor/test_migration_detection.py +++ b/tests/processor/test_migration_detection.py @@ -25,7 +25,7 @@ from pathlib import Path import pytest from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError -from lerobot.utils.constants import OBS_STATE +from lerobot.utils.constants import ACTION, OBS_STATE def test_is_processor_config_valid_configs(): @@ -113,7 +113,7 @@ def test_should_suggest_migration_with_model_config_only(): model_config = { "type": "act", "input_features": {OBS_STATE: {"shape": [7]}}, - "output_features": {"action": {"shape": [7]}}, + "output_features": {ACTION: {"shape": [7]}}, "hidden_dim": 256, "n_obs_steps": 1, "n_action_steps": 1, diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 616f33db9..98c9e0b23 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -29,7 +29,7 @@ from lerobot.processor import ( hotswap_stats, ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR from lerobot.utils.utils import auto_select_torch_device @@ -50,15 +50,15 @@ def test_numpy_conversion(): def test_tensor_conversion(): stats = { - "action": { + ACTION: { "mean": torch.tensor([0.0, 0.0]), "std": torch.tensor([1.0, 1.0]), } } tensor_stats = to_tensor(stats) - assert tensor_stats["action"]["mean"].dtype == torch.float32 - assert tensor_stats["action"]["std"].dtype == torch.float32 + assert tensor_stats[ACTION]["mean"].dtype == torch.float32 + assert tensor_stats[ACTION]["std"].dtype == torch.float32 def test_scalar_conversion(): @@ -212,12 +212,12 @@ def test_from_lerobot_dataset(): mock_dataset = Mock() mock_dataset.meta.stats = { OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, - "action": {"mean": [0.0], "std": [1.0]}, + ACTION: {"mean": [0.0], "std": [1.0]}, } features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "action": PolicyFeature(FeatureType.ACTION, (1,)), + ACTION: PolicyFeature(FeatureType.ACTION, (1,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -228,7 +228,7 @@ def test_from_lerobot_dataset(): # Both observation and action statistics should be present in tensor stats assert OBS_IMAGE in normalizer._tensor_stats - assert "action" in normalizer._tensor_stats + assert ACTION in normalizer._tensor_stats def test_state_dict_save_load(observation_normalizer): @@ -271,7 +271,7 @@ def action_stats_min_max(): def _create_action_features(): return { - "action": PolicyFeature(FeatureType.ACTION, (3,)), + ACTION: PolicyFeature(FeatureType.ACTION, (3,)), } @@ -291,7 +291,7 @@ def test_mean_std_unnormalization(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std} ) normalized_action = torch.tensor([1.0, -0.5, 2.0]) @@ -309,7 +309,7 @@ def test_min_max_unnormalization(action_stats_min_max): features = _create_action_features() norm_map = _create_action_norm_map_min_max() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_min_max} + features=features, norm_map=norm_map, stats={ACTION: action_stats_min_max} ) # Actions in [-1, 1] @@ -335,7 +335,7 @@ def test_tensor_action_input(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std} ) normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32) @@ -353,7 +353,7 @@ def test_none_action(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std} ) transition = create_transition() @@ -365,11 +365,11 @@ def test_none_action(action_stats_mean_std): def test_action_from_lerobot_dataset(): mock_dataset = Mock() - mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} - features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} + mock_dataset.meta.stats = {ACTION: {"mean": [0.0], "std": [1.0]}} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (1,))} norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} unnormalizer = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) - assert "mean" in unnormalizer._tensor_stats["action"] + assert "mean" in unnormalizer._tensor_stats[ACTION] # Fixtures for NormalizerProcessorStep tests @@ -384,7 +384,7 @@ def full_stats(): "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, - "action": { + ACTION: { "mean": np.array([0.0, 0.0]), "std": np.array([1.0, 2.0]), }, @@ -395,7 +395,7 @@ def _create_full_features(): return { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } @@ -461,7 +461,7 @@ def test_processor_from_lerobot_dataset(full_stats): assert processor.normalize_observation_keys == {OBS_IMAGE} assert OBS_IMAGE in processor._tensor_stats - assert "action" in processor._tensor_stats + assert ACTION in processor._tensor_stats def test_get_config(full_stats): @@ -482,7 +482,7 @@ def test_get_config(full_stats): "features": { OBS_IMAGE: {"type": "VISUAL", "shape": (3, 96, 96)}, OBS_STATE: {"type": "STATE", "shape": (2,)}, - "action": {"type": "ACTION", "shape": (2,)}, + ACTION: {"type": "ACTION", "shape": (2,)}, }, "norm_map": { "VISUAL": "MEAN_STD", @@ -568,7 +568,7 @@ def test_missing_action_stats_no_error(): processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # The tensor stats should not contain the 'action' key - assert "action" not in processor._tensor_stats + assert ACTION not in processor._tensor_stats def test_serialization_roundtrip(full_stats): @@ -676,9 +676,9 @@ def test_identity_normalization_observations(): def test_identity_normalization_actions(): """Test that IDENTITY mode skips normalization for actions.""" - features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))} norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} - stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}} + stats = {ACTION: {"mean": [0.0, 0.0], "std": [1.0, 2.0]}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -729,9 +729,9 @@ def test_identity_unnormalization_observations(): def test_identity_unnormalization_actions(): """Test that IDENTITY mode skips unnormalization for actions.""" - features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))} norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} - stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}} + stats = {ACTION: {"min": [-1.0, -2.0], "max": [1.0, 2.0]}} unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -748,7 +748,7 @@ def test_identity_with_missing_stats(): """Test that IDENTITY mode works even when stats are missing.""" features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -784,7 +784,7 @@ def test_identity_mixed_with_other_modes(): features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -794,7 +794,7 @@ def test_identity_mixed_with_other_modes(): stats = { OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, - "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + ACTION: {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -862,7 +862,7 @@ def test_identity_roundtrip(): """Test that IDENTITY normalization and unnormalization are true inverses.""" features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -870,7 +870,7 @@ def test_identity_roundtrip(): } stats = { OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + ACTION: {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -893,7 +893,7 @@ def test_identity_config_serialization(): """Test that IDENTITY mode is properly saved and loaded in config.""" features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -901,7 +901,7 @@ def test_identity_config_serialization(): } stats = { OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, - "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + ACTION: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -969,19 +969,19 @@ def test_hotswap_stats_basic_functionality(): # Create initial stats initial_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create new stats for hotswapping new_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, - "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } # Create features and norm_map features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1177,17 +1177,17 @@ def test_hotswap_stats_multiple_normalizer_types(): """Test hotswap_stats with multiple normalizer and unnormalizer steps.""" initial_stats = { OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, - "action": {"min": np.array([-1.0]), "max": np.array([1.0])}, + ACTION: {"min": np.array([-1.0]), "max": np.array([1.0])}, } new_stats = { OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, - "action": {"min": np.array([-2.0]), "max": np.array([2.0])}, + ACTION: {"min": np.array([-2.0]), "max": np.array([2.0])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(1,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1232,7 +1232,7 @@ def test_hotswap_stats_with_different_data_types(): "min": 0, # int "max": 1.0, # float }, - "action": { + ACTION: { "mean": np.array([0.1, 0.2]), # numpy array "std": torch.tensor([0.5, 0.6]), # torch tensor }, @@ -1240,7 +1240,7 @@ def test_hotswap_stats_with_different_data_types(): features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1262,8 +1262,8 @@ def test_hotswap_stats_with_different_data_types(): assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) assert isinstance(tensor_stats[OBS_IMAGE]["min"], torch.Tensor) assert isinstance(tensor_stats[OBS_IMAGE]["max"], torch.Tensor) - assert isinstance(tensor_stats["action"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["action"]["std"], torch.Tensor) + assert isinstance(tensor_stats[ACTION]["mean"], torch.Tensor) + assert isinstance(tensor_stats[ACTION]["std"], torch.Tensor) # Check values torch.testing.assert_close(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.4, 0.5])) @@ -1284,18 +1284,18 @@ def test_hotswap_stats_functional_test(): # Initial stats initial_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # New stats new_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, - "action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1324,18 +1324,18 @@ def test_hotswap_stats_functional_test(): rtol=1e-3, atol=1e-3, ) - assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3) + assert not torch.allclose(original_result[ACTION], new_result[ACTION], rtol=1e-3, atol=1e-3) # Verify that the new processor is actually using the new stats by checking internal state assert new_processor.steps[0].stats == new_stats assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.2])) assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2])) - assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1])) - assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5])) + assert torch.allclose(new_processor.steps[0]._tensor_stats[ACTION]["mean"], torch.tensor([0.1, -0.1])) + assert torch.allclose(new_processor.steps[0]._tensor_stats[ACTION]["std"], torch.tensor([0.5, 0.5])) # Test that normalization actually happens (output should not equal input) assert not torch.allclose(new_result[OBS_STR][OBS_IMAGE], observation[OBS_IMAGE]) - assert not torch.allclose(new_result["action"], action) + assert not torch.allclose(new_result[ACTION], action) def test_zero_std_uses_eps(): @@ -1366,10 +1366,10 @@ def test_action_normalized_despite_normalize_observation_keys(): """Action normalization is independent of normalize_observation_keys filter for observations.""" features = { OBS_STATE: PolicyFeature(FeatureType.STATE, (1,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} - stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + stats = {ACTION: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessorStep( features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={OBS_STATE} ) @@ -1426,9 +1426,9 @@ def test_unknown_observation_keys_ignored(): def test_batched_action_normalization(): - features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))} norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} - stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + stats = {ACTION: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1] @@ -1453,12 +1453,12 @@ def test_complementary_data_preservation(): def test_roundtrip_normalize_unnormalize_non_identity(): features = { OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} stats = { OBS_STATE: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, - "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, + ACTION: {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -1530,18 +1530,18 @@ def test_stats_override_preservation_in_load_state_dict(): # Create original stats original_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create override stats (what user wants to use) override_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, - "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1601,12 +1601,12 @@ def test_stats_without_override_loads_normally(): """ original_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1674,7 +1674,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Create test data features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1683,12 +1683,12 @@ def test_pipeline_from_pretrained_with_stats_overrides(): original_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } override_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, - "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } # Create and save a pipeline with the original stats @@ -1751,8 +1751,8 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # The critical part was verified above: loaded_normalizer.stats == override_stats # This confirms that override stats are preserved during load_state_dict. # Let's just verify the pipeline processes data successfully. - assert "action" in override_result - assert isinstance(override_result["action"], torch.Tensor) + assert ACTION in override_result + assert isinstance(override_result[ACTION], torch.Tensor) def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): @@ -1812,7 +1812,7 @@ def test_stats_reconstruction_after_load_state_dict(): features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1828,7 +1828,7 @@ def test_stats_reconstruction_after_load_state_dict(): "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, - "action": { + ACTION: { "mean": np.array([0.0, 0.0]), "std": np.array([1.0, 2.0]), }, @@ -1852,15 +1852,15 @@ def test_stats_reconstruction_after_load_state_dict(): # Check that all expected keys are present assert OBS_IMAGE in new_normalizer.stats assert OBS_STATE in new_normalizer.stats - assert "action" in new_normalizer.stats + assert ACTION in new_normalizer.stats # Check that values are correct (converted back from tensors) np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5]) np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0]) - np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) - np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) + np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0]) + np.testing.assert_allclose(new_normalizer.stats[ACTION]["std"], [1.0, 2.0]) # Test that methods that depend on self.stats work correctly after loading # This would fail before the bug fix because self.stats was empty @@ -1876,7 +1876,7 @@ def test_stats_reconstruction_after_load_state_dict(): new_stats = { OBS_IMAGE: {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, OBS_STATE: {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, - "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, + ACTION: {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, } pipeline = DataProcessorPipeline([new_normalizer]) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 6d056e4dc..6dbf37450 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -35,7 +35,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -257,7 +257,7 @@ def test_step_through_with_dict(): batch = { OBS_IMAGE: None, - "action": None, + ACTION: None, "next.reward": 0.0, "next.done": False, "next.truncated": False, @@ -1842,7 +1842,7 @@ def test_save_load_with_custom_converter_functions(): # Verify it uses default converters by checking with standard batch format batch = { OBS_IMAGE: torch.randn(1, 3, 32, 32), - "action": torch.randn(1, 7), + ACTION: torch.randn(1, 7), "next.reward": torch.tensor([1.0]), "next.done": torch.tensor([False]), "next.truncated": torch.tensor([False]), @@ -2094,11 +2094,11 @@ def test_aggregate_joint_action_only(): patterns=["action.j1.pos", "action.j2.pos"], ) - # Expect only "action" with joint names - assert "action" in out and OBS_STATE not in out - assert out["action"]["dtype"] == "float32" - assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} - assert out["action"]["shape"] == (len(out["action"]["names"]),) + # Expect only ACTION with joint names + assert ACTION in out and OBS_STATE not in out + assert out[ACTION]["dtype"] == "float32" + assert set(out[ACTION]["names"]) == {"j1.pos", "j2.pos"} + assert out[ACTION]["shape"] == (len(out[ACTION]["names"]),) def test_aggregate_ee_action_and_observation_with_videos(): @@ -2113,9 +2113,9 @@ def test_aggregate_ee_action_and_observation_with_videos(): ) # Action should pack only EE names - assert "action" in out - assert set(out["action"]["names"]) == {"ee.x", "ee.y"} - assert out["action"]["dtype"] == "float32" + assert ACTION in out + assert set(out[ACTION]["names"]) == {"ee.x", "ee.y"} + assert out[ACTION]["dtype"] == "float32" # Observation state should pack both ee.x and j1.pos as a vector assert OBS_STATE in out @@ -2140,10 +2140,10 @@ def test_aggregate_both_action_types(): patterns=["action.ee", "action.j1", "action.j2.pos"], ) - assert "action" in out + assert ACTION in out expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"} - assert set(out["action"]["names"]) == expected - assert out["action"]["shape"] == (len(expected),) + assert set(out[ACTION]["names"]) == expected + assert out[ACTION]["shape"] == (len(expected),) def test_aggregate_images_when_use_videos_false(): diff --git a/tests/processor/test_policy_robot_bridge.py b/tests/processor/test_policy_robot_bridge.py index f3bbd9a74..6269c508f 100644 --- a/tests/processor/test_policy_robot_bridge.py +++ b/tests/processor/test_policy_robot_bridge.py @@ -28,6 +28,7 @@ from lerobot.processor import ( RobotActionToPolicyActionProcessorStep, ) from lerobot.processor.converters import identity_transition +from lerobot.utils.constants import ACTION from tests.conftest import assert_contract_is_typed @@ -134,8 +135,8 @@ def test_robot_to_policy_transform_features(): transformed = processor.transform_features(features) - assert "action" in transformed[PipelineFeatureType.ACTION] - action_feature = transformed[PipelineFeatureType.ACTION]["action"] + assert ACTION in transformed[PipelineFeatureType.ACTION] + action_feature = transformed[PipelineFeatureType.ACTION][ACTION] assert action_feature.type == FeatureType.ACTION assert action_feature.shape == (3,) @@ -251,7 +252,7 @@ def test_policy_to_robot_transform_features(): features = { PipelineFeatureType.ACTION: { - "action": {"type": FeatureType.ACTION, "shape": (2,)}, + ACTION: {"type": FeatureType.ACTION, "shape": (2,)}, "other_data": {"type": FeatureType.ENV, "shape": (1,)}, } } @@ -266,7 +267,7 @@ def test_policy_to_robot_transform_features(): assert motor_feature.type == FeatureType.ACTION assert motor_feature.shape == (1,) - assert "action" in transformed[PipelineFeatureType.ACTION] + assert ACTION in transformed[PipelineFeatureType.ACTION] assert "other_data" in transformed[PipelineFeatureType.ACTION] @@ -447,8 +448,8 @@ def test_robot_to_policy_features_contract(policy_feature_factory): assert_contract_is_typed(out) - assert "action" in out[PipelineFeatureType.ACTION] - action_feature = out[PipelineFeatureType.ACTION]["action"] + assert ACTION in out[PipelineFeatureType.ACTION] + action_feature = out[PipelineFeatureType.ACTION][ACTION] assert action_feature.type == FeatureType.ACTION assert action_feature.shape == (2,) @@ -458,7 +459,7 @@ def test_policy_to_robot_features_contract(policy_feature_factory): processor = PolicyActionToRobotActionProcessorStep(motor_names=["m1", "m2", "m3"]) features = { PipelineFeatureType.ACTION: { - "action": policy_feature_factory(FeatureType.ACTION, (3,)), + ACTION: policy_feature_factory(FeatureType.ACTION, (3,)), "other": policy_feature_factory(FeatureType.ENV, (1,)), } } diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index c6aa303f1..efb9f9328 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -28,7 +28,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.rename_processor import rename_stats -from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -488,7 +488,7 @@ def test_features_chained_processors(policy_feature_factory): def test_rename_stats_basic(): orig = { OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}, - "action": {"mean": np.array([0.0])}, + ACTION: {"mean": np.array([0.0])}, } mapping = {OBS_STATE: "observation.robot_state"} renamed = rename_stats(orig, mapping) diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 35bbcfd8a..503f2e036 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -11,7 +11,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE from tests.utils import require_package @@ -504,14 +504,14 @@ def test_features_basic(): input_features = { PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + PipelineFeatureType.ACTION: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } output_features = processor.transform_features(input_features) # Check that original features are preserved assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION] - assert "action" in output_features[PipelineFeatureType.ACTION] + assert ACTION in output_features[PipelineFeatureType.ACTION] # Check that tokenized features are added assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION] diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py index 79edad4e4..52825a24e 100644 --- a/tests/transport/test_transport_utils.py +++ b/tests/transport/test_transport_utils.py @@ -21,6 +21,7 @@ from pickle import UnpicklingError import pytest import torch +from lerobot.utils.constants import ACTION from lerobot.utils.transition import Transition from tests.utils import require_cuda, require_package @@ -512,7 +513,7 @@ def test_transitions_to_bytes_single_transition(): def assert_transitions_equal(t1: Transition, t2: Transition): """Helper to assert two transitions are equal.""" assert_observation_equal(t1["state"], t2["state"]) - assert torch.allclose(t1["action"], t2["action"]) + assert torch.allclose(t1[ACTION], t2[ACTION]) assert torch.allclose(t1["reward"], t2["reward"]) assert torch.equal(t1["done"], t2["done"]) assert_observation_equal(t1["next_state"], t2["next_state"]) diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 6820d321f..1e6c0df95 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -22,7 +22,7 @@ import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_REPO_ID @@ -63,7 +63,7 @@ def create_random_image() -> torch.Tensor: def create_dummy_transition() -> dict: return { OBS_IMAGE: create_random_image(), - "action": torch.randn(4), + ACTION: torch.randn(4), "reward": torch.tensor(1.0), OBS_STATE: torch.randn( 10, @@ -341,7 +341,7 @@ def test_sample_batch(replay_buffer): f"{k} should be equal to one of the dummy states." ) - for got_action_item in got_batch_transition["action"]: + for got_action_item in got_batch_transition[ACTION]: assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), ( "Actions should be equal to the dummy actions." ) @@ -378,7 +378,7 @@ def test_to_lerobot_dataset(tmp_path): for i in range(len(ds)): for feature, value in ds[i].items(): - if feature == "action": + if feature == ACTION: assert torch.equal(value, buffer.actions[i]) elif feature == "next.reward": assert torch.equal(value, buffer.rewards[i]) @@ -495,7 +495,7 @@ def test_buffer_sample_alignment(): for i in range(50): state_sig = batch["state"]["state_value"][i].item() - action_val = batch["action"][i].item() + action_val = batch[ACTION][i].item() reward_val = batch["reward"][i].item() next_state_sig = batch["next_state"]["state_value"][i].item() is_done = batch["done"][i].item() > 0.5