From 15724826dd5b18c18656151a02a24ce5fd690c46 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 13 Jan 2026 09:49:46 +0100 Subject: [PATCH] chore: use alias & constants (#2785) * chore: use alias and constants * fix(rl): solve circular dependecy * chore: nit right constant * chore: pre-commit * chore(script): conflict tokenizer train --------- Signed-off-by: Steven Palma --- src/lerobot/async_inference/robot_client.py | 4 +-- src/lerobot/datasets/pipeline_features.py | 4 +-- src/lerobot/envs/libero.py | 6 ++-- src/lerobot/envs/metaworld.py | 12 ++++---- src/lerobot/envs/utils.py | 3 +- src/lerobot/policies/factory.py | 8 ++++-- .../policies/groot/configuration_groot.py | 13 +++++---- src/lerobot/policies/groot/groot_n1.py | 6 ++-- src/lerobot/policies/groot/modeling_groot.py | 3 +- src/lerobot/policies/groot/processor_groot.py | 28 +++++++++++-------- src/lerobot/policies/pi0/configuration_pi0.py | 10 +++---- .../policies/pi05/configuration_pi05.py | 11 ++++---- .../pi0_fast/configuration_pi0_fast.py | 11 ++++---- .../policies/sarm/configuration_sarm.py | 5 ++-- src/lerobot/policies/sarm/modeling_sarm.py | 3 +- src/lerobot/policies/utils.py | 3 +- .../policies/wall_x/configuration_wall_x.py | 13 +++++---- .../policies/wall_x/modeling_wall_x.py | 6 ++-- src/lerobot/policies/xvla/processor_xvla.py | 16 ++++++----- src/lerobot/processor/__init__.py | 3 -- src/lerobot/processor/converters.py | 6 ++-- src/lerobot/processor/core.py | 2 +- src/lerobot/processor/env_processor.py | 11 ++++---- src/lerobot/processor/normalize_processor.py | 4 +-- src/lerobot/processor/pipeline.py | 6 ++-- src/lerobot/processor/tokenizer_processor.py | 4 +-- src/lerobot/rl/gym_manipulator.py | 11 ++++---- .../joint_observations_processor.py | 0 .../robots/bi_so_follower/bi_so_follower.py | 6 ++-- .../robot_earthrover_mini_plus.py | 11 ++++---- src/lerobot/robots/hope_jr/hope_jr_arm.py | 6 ++-- src/lerobot/robots/hope_jr/hope_jr_hand.py | 6 ++-- .../robots/koch_follower/koch_follower.py | 10 +++---- src/lerobot/robots/lekiwi/lekiwi.py | 7 +++-- src/lerobot/robots/lekiwi/lekiwi_client.py | 19 ++++++------- .../robots/omx_follower/omx_follower.py | 10 +++---- src/lerobot/robots/reachy2/robot_reachy2.py | 9 +++--- src/lerobot/robots/robot.py | 12 ++++---- .../so_follower/robot_kinematic_processor.py | 3 +- src/lerobot/robots/so_follower/so_follower.py | 9 +++--- src/lerobot/robots/unitree_g1/unitree_g1.py | 5 ++-- src/lerobot/scripts/lerobot_eval.py | 4 +-- src/lerobot/scripts/lerobot_teleoperate.py | 2 +- .../scripts/lerobot_train_tokenizer.py | 13 ++++----- .../teleoperators/gamepad/teleop_gamepad.py | 4 ++- .../teleoperators/keyboard/teleop_keyboard.py | 9 +++--- src/lerobot/teleoperators/teleoperator.py | 5 ++-- src/lerobot/utils/constants.py | 2 ++ src/lerobot/utils/visualization_utils.py | 11 ++++---- tests/mocks/mock_robot.py | 6 ++-- tests/mocks/mock_teleop.py | 3 +- 51 files changed, 206 insertions(+), 178 deletions(-) rename src/lerobot/{processor => rl}/joint_observations_processor.py (100%) diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index eea5585b0..f26639dc1 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -40,7 +40,6 @@ from collections.abc import Callable from dataclasses import asdict from pprint import pformat from queue import Queue -from typing import Any import draccus import grpc @@ -48,6 +47,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.processor import RobotAction from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -351,7 +351,7 @@ class RobotClient: action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)} return action - def control_loop_action(self, verbose: bool = False) -> dict[str, Any]: + def control_loop_action(self, verbose: bool = False) -> RobotAction: """Reading and performing actions in local queue""" # Lock only for queue operations diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 4fad7bd20..161633f26 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -18,12 +18,12 @@ from typing import Any from lerobot.configs.types import PipelineFeatureType from lerobot.datasets.utils import hw_to_dataset_features -from lerobot.processor import DataProcessorPipeline +from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR def create_initial_features( - action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None + action: RobotAction | None = None, observation: RobotObservation | None = None ) -> dict[PipelineFeatureType, dict[str, Any]]: """ Creates the initial features dict for the dataset from action and observation specs. diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index b1eb37377..74882ad18 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -29,6 +29,8 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv +from lerobot.processor import RobotObservation + def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: """Normalize camera_name into a non-empty list of strings.""" @@ -237,7 +239,7 @@ class LiberoEnv(gym.Env): env.reset() return env - def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]: + def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation: images = {} for camera_name in self.camera_name: image = raw_obs[camera_name] @@ -313,7 +315,7 @@ class LiberoEnv(gym.Env): info = {"is_success": False} return observation, info - def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: if action.ndim != 1: raise ValueError( f"Expected action to be 1-D (shape (action_dim,)), " diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py index 9190f33ad..4d91e002d 100644 --- a/src/lerobot/envs/metaworld.py +++ b/src/lerobot/envs/metaworld.py @@ -25,6 +25,8 @@ import metaworld.policies as policies import numpy as np from gymnasium import spaces +from lerobot.processor import RobotObservation + # ---- Load configuration data from the external JSON file ---- CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" try: @@ -161,7 +163,7 @@ class MetaworldEnv(gym.Env): env._freeze_rand_vec = False # otherwise no randomization return env - def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]: + def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation: image = None if self._env is not None: image = self._env.render() @@ -196,7 +198,7 @@ class MetaworldEnv(gym.Env): self, seed: int | None = None, **kwargs, - ) -> tuple[dict[str, Any], dict[str, Any]]: + ) -> tuple[RobotObservation, dict[str, Any]]: """ Reset the environment to its initial state. @@ -204,7 +206,7 @@ class MetaworldEnv(gym.Env): seed (Optional[int]): Random seed for environment initialization. Returns: - observation (Dict[str, Any]): The initial formatted observation. + observation (RobotObservation): The initial formatted observation. info (Dict[str, Any]): Additional info about the reset state. """ super().reset(seed=seed) @@ -216,7 +218,7 @@ class MetaworldEnv(gym.Env): info = {"is_success": False} return observation, info - def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: """ Perform one environment step. @@ -224,7 +226,7 @@ class MetaworldEnv(gym.Env): action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,). Returns: - observation (Dict[str, Any]): The formatted observation after the step. + observation (RobotObservation): The formatted observation after the step. reward (float): The scalar reward for this step. terminated (bool): Whether the episode terminated successfully. truncated (bool): Whether the episode was truncated due to a time limit. diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index af814c92a..09431a18d 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -29,6 +29,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig +from lerobot.processor import RobotObservation from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import get_channel_first_image_shape @@ -152,7 +153,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: ) -def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]: +def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation: """Adds task feature to the observation dict with respect to the first environment attribute.""" if hasattr(env.envs[0], "task_description"): task_result = env.call("task_description") diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index fff08ad37..a593e5bcb 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -51,7 +51,11 @@ from lerobot.processor.converters import ( transition_to_batch, transition_to_policy_action, ) -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.utils.constants import ( + ACTION, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) def get_policy_class(name: str) -> type[PreTrainedPolicy]: @@ -250,7 +254,7 @@ def make_pre_post_processors( } # Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats - env_action_dim = policy_cfg.output_features["action"].shape[0] + env_action_dim = policy_cfg.output_features[ACTION].shape[0] postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = { "stats": kwargs.get("dataset_stats"), "normalize_min_max": True, diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 8002c69ea..4f3d78222 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.utils.constants import ACTION, OBS_STATE @PreTrainedConfig.register_subclass("groot") @@ -137,14 +138,14 @@ class GrootConfig(PreTrainedConfig): "No features of type FeatureType.VISUAL found in input_features." ) - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature else: - state_shape = self.input_features["observation.state"].shape + state_shape = self.input_features[OBS_STATE].shape state_dim = state_shape[0] if state_shape else 0 if state_dim > self.max_state_dim: raise ValueError( @@ -152,14 +153,14 @@ class GrootConfig(PreTrainedConfig): f"Either reduce state dimension or increase max_state_dim in config." ) - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature else: - action_shape = self.output_features["action"].shape + action_shape = self.output_features[ACTION].shape action_dim = action_shape[0] if action_shape else 0 if action_dim > self.max_action_dim: raise ValueError( diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 0da26874e..06ff5a04d 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -46,7 +46,7 @@ from lerobot.policies.groot.action_head.flow_matching_action_head import ( FlowmatchingActionHeadConfig, ) from lerobot.policies.groot.utils import ensure_eagle_cache_ready -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve()) DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5" @@ -227,8 +227,8 @@ class GR00TN15(PreTrainedModel): detected_error = False error_msg = ERROR_MSG - if "action" in inputs: - action = inputs["action"] + if ACTION in inputs: + action = inputs[ACTION] # In inference, action may be omitted or None; validate only when it's a tensor. if action is None: pass # allow None during inference diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index bdaef37b9..fd9baa9b1 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -41,6 +41,7 @@ from torch import Tensor from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.groot_n1 import GR00TN15 from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION class GrootPolicy(PreTrainedPolicy): @@ -147,7 +148,7 @@ class GrootPolicy(PreTrainedPolicy): actions = outputs.get("action_pred") - original_action_dim = self.config.output_features["action"].shape[0] + original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim] return actions diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index d87e43c11..14149cf2f 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -51,7 +51,11 @@ from lerobot.processor.converters import ( ) from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.utils.constants import ( + ACTION, HF_LEROBOT_HOME, + OBS_IMAGE, + OBS_IMAGES, + OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, ) @@ -107,9 +111,9 @@ def make_groot_pre_post_processors( # Define feature specs for optional normalization steps _features: dict[str, PolicyFeature] = { # Observation features (only add those we may normalize) - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)), # Action feature - "action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)), } # Normalize STATE and ACTION with min_max (SO100-like default) @@ -120,7 +124,7 @@ def make_groot_pre_post_processors( # Determine env action dimension from config (simple, object-like PolicyFeature) try: - env_action_dim = int(config.output_features["action"].shape[0]) + env_action_dim = int(config.output_features[ACTION].shape[0]) except Exception: env_action_dim = 0 @@ -268,9 +272,9 @@ class GrootPackInputsStep(ProcessorStep): return torch.where(mask, mapped, torch.zeros_like(mapped)) # 1) Video (B, T=1, V, H, W, C) uint8 - img_keys = sorted([k for k in obs if k.startswith("observation.images.")]) - if not img_keys and "observation.image" in obs: - img_keys = ["observation.image"] + img_keys = sorted([k for k in obs if k.startswith(OBS_IMAGES)]) + if not img_keys and OBS_IMAGE in obs: + img_keys = [OBS_IMAGE] if img_keys: cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys] video = np.stack(cams, axis=1) # (B, V, H, W, C) @@ -294,14 +298,14 @@ class GrootPackInputsStep(ProcessorStep): comp["language"] = lang # 3) State/state_mask -> (B, 1, max_state_dim) - if "observation.state" in obs: - state = obs["observation.state"] # (B, D) + if OBS_STATE in obs: + state = obs[OBS_STATE] # (B, D) if state.dim() != 2: raise ValueError(f"state must be (B, D), got {tuple(state.shape)}") bsz, d = state.shape # Normalize BEFORE padding if self.normalize_min_max: - state = _min_max_norm(state, "observation.state") + state = _min_max_norm(state, OBS_STATE) state = state.unsqueeze(1) # (B, 1, D) if d > self.max_state_dim: state = state[:, :, : self.max_state_dim] @@ -320,11 +324,11 @@ class GrootPackInputsStep(ProcessorStep): # Normalize BEFORE temporal expansion/padding if self.normalize_min_max: if action.dim() == 2: - action = _min_max_norm(action, "action") + action = _min_max_norm(action, ACTION) elif action.dim() == 3: b, t, d = action.shape flat = action.reshape(b * t, d) - flat = _min_max_norm(flat, "action") + flat = _min_max_norm(flat, ACTION) action = flat.view(b, t, d) if action.dim() == 2: action = action.unsqueeze(1).repeat(1, self.action_horizon, 1) @@ -590,7 +594,7 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): # forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0 # inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min if self.normalize_min_max and self.stats is not None: - stats_k = self.stats.get("action", {}) + stats_k = self.stats.get(ACTION, {}) d = action.shape[-1] min_v = torch.as_tensor( stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index d91145aa7..be9b4530f 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig -from lerobot.utils.constants import OBS_IMAGES +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -124,19 +124,19 @@ class PI0Config(PreTrainedConfig): ) self.input_features[key] = empty_camera - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 4a0e23039..b96e6d196 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -117,26 +118,26 @@ class PI05Config(PreTrainedConfig): def validate_features(self) -> None: """Validate and set up input/output features.""" for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = OBS_IMAGES + f".empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, *self.image_resolution), # Use configured image resolution ) self.input_features[key] = empty_camera - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py index 42aa4a132..96137e91f 100644 --- a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -110,26 +111,26 @@ class PI0FastConfig(PreTrainedConfig): def validate_features(self) -> None: """Validate and set up input/output features.""" for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = OBS_IMAGES + f".empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, *self.image_resolution), # Use configured image resolution ) self.input_features[key] = empty_camera - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 59cb352d5..673422fe2 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -26,6 +26,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE @PreTrainedConfig.register_subclass("sarm") @@ -86,8 +87,8 @@ class SARMConfig(PreTrainedConfig): pretrained_model_path: str | None = None device: str | None = None - image_key: str = "observation.images.top" # Key for image used from the dataset - state_key: str = "observation.state" + image_key: str = OBS_IMAGES + ".top" # Key for image used from the dataset + state_key: str = OBS_STATE # Populated by the processor (video_features, state_features, text_features) input_features: dict = field(default_factory=lambda: {}) diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index a88b2ad64..6051d90f8 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -40,6 +40,7 @@ from lerobot.policies.sarm.sarm_utils import ( normalize_stage_tau, pad_state_to_max_dim, ) +from lerobot.utils.constants import OBS_STR class StageTransformer(nn.Module): @@ -721,7 +722,7 @@ class SARMRewardModel(PreTrainedPolicy): Returns: Tuple of (total_loss, output_dict with loss components) """ - observation = batch.get("observation", batch) + observation = batch.get(OBS_STR, batch) # Extract features video_features = observation["video_features"].to(self.device) diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index bfbe2bf1d..1a14b2925 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -16,7 +16,6 @@ import logging from collections import deque -from typing import Any import numpy as np import torch @@ -140,7 +139,7 @@ def prepare_observation_for_inference( def build_inference_frame( - observation: dict[str, Any], + observation: RobotObservation, device: torch.device, ds_features: dict[str, dict], task: str | None = None, diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index 0d10a8f98..3962b56f6 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -18,6 +18,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.utils.constants import ACTION, OBS_STATE @PreTrainedConfig.register_subclass("wall_x") @@ -105,14 +106,14 @@ class WallXConfig(PreTrainedConfig): "No features of type FeatureType.VISUAL found in input_features." ) - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature else: - state_shape = self.input_features["observation.state"].shape + state_shape = self.input_features[OBS_STATE].shape state_dim = state_shape[0] if state_shape else 0 if state_dim > self.max_state_dim: raise ValueError( @@ -120,14 +121,14 @@ class WallXConfig(PreTrainedConfig): f"Either reduce state dimension or increase max_state_dim in config." ) - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature else: - action_shape = self.output_features["action"].shape + action_shape = self.output_features[ACTION].shape action_dim = action_shape[0] if action_shape else 0 if action_dim > self.max_action_dim: raise ValueError( diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index 94ee7897e..ef99bad89 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -1861,7 +1861,7 @@ class WallXPolicy(PreTrainedPolicy): dim=-1, ) else: - action_dim = self.config.output_features["action"].shape[0] + action_dim = self.config.output_features[ACTION].shape[0] dof_mask = torch.cat( [ torch.ones( @@ -1977,7 +1977,7 @@ class WallXPolicy(PreTrainedPolicy): elif self.config.prediction_mode == "fast": output = self.model( **batch, - action_dim=self.config.output_features["action"].shape[0], + action_dim=self.config.output_features[ACTION].shape[0], pred_horizon=self.config.chunk_size, mode="predict", predict_mode="fast", @@ -1989,7 +1989,7 @@ class WallXPolicy(PreTrainedPolicy): actions = output["predict_action"] # Unpad actions to actual action dimension - action_dim = self.config.output_features["action"].shape[0] + action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :action_dim] return actions diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index 7f7297b9a..c4e3f2d6f 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -41,6 +41,7 @@ from lerobot.processor.converters import policy_action_to_transition, transition from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_IMAGES, + OBS_PREFIX, OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, @@ -137,8 +138,9 @@ class LiberoProcessorStep(ObservationProcessorStep): processed_obs[key] = img # Process robot_state into a flat state vector - if "observation.robot_state" in processed_obs: - robot_state = processed_obs.pop("observation.robot_state") + robot_state_str = OBS_PREFIX + "robot_state" + if robot_state_str in processed_obs: + robot_state = processed_obs.pop(robot_state_str) # Extract components eef_pos = robot_state["eef"]["pos"] # (B, 3,) @@ -174,8 +176,8 @@ class LiberoProcessorStep(ObservationProcessorStep): state_feats = {} # add our new flattened state - state_feats["observation.state"] = PolicyFeature( - key="observation.state", + state_feats[OBS_STATE] = PolicyFeature( + key=OBS_STATE, shape=(20,), dtype="float32", ) @@ -247,7 +249,7 @@ class XVLAImageScaleProcessorStep(ProcessorStep): keys_to_scale = self.image_keys if keys_to_scale is None: # Auto-detect image keys - keys_to_scale = [k for k in obs if k.startswith("observation.images.")] + keys_to_scale = [k for k in obs if k.startswith(OBS_IMAGES)] # Scale each image for key in keys_to_scale: @@ -303,7 +305,7 @@ class XVLAImageToFloatProcessorStep(ProcessorStep): keys_to_convert = self.image_keys if keys_to_convert is None: # Auto-detect image keys - keys_to_convert = [k for k in obs if k.startswith("observation.images.")] + keys_to_convert = [k for k in obs if k.startswith(OBS_IMAGES)] # Convert each image for key in keys_to_convert: @@ -376,7 +378,7 @@ class XVLAImageNetNormalizeProcessorStep(ProcessorStep): keys_to_normalize = self.image_keys if keys_to_normalize is None: # Auto-detect image keys - keys_to_normalize = [k for k in obs if k.startswith("observation.images.")] + keys_to_normalize = [k for k in obs if k.startswith(OBS_IMAGES)] # Normalize each image for key in keys_to_normalize: diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 676ba29ee..164f7da03 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -49,7 +49,6 @@ from .hil_processor import ( RewardClassifierProcessorStep, TimeLimitProcessorStep, ) -from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats from .observation_processor import VanillaObservationProcessorStep from .pipeline import ( @@ -94,14 +93,12 @@ __all__ = [ "ImageCropResizeProcessorStep", "InfoProcessorStep", "InterventionActionProcessorStep", - "JointVelocityProcessorStep", "make_default_processors", "make_default_teleop_action_processor", "make_default_robot_action_processor", "make_default_robot_observation_processor", "MapDeltaActionToRobotActionStep", "MapTensorToDeltaActionDictStep", - "MotorCurrentProcessorStep", "NormalizerProcessorStep", "Numpy2TorchActionProcessorStep", "ObservationProcessorStep", diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 126be0e36..4f9485fee 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,7 +23,7 @@ from typing import Any import numpy as np import torch -from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED +from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -176,7 +176,7 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: def create_transition( - observation: dict[str, Any] | None = None, + observation: RobotObservation | None = None, action: PolicyAction | RobotAction | None = None, reward: float = 0.0, done: bool = False, @@ -384,7 +384,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: REWARD: transition.get(TransitionKey.REWARD, 0.0), DONE: transition.get(TransitionKey.DONE, False), TRUNCATED: transition.get(TransitionKey.TRUNCATED, False), - "info": transition.get(TransitionKey.INFO, {}), + INFO: transition.get(TransitionKey.INFO, {}), } # Add complementary data. diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py index 679ba8c54..0b293c9b0 100644 --- a/src/lerobot/processor/core.py +++ b/src/lerobot/processor/core.py @@ -45,7 +45,7 @@ RobotObservation: TypeAlias = dict[str, Any] EnvTransition = TypedDict( "EnvTransition", { - TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.OBSERVATION.value: RobotObservation | None, TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None, TransitionKey.REWARD.value: float | torch.Tensor | None, TransitionKey.DONE.value: bool | torch.Tensor | None, diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index a5210af30..8d42bfdb7 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -18,7 +18,7 @@ from dataclasses import dataclass import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -60,8 +60,9 @@ class LiberoProcessorStep(ObservationProcessorStep): processed_obs[key] = img # Process robot_state into a flat state vector - if "observation.robot_state" in processed_obs: - robot_state = processed_obs.pop("observation.robot_state") + observation_robot_state_str = OBS_PREFIX + "robot_state" + if observation_robot_state_str in processed_obs: + robot_state = processed_obs.pop(observation_robot_state_str) # Extract components eef_pos = robot_state["eef"]["pos"] # (B, 3,) @@ -98,8 +99,8 @@ class LiberoProcessorStep(ObservationProcessorStep): state_feats = {} # add our new flattened state - state_feats["observation.state"] = PolicyFeature( - key="observation.state", + state_feats[OBS_STATE] = PolicyFeature( + key=OBS_STATE, shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] dtype="float32", description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 368c9b270..4769b91ac 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -30,7 +30,7 @@ from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor from .core import EnvTransition, PolicyAction, TransitionKey -from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry +from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation @dataclass @@ -239,7 +239,7 @@ class _NormalizationMixin: config["normalize_observation_keys"] = sorted(self.normalize_observation_keys) return config - def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]: + def _normalize_observation(self, observation: RobotObservation, inverse: bool) -> dict[str, Tensor]: """ Applies (un)normalization to all relevant features in an observation dictionary. diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index e14d8b0b9..97ec716ff 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -49,7 +49,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.utils.hub import HubMixin from .converters import batch_to_transition, create_transition, transition_to_batch -from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey +from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey # Generic type variables for pipeline input and output. TInput = TypeVar("TInput") @@ -1337,7 +1337,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): return features # Convenience methods for processing individual parts of a transition. - def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def process_observation(self, observation: RobotObservation) -> RobotObservation: """Processes only the observation part of a transition through the pipeline. Args: @@ -1440,7 +1440,7 @@ class ObservationProcessorStep(ProcessorStep, ABC): """An abstract `ProcessorStep` that specifically targets the observation in a transition.""" @abstractmethod - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def observation(self, observation: RobotObservation) -> RobotObservation: """Processes an observation dictionary. Subclasses must implement this method. Args: diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 93e0395b9..5cd1bebb0 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -38,7 +38,7 @@ from lerobot.utils.constants import ( ) from lerobot.utils.import_utils import _transformers_available -from .core import EnvTransition, TransitionKey +from .core import EnvTransition, RobotObservation, TransitionKey from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading @@ -139,7 +139,7 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def observation(self, observation: RobotObservation) -> RobotObservation: """ Tokenizes the task description and adds it to the observation dictionary. diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 604adb931..3d58ae18f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -38,13 +38,12 @@ from lerobot.processor import ( GripperPenaltyProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, - JointVelocityProcessorStep, MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep, - MotorCurrentProcessorStep, Numpy2TorchActionProcessorStep, RewardClassifierProcessorStep, RobotActionToPolicyActionProcessorStep, + RobotObservation, TimeLimitProcessorStep, Torch2NumpyActionProcessorStep, TransitionKey, @@ -77,6 +76,8 @@ from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say +from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep + logging.basicConfig(level=logging.INFO) @@ -163,7 +164,7 @@ class RobotEnv(gym.Env): self._setup_spaces() - def _get_observation(self) -> dict[str, Any]: + def _get_observation(self) -> RobotObservation: """Get current robot observation including joint positions and camera images.""" obs_dict = self.robot.get_observation() raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names} @@ -220,7 +221,7 @@ class RobotEnv(gym.Env): def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: + ) -> tuple[RobotObservation, dict[str, Any]]: """Reset environment to initial state. Args: @@ -249,7 +250,7 @@ class RobotEnv(gym.Env): self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names} return obs, {TeleopEvents.IS_INTERVENTION: False} - def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: + def step(self, action) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: """Execute one environment step with given action.""" joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())} diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/rl/joint_observations_processor.py similarity index 100% rename from src/lerobot/processor/joint_observations_processor.py rename to src/lerobot/rl/joint_observations_processor.py diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index fa81e7d09..09f849772 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -16,8 +16,8 @@ import logging from functools import cached_property -from typing import Any +from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig from ..robot import Robot @@ -116,7 +116,7 @@ class BiSOFollower(Robot): self.left_arm.setup_motors() self.right_arm.setup_motors() - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: obs_dict = {} # Add "left_" prefix @@ -129,7 +129,7 @@ class BiSOFollower(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: # Remove "left_" prefix left_action = { key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_") diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index 48cb09215..784a95577 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -18,12 +18,12 @@ import base64 import logging from functools import cached_property -from typing import Any import cv2 import numpy as np import requests +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -197,11 +197,11 @@ class EarthRoverMiniPlus(Robot): ACTION_ANGULAR_VEL: float, } - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: """Get current robot observation from SDK. Returns: - dict: Observation containing: + RobotObservation: Observation containing: - front: Front camera image (480, 640, 3) in RGB format - rear: Rear camera image (480, 640, 3) in RGB format - linear.vel: Current speed (0-1, SDK reports only positive speeds) @@ -255,7 +255,7 @@ class EarthRoverMiniPlus(Robot): return observation - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Send action to robot via SDK. Args: @@ -264,8 +264,7 @@ class EarthRoverMiniPlus(Robot): - angular.vel: Target angular velocity (-1 to 1) Returns: - dict: The action that was sent (matches action_features keys) - + RobotAction: The action that was sent (matches action_features keys) Raises: DeviceNotConnectedError: If robot is not connected diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 220a29f8c..4be8a0b17 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorNormMode @@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -128,7 +128,7 @@ class HopeJrArm(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -149,7 +149,7 @@ class HopeJrArm(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 9e960642b..73fb4464f 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorNormMode @@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -159,7 +159,7 @@ class HopeJrHand(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -181,7 +181,7 @@ class HopeJrHand(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 41a57828b..a1d001ba8 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -25,6 +24,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -182,7 +182,7 @@ class KochFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -202,7 +202,7 @@ class KochFollower(Robot): return obs_dict - def send_action(self, action: dict[str, float]) -> dict[str, float]: + def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -210,10 +210,10 @@ class KochFollower(Robot): Thus, this function always returns the action actually sent. Args: - action (dict[str, float]): The goal positions for the motors. + action (RobotAction): The goal positions for the motors. Returns: - dict[str, float]: The action sent to the motors, potentially clipped. + RobotAction: The action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 86fe017d6..c84e81001 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -28,6 +28,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -338,7 +339,7 @@ class LeKiwi(Robot): "theta.vel": theta, } # m/s and deg/s - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -369,7 +370,7 @@ class LeKiwi(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -380,7 +381,7 @@ class LeKiwi(Robot): RobotDeviceNotConnectedError: if robot is not connected. Returns: - np.ndarray: the action sent to the motors, potentially clipped. + RobotAction: the action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 19744e244..bb865dc10 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -18,11 +18,11 @@ import base64 import json import logging from functools import cached_property -from typing import Any import cv2 import numpy as np +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @@ -172,7 +172,7 @@ class LeKiwiClient(Robot): return last_msg - def _parse_observation_json(self, obs_string: str) -> dict[str, Any] | None: + def _parse_observation_json(self, obs_string: str) -> RobotObservation | None: """Parses the JSON observation string.""" try: return json.loads(obs_string) @@ -196,15 +196,15 @@ class LeKiwiClient(Robot): return None def _remote_state_from_obs( - self, observation: dict[str, Any] - ) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + self, observation: RobotObservation + ) -> tuple[dict[str, np.ndarray], RobotObservation]: """Extracts frames, and state from the parsed observation.""" flat_state = {key: observation.get(key, 0.0) for key in self._state_order} state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32) - obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec} + obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec} # Decode images current_frames: dict[str, np.ndarray] = {} @@ -217,7 +217,7 @@ class LeKiwiClient(Robot): return current_frames, obs_dict - def _get_data(self) -> tuple[dict[str, np.ndarray], dict[str, Any], dict[str, Any]]: + def _get_data(self) -> tuple[dict[str, np.ndarray], RobotObservation]: """ Polls the video socket for the latest observation data. @@ -252,7 +252,7 @@ class LeKiwiClient(Robot): return new_frames, new_state - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: """ Capture observations from the remote robot: current follower arm positions, present wheel speeds (converted to body-frame velocities: x, y, theta), @@ -307,12 +307,11 @@ class LeKiwiClient(Robot): def configure(self): pass - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ Args: - action (np.ndarray): array containing the goal positions for the motors. - + action (RobotAction): array containing the goal positions for the motors. Raises: RobotDeviceNotConnectedError: if robot is not connected. diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index 2dd851377..14668b3a7 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -26,6 +25,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -165,7 +165,7 @@ class OmxFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -185,7 +185,7 @@ class OmxFollower(Robot): return obs_dict - def send_action(self, action: dict[str, float]) -> dict[str, float]: + def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -193,10 +193,10 @@ class OmxFollower(Robot): Thus, this function always returns the action actually sent. Args: - action (dict[str, float]): The goal positions for the motors. + action (RobotAction): The goal positions for the motors. Returns: - dict[str, float]: The action sent to the motors, potentially clipped. + RobotAction: The action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 74742ee8d..6f4eef56c 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -18,9 +18,8 @@ from __future__ import annotations import time from typing import TYPE_CHECKING, Any -import numpy as np - from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.import_utils import _reachy2_sdk_available from ..robot import Robot @@ -171,8 +170,8 @@ class Reachy2Robot(Robot): else: return {} - def get_observation(self) -> dict[str, np.ndarray]: - obs_dict: dict[str, Any] = {} + def get_observation(self) -> RobotObservation: + obs_dict: RobotObservation = {} # Read Reachy 2 state before_read_t = time.perf_counter() @@ -185,7 +184,7 @@ class Reachy2Robot(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if self.reachy is not None: if not self.is_connected: raise ConnectionError() diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index 5e88b915b..d1021daf4 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -15,11 +15,11 @@ import abc import builtins from pathlib import Path -from typing import Any import draccus from lerobot.motors import MotorCalibration +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS from .config import RobotConfig @@ -153,28 +153,28 @@ class Robot(abc.ABC): pass @abc.abstractmethod - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: """ Retrieve the current observation from the robot. Returns: - dict[str, Any]: A flat dictionary representing the robot's current sensory state. Its structure + RobotObservation: A flat dictionary representing the robot's current sensory state. Its structure should match :pymeth:`observation_features`. """ pass @abc.abstractmethod - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """ Send an action command to the robot. Args: - action (dict[str, Any]): Dictionary representing the desired action. Its structure should match + action (RobotAction): Dictionary representing the desired action. Its structure should match :pymeth:`action_features`. Returns: - dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by + RobotAction: The action actually sent to the motors potentially clipped or modified, e.g. by safety limits on velocity. """ pass diff --git a/src/lerobot/robots/so_follower/robot_kinematic_processor.py b/src/lerobot/robots/so_follower/robot_kinematic_processor.py index 87e832db6..2aa60e12a 100644 --- a/src/lerobot/robots/so_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so_follower/robot_kinematic_processor.py @@ -28,6 +28,7 @@ from lerobot.processor import ( ProcessorStepRegistry, RobotAction, RobotActionProcessorStep, + RobotObservation, TransitionKey, ) from lerobot.utils.rotation import Rotation @@ -438,7 +439,7 @@ class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep): kinematics: RobotKinematics motor_names: list[str] - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def observation(self, observation: RobotObservation) -> RobotObservation: return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names) def transform_features( diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index 5e99b33a1..011a0061e 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -17,7 +17,7 @@ import logging import time from functools import cached_property -from typing import Any, TypeAlias +from typing import TypeAlias from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -25,6 +25,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -175,7 +176,7 @@ class SOFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -195,7 +196,7 @@ class SOFollower(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -206,7 +207,7 @@ class SOFollower(Robot): RobotDeviceNotConnectedError: if robot is not connected. Returns: - the action sent to the motors, potentially clipped. + RobotAction: the action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 1764f31b5..fa6e0da85 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -26,6 +26,7 @@ import numpy as np from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env +from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex from ..robot import Robot @@ -269,7 +270,7 @@ class UnitreeG1(Robot): for cam in self._cameras.values(): cam.disconnect() - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: lowstate = self._lowstate if lowstate is None: return {} @@ -350,7 +351,7 @@ class UnitreeG1(Robot): def observation_features(self) -> dict[str, type | tuple]: return {**self._motors_ft, **self._cameras_ft} - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: for motor in G1_29_JointIndex: key = f"{motor.name}.q" if key in action: diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 7b45e88e1..e32b80404 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -177,9 +177,9 @@ def rollout( action = policy.select_action(observation) action = postprocessor(action) - action_transition = {"action": action} + action_transition = {ACTION: action} action_transition = env_postprocessor(action_transition) - action = action_transition["action"] + action = action_transition[ACTION] # Convert to CPU / numpy. action_numpy: np.ndarray = action.to("cpu").numpy() diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 05d4534d4..18d8863d6 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -165,7 +165,7 @@ def teleop_loop( # Process action for robot through pipeline robot_action_to_send = robot_action_processor((teleop_action, obs)) - # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any]) + # Send processed action to robot (robot_action_processor.to_output should return RobotAction) _ = robot.send_action(robot_action_to_send) if display_data: diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 296447bad..1d8f4644b 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -63,6 +63,7 @@ else: from lerobot.configs import parser from lerobot.configs.types import NormalizationMode from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import ACTION, OBS_STATE @dataclass @@ -86,7 +87,7 @@ class TokenizerTrainingConfig: # Whether to apply delta transform (relative actions vs absolute actions) use_delta_transform: bool = False # Dataset key for state observations (default: "observation.state") - state_key: str = "observation.state" + state_key: str = OBS_STATE # Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) normalization_mode: str = "QUANTILES" # FAST vocabulary size (BPE vocab size) @@ -223,12 +224,10 @@ def process_episode(args): else: # if no state key, use zeros (no delta transform) state = np.zeros_like( - frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) + frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION]) ) - action = ( - frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) - ) + action = frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION]) states.append(state) actions.append(action) @@ -468,8 +467,8 @@ def train_tokenizer(cfg: TokenizerTrainingConfig): # get normalization stats from dataset norm_stats = dataset.meta.stats - if norm_stats is not None and "action" in norm_stats: - action_stats = norm_stats["action"] + if norm_stats is not None and ACTION in norm_stats: + action_stats = norm_stats[ACTION] # build encoded dimension indices encoded_dim_indices = [] diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index c7072f4a7..4dbb49c1d 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -20,6 +20,8 @@ from typing import Any import numpy as np +from lerobot.processor import RobotAction + from ..teleoperator import Teleoperator from ..utils import TeleopEvents from .configuration_gamepad import GamepadTeleopConfig @@ -83,7 +85,7 @@ class GamepadTeleop(Teleoperator): self.gamepad = Gamepad() self.gamepad.start() - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: # Update the controller to get fresh inputs self.gamepad.update() diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index ec8ea18f4..55c158da8 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -21,6 +21,7 @@ import time from queue import Queue from typing import Any +from lerobot.processor import RobotAction from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator @@ -124,7 +125,7 @@ class KeyboardTeleop(Teleoperator): def configure(self): pass - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: before_read_t = time.perf_counter() if not self.is_connected: @@ -181,7 +182,7 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, } - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError( "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." @@ -374,12 +375,12 @@ class KeyboardRoverTeleop(KeyboardTeleop): # Only remove key if it's being released self.current_pressed.pop(key_char, None) - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: """ Get the current action based on pressed keys. Returns: - dict with 'linear.vel' and 'angular.vel' keys + RobotAction with 'linear.vel' and 'angular.vel' keys """ before_read_t = time.perf_counter() diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index 95020a962..cd9e3a53d 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -20,6 +20,7 @@ from typing import Any import draccus from lerobot.motors.motors_bus import MotorCalibration +from lerobot.processor import RobotAction from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS from .config import TeleoperatorConfig @@ -150,12 +151,12 @@ class Teleoperator(abc.ABC): pass @abc.abstractmethod - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: """ Retrieve the current action from the teleoperator. Returns: - dict[str, Any]: A flat dictionary representing the teleoperator's current actions. Its + RobotAction: A flat dictionary representing the teleoperator's current actions. Its structure should match :pymeth:`observation_features`. """ pass diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index a96e1596d..43a61b4f7 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -28,11 +28,13 @@ OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" ACTION = "action" +ACTION_PREFIX = ACTION + "." ACTION_TOKENS = ACTION + ".tokens" ACTION_TOKEN_MASK = ACTION + ".token_mask" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" +INFO = "info" ROBOTS = "robots" TELEOPERATORS = "teleoperators" diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 9143d0f66..31ca8d247 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -14,12 +14,13 @@ import numbers import os -from typing import Any import numpy as np import rerun as rr -from .constants import OBS_PREFIX, OBS_STR +from lerobot.processor import RobotAction, RobotObservation + +from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR def init_rerun( @@ -50,8 +51,8 @@ def _is_scalar(x): def log_rerun_data( - observation: dict[str, Any] | None = None, - action: dict[str, Any] | None = None, + observation: RobotObservation | None = None, + action: RobotAction | None = None, compress_images: bool = False, ) -> None: """ @@ -96,7 +97,7 @@ def log_rerun_data( for k, v in action.items(): if v is None: continue - key = k if str(k).startswith("action.") else f"action.{k}" + key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}" if _is_scalar(v): rr.log(key, rr.Scalars(float(v))) diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index b0513fd38..d997cb6d4 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -17,10 +17,10 @@ import random from dataclasses import dataclass, field from functools import cached_property -from typing import Any from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.motors.motors_bus import Motor, MotorNormMode +from lerobot.processor import RobotAction, RobotObservation from lerobot.robots import Robot, RobotConfig from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from tests.mocks.mock_motors_bus import MockMotorsBus @@ -119,7 +119,7 @@ class MockRobot(Robot): def configure(self) -> None: pass - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -130,7 +130,7 @@ class MockRobot(Robot): f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) } - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index 71b49947c..04479bad9 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from functools import cached_property from typing import Any +from lerobot.processor import RobotAction from lerobot.teleoperators import Teleoperator, TeleoperatorConfig from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @@ -88,7 +89,7 @@ class MockTeleop(Teleoperator): def configure(self) -> None: pass - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.")