Merge branch 'main' into feature/add-multitask-dit

This commit is contained in:
Bryson Jones
2026-01-13 10:02:41 -08:00
committed by GitHub
51 changed files with 206 additions and 178 deletions
+2 -2
View File
@@ -40,7 +40,6 @@ from collections.abc import Callable
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any
import draccus
import grpc
@@ -48,6 +47,7 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.processor import RobotAction
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -351,7 +351,7 @@ class RobotClient:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
def control_loop_action(self, verbose: bool = False) -> RobotAction:
"""Reading and performing actions in local queue"""
# Lock only for queue operations
+2 -2
View File
@@ -18,12 +18,12 @@ from typing import Any
from lerobot.configs.types import PipelineFeatureType
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
def create_initial_features(
action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None
action: RobotAction | None = None, observation: RobotObservation | None = None
) -> dict[PipelineFeatureType, dict[str, Any]]:
"""
Creates the initial features dict for the dataset from action and observation specs.
+4 -2
View File
@@ -29,6 +29,8 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from lerobot.processor import RobotObservation
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
@@ -237,7 +239,7 @@ class LiberoEnv(gym.Env):
env.reset()
return env
def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation:
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
@@ -313,7 +315,7 @@ class LiberoEnv(gym.Env):
info = {"is_success": False}
return observation, info
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
if action.ndim != 1:
raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), "
+7 -5
View File
@@ -25,6 +25,8 @@ import metaworld.policies as policies
import numpy as np
from gymnasium import spaces
from lerobot.processor import RobotObservation
# ---- Load configuration data from the external JSON file ----
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
try:
@@ -161,7 +163,7 @@ class MetaworldEnv(gym.Env):
env._freeze_rand_vec = False # otherwise no randomization
return env
def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation:
image = None
if self._env is not None:
image = self._env.render()
@@ -196,7 +198,7 @@ class MetaworldEnv(gym.Env):
self,
seed: int | None = None,
**kwargs,
) -> tuple[dict[str, Any], dict[str, Any]]:
) -> tuple[RobotObservation, dict[str, Any]]:
"""
Reset the environment to its initial state.
@@ -204,7 +206,7 @@ class MetaworldEnv(gym.Env):
seed (Optional[int]): Random seed for environment initialization.
Returns:
observation (Dict[str, Any]): The initial formatted observation.
observation (RobotObservation): The initial formatted observation.
info (Dict[str, Any]): Additional info about the reset state.
"""
super().reset(seed=seed)
@@ -216,7 +218,7 @@ class MetaworldEnv(gym.Env):
info = {"is_success": False}
return observation, info
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
"""
Perform one environment step.
@@ -224,7 +226,7 @@ class MetaworldEnv(gym.Env):
action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
Returns:
observation (Dict[str, Any]): The formatted observation after the step.
observation (RobotObservation): The formatted observation after the step.
reward (float): The scalar reward for this step.
terminated (bool): Whether the episode terminated successfully.
truncated (bool): Whether the episode was truncated due to a time limit.
+2 -1
View File
@@ -29,6 +29,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.processor import RobotObservation
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
@@ -152,7 +153,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
)
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
task_result = env.call("task_description")
+6 -2
View File
@@ -52,7 +52,11 @@ from lerobot.processor.converters import (
transition_to_batch,
transition_to_policy_action,
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.utils.constants import (
ACTION,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
@@ -256,7 +260,7 @@ def make_pre_post_processors(
}
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features["action"].shape[0]
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
@PreTrainedConfig.register_subclass("groot")
@@ -137,14 +138,14 @@ class GrootConfig(PreTrainedConfig):
"No features of type FeatureType.VISUAL found in input_features."
)
if "observation.state" not in self.input_features:
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
self.input_features["observation.state"] = state_feature
self.input_features[OBS_STATE] = state_feature
else:
state_shape = self.input_features["observation.state"].shape
state_shape = self.input_features[OBS_STATE].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
@@ -152,14 +153,14 @@ class GrootConfig(PreTrainedConfig):
f"Either reduce state dimension or increase max_state_dim in config."
)
if "action" not in self.output_features:
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
self.output_features["action"] = action_feature
self.output_features[ACTION] = action_feature
else:
action_shape = self.output_features["action"].shape
action_shape = self.output_features[ACTION].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
+3 -3
View File
@@ -46,7 +46,7 @@ from lerobot.policies.groot.action_head.flow_matching_action_head import (
FlowmatchingActionHeadConfig,
)
from lerobot.policies.groot.utils import ensure_eagle_cache_ready
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
@@ -227,8 +227,8 @@ class GR00TN15(PreTrainedModel):
detected_error = False
error_msg = ERROR_MSG
if "action" in inputs:
action = inputs["action"]
if ACTION in inputs:
action = inputs[ACTION]
# In inference, action may be omitted or None; validate only when it's a tensor.
if action is None:
pass # allow None during inference
+2 -1
View File
@@ -41,6 +41,7 @@ from torch import Tensor
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1 import GR00TN15
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION
class GrootPolicy(PreTrainedPolicy):
@@ -147,7 +148,7 @@ class GrootPolicy(PreTrainedPolicy):
actions = outputs.get("action_pred")
original_action_dim = self.config.output_features["action"].shape[0]
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
return actions
+16 -12
View File
@@ -51,7 +51,11 @@ from lerobot.processor.converters import (
)
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
ACTION,
HF_LEROBOT_HOME,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@@ -107,9 +111,9 @@ def make_groot_pre_post_processors(
# Define feature specs for optional normalization steps
_features: dict[str, PolicyFeature] = {
# Observation features (only add those we may normalize)
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
# Action feature
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
}
# Normalize STATE and ACTION with min_max (SO100-like default)
@@ -120,7 +124,7 @@ def make_groot_pre_post_processors(
# Determine env action dimension from config (simple, object-like PolicyFeature)
try:
env_action_dim = int(config.output_features["action"].shape[0])
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
@@ -268,9 +272,9 @@ class GrootPackInputsStep(ProcessorStep):
return torch.where(mask, mapped, torch.zeros_like(mapped))
# 1) Video (B, T=1, V, H, W, C) uint8
img_keys = sorted([k for k in obs if k.startswith("observation.images.")])
if not img_keys and "observation.image" in obs:
img_keys = ["observation.image"]
img_keys = sorted([k for k in obs if k.startswith(OBS_IMAGES)])
if not img_keys and OBS_IMAGE in obs:
img_keys = [OBS_IMAGE]
if img_keys:
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
video = np.stack(cams, axis=1) # (B, V, H, W, C)
@@ -294,14 +298,14 @@ class GrootPackInputsStep(ProcessorStep):
comp["language"] = lang
# 3) State/state_mask -> (B, 1, max_state_dim)
if "observation.state" in obs:
state = obs["observation.state"] # (B, D)
if OBS_STATE in obs:
state = obs[OBS_STATE] # (B, D)
if state.dim() != 2:
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
bsz, d = state.shape
# Normalize BEFORE padding
if self.normalize_min_max:
state = _min_max_norm(state, "observation.state")
state = _min_max_norm(state, OBS_STATE)
state = state.unsqueeze(1) # (B, 1, D)
if d > self.max_state_dim:
state = state[:, :, : self.max_state_dim]
@@ -320,11 +324,11 @@ class GrootPackInputsStep(ProcessorStep):
# Normalize BEFORE temporal expansion/padding
if self.normalize_min_max:
if action.dim() == 2:
action = _min_max_norm(action, "action")
action = _min_max_norm(action, ACTION)
elif action.dim() == 3:
b, t, d = action.shape
flat = action.reshape(b * t, d)
flat = _min_max_norm(flat, "action")
flat = _min_max_norm(flat, ACTION)
action = flat.view(b, t, d)
if action.dim() == 2:
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
@@ -590,7 +594,7 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
# forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0
# inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min
if self.normalize_min_max and self.stats is not None:
stats_k = self.stats.get("action", {})
stats_k = self.stats.get(ACTION, {})
d = action.shape[-1]
min_v = torch.as_tensor(
stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device
@@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -124,19 +124,19 @@ class PI0Config(PreTrainedConfig):
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
self.input_features[OBS_STATE] = state_feature
if "action" not in self.output_features:
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
self.output_features[ACTION] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
@@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -117,26 +118,26 @@ class PI05Config(PreTrainedConfig):
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
key = OBS_IMAGES + f".empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
self.input_features[OBS_STATE] = state_feature
if "action" not in self.output_features:
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
self.output_features[ACTION] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
@@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -110,26 +111,26 @@ class PI0FastConfig(PreTrainedConfig):
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
key = OBS_IMAGES + f".empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
self.input_features[OBS_STATE] = state_feature
if "action" not in self.output_features:
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
self.output_features[ACTION] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
@@ -26,6 +26,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
@PreTrainedConfig.register_subclass("sarm")
@@ -86,8 +87,8 @@ class SARMConfig(PreTrainedConfig):
pretrained_model_path: str | None = None
device: str | None = None
image_key: str = "observation.images.top" # Key for image used from the dataset
state_key: str = "observation.state"
image_key: str = OBS_IMAGES + ".top" # Key for image used from the dataset
state_key: str = OBS_STATE
# Populated by the processor (video_features, state_features, text_features)
input_features: dict = field(default_factory=lambda: {})
+2 -1
View File
@@ -40,6 +40,7 @@ from lerobot.policies.sarm.sarm_utils import (
normalize_stage_tau,
pad_state_to_max_dim,
)
from lerobot.utils.constants import OBS_STR
class StageTransformer(nn.Module):
@@ -721,7 +722,7 @@ class SARMRewardModel(PreTrainedPolicy):
Returns:
Tuple of (total_loss, output_dict with loss components)
"""
observation = batch.get("observation", batch)
observation = batch.get(OBS_STR, batch)
# Extract features
video_features = observation["video_features"].to(self.device)
+1 -2
View File
@@ -16,7 +16,6 @@
import logging
from collections import deque
from typing import Any
import numpy as np
import torch
@@ -140,7 +139,7 @@ def prepare_observation_for_inference(
def build_inference_frame(
observation: dict[str, Any],
observation: RobotObservation,
device: torch.device,
ds_features: dict[str, dict],
task: str | None = None,
@@ -18,6 +18,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
@PreTrainedConfig.register_subclass("wall_x")
@@ -105,14 +106,14 @@ class WallXConfig(PreTrainedConfig):
"No features of type FeatureType.VISUAL found in input_features."
)
if "observation.state" not in self.input_features:
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
self.input_features[OBS_STATE] = state_feature
else:
state_shape = self.input_features["observation.state"].shape
state_shape = self.input_features[OBS_STATE].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
@@ -120,14 +121,14 @@ class WallXConfig(PreTrainedConfig):
f"Either reduce state dimension or increase max_state_dim in config."
)
if "action" not in self.output_features:
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
self.output_features[ACTION] = action_feature
else:
action_shape = self.output_features["action"].shape
action_shape = self.output_features[ACTION].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
@@ -1861,7 +1861,7 @@ class WallXPolicy(PreTrainedPolicy):
dim=-1,
)
else:
action_dim = self.config.output_features["action"].shape[0]
action_dim = self.config.output_features[ACTION].shape[0]
dof_mask = torch.cat(
[
torch.ones(
@@ -1977,7 +1977,7 @@ class WallXPolicy(PreTrainedPolicy):
elif self.config.prediction_mode == "fast":
output = self.model(
**batch,
action_dim=self.config.output_features["action"].shape[0],
action_dim=self.config.output_features[ACTION].shape[0],
pred_horizon=self.config.chunk_size,
mode="predict",
predict_mode="fast",
@@ -1989,7 +1989,7 @@ class WallXPolicy(PreTrainedPolicy):
actions = output["predict_action"]
# Unpad actions to actual action dimension
action_dim = self.config.output_features["action"].shape[0]
action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :action_dim]
return actions
+9 -7
View File
@@ -41,6 +41,7 @@ from lerobot.processor.converters import policy_action_to_transition, transition
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
OBS_PREFIX,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
@@ -137,8 +138,9 @@ class LiberoProcessorStep(ObservationProcessorStep):
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
robot_state_str = OBS_PREFIX + "robot_state"
if robot_state_str in processed_obs:
robot_state = processed_obs.pop(robot_state_str)
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
@@ -174,8 +176,8 @@ class LiberoProcessorStep(ObservationProcessorStep):
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
state_feats[OBS_STATE] = PolicyFeature(
key=OBS_STATE,
shape=(20,),
dtype="float32",
)
@@ -247,7 +249,7 @@ class XVLAImageScaleProcessorStep(ProcessorStep):
keys_to_scale = self.image_keys
if keys_to_scale is None:
# Auto-detect image keys
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
keys_to_scale = [k for k in obs if k.startswith(OBS_IMAGES)]
# Scale each image
for key in keys_to_scale:
@@ -303,7 +305,7 @@ class XVLAImageToFloatProcessorStep(ProcessorStep):
keys_to_convert = self.image_keys
if keys_to_convert is None:
# Auto-detect image keys
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
keys_to_convert = [k for k in obs if k.startswith(OBS_IMAGES)]
# Convert each image
for key in keys_to_convert:
@@ -376,7 +378,7 @@ class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
keys_to_normalize = self.image_keys
if keys_to_normalize is None:
# Auto-detect image keys
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
keys_to_normalize = [k for k in obs if k.startswith(OBS_IMAGES)]
# Normalize each image
for key in keys_to_normalize:
-3
View File
@@ -49,7 +49,6 @@ from .hil_processor import (
RewardClassifierProcessorStep,
TimeLimitProcessorStep,
)
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
from .observation_processor import VanillaObservationProcessorStep
from .pipeline import (
@@ -94,14 +93,12 @@ __all__ = [
"ImageCropResizeProcessorStep",
"InfoProcessorStep",
"InterventionActionProcessorStep",
"JointVelocityProcessorStep",
"make_default_processors",
"make_default_teleop_action_processor",
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"MotorCurrentProcessorStep",
"NormalizerProcessorStep",
"Numpy2TorchActionProcessorStep",
"ObservationProcessorStep",
+3 -3
View File
@@ -23,7 +23,7 @@ from typing import Any
import numpy as np
import torch
from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED
from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -176,7 +176,7 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
def create_transition(
observation: dict[str, Any] | None = None,
observation: RobotObservation | None = None,
action: PolicyAction | RobotAction | None = None,
reward: float = 0.0,
done: bool = False,
@@ -384,7 +384,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
REWARD: transition.get(TransitionKey.REWARD, 0.0),
DONE: transition.get(TransitionKey.DONE, False),
TRUNCATED: transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
INFO: transition.get(TransitionKey.INFO, {}),
}
# Add complementary data.
+1 -1
View File
@@ -45,7 +45,7 @@ RobotObservation: TypeAlias = dict[str, Any]
EnvTransition = TypedDict(
"EnvTransition",
{
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
TransitionKey.OBSERVATION.value: RobotObservation | None,
TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None,
TransitionKey.REWARD.value: float | torch.Tensor | None,
TransitionKey.DONE.value: bool | torch.Tensor | None,
+6 -5
View File
@@ -18,7 +18,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -60,8 +60,9 @@ class LiberoProcessorStep(ObservationProcessorStep):
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
observation_robot_state_str = OBS_PREFIX + "robot_state"
if observation_robot_state_str in processed_obs:
robot_state = processed_obs.pop(observation_robot_state_str)
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
@@ -98,8 +99,8 @@ class LiberoProcessorStep(ObservationProcessorStep):
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
state_feats[OBS_STATE] = PolicyFeature(
key=OBS_STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
+2 -2
View File
@@ -30,7 +30,7 @@ from lerobot.utils.constants import ACTION
from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation
@dataclass
@@ -239,7 +239,7 @@ class _NormalizationMixin:
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
return config
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
def _normalize_observation(self, observation: RobotObservation, inverse: bool) -> dict[str, Tensor]:
"""
Applies (un)normalization to all relevant features in an observation dictionary.
+3 -3
View File
@@ -49,7 +49,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.hub import HubMixin
from .converters import batch_to_transition, create_transition, transition_to_batch
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
# Generic type variables for pipeline input and output.
TInput = TypeVar("TInput")
@@ -1337,7 +1337,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
return features
# Convenience methods for processing individual parts of a transition.
def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
def process_observation(self, observation: RobotObservation) -> RobotObservation:
"""Processes only the observation part of a transition through the pipeline.
Args:
@@ -1440,7 +1440,7 @@ class ObservationProcessorStep(ProcessorStep, ABC):
"""An abstract `ProcessorStep` that specifically targets the observation in a transition."""
@abstractmethod
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
def observation(self, observation: RobotObservation) -> RobotObservation:
"""Processes an observation dictionary. Subclasses must implement this method.
Args:
+2 -2
View File
@@ -38,7 +38,7 @@ from lerobot.utils.constants import (
)
from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, TransitionKey
from .core import EnvTransition, RobotObservation, TransitionKey
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
# Conditional import for type checking and lazy loading
@@ -139,7 +139,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
def observation(self, observation: RobotObservation) -> RobotObservation:
"""
Tokenizes the task description and adds it to the observation dictionary.
+6 -5
View File
@@ -38,13 +38,12 @@ from lerobot.processor import (
GripperPenaltyProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
JointVelocityProcessorStep,
MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep,
MotorCurrentProcessorStep,
Numpy2TorchActionProcessorStep,
RewardClassifierProcessorStep,
RobotActionToPolicyActionProcessorStep,
RobotObservation,
TimeLimitProcessorStep,
Torch2NumpyActionProcessorStep,
TransitionKey,
@@ -77,6 +76,8 @@ from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
logging.basicConfig(level=logging.INFO)
@@ -163,7 +164,7 @@ class RobotEnv(gym.Env):
self._setup_spaces()
def _get_observation(self) -> dict[str, Any]:
def _get_observation(self) -> RobotObservation:
"""Get current robot observation including joint positions and camera images."""
obs_dict = self.robot.get_observation()
raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names}
@@ -220,7 +221,7 @@ class RobotEnv(gym.Env):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
) -> tuple[RobotObservation, dict[str, Any]]:
"""Reset environment to initial state.
Args:
@@ -249,7 +250,7 @@ class RobotEnv(gym.Env):
self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
return obs, {TeleopEvents.IS_INTERVENTION: False}
def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
def step(self, action) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
"""Execute one environment step with given action."""
joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())}
@@ -16,8 +16,8 @@
import logging
from functools import cached_property
from typing import Any
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
from ..robot import Robot
@@ -116,7 +116,7 @@ class BiSOFollower(Robot):
self.left_arm.setup_motors()
self.right_arm.setup_motors()
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
obs_dict = {}
# Add "left_" prefix
@@ -129,7 +129,7 @@ class BiSOFollower(Robot):
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
# Remove "left_" prefix
left_action = {
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
@@ -18,12 +18,12 @@
import base64
import logging
from functools import cached_property
from typing import Any
import cv2
import numpy as np
import requests
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -197,11 +197,11 @@ class EarthRoverMiniPlus(Robot):
ACTION_ANGULAR_VEL: float,
}
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
"""Get current robot observation from SDK.
Returns:
dict: Observation containing:
RobotObservation: Observation containing:
- front: Front camera image (480, 640, 3) in RGB format
- rear: Rear camera image (480, 640, 3) in RGB format
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
@@ -255,7 +255,7 @@ class EarthRoverMiniPlus(Robot):
return observation
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
"""Send action to robot via SDK.
Args:
@@ -264,8 +264,7 @@ class EarthRoverMiniPlus(Robot):
- angular.vel: Target angular velocity (-1 to 1)
Returns:
dict: The action that was sent (matches action_features keys)
RobotAction: The action that was sent (matches action_features keys)
Raises:
DeviceNotConnectedError: If robot is not connected
+3 -3
View File
@@ -17,7 +17,6 @@
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorNormMode
@@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -128,7 +128,7 @@ class HopeJrArm(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -149,7 +149,7 @@ class HopeJrArm(Robot):
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
+3 -3
View File
@@ -17,7 +17,6 @@
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorNormMode
@@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -159,7 +159,7 @@ class HopeJrHand(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -181,7 +181,7 @@ class HopeJrHand(Robot):
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -17,7 +17,6 @@
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
@@ -25,6 +24,7 @@ from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -182,7 +182,7 @@ class KochFollower(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -202,7 +202,7 @@ class KochFollower(Robot):
return obs_dict
def send_action(self, action: dict[str, float]) -> dict[str, float]:
def send_action(self, action: RobotAction) -> RobotAction:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
@@ -210,10 +210,10 @@ class KochFollower(Robot):
Thus, this function always returns the action actually sent.
Args:
action (dict[str, float]): The goal positions for the motors.
action (RobotAction): The goal positions for the motors.
Returns:
dict[str, float]: The action sent to the motors, potentially clipped.
RobotAction: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
+4 -3
View File
@@ -28,6 +28,7 @@ from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -338,7 +339,7 @@ class LeKiwi(Robot):
"theta.vel": theta,
} # m/s and deg/s
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -369,7 +370,7 @@ class LeKiwi(Robot):
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
"""Command lekiwi to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
@@ -380,7 +381,7 @@ class LeKiwi(Robot):
RobotDeviceNotConnectedError: if robot is not connected.
Returns:
np.ndarray: the action sent to the motors, potentially clipped.
RobotAction: the action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
+9 -10
View File
@@ -18,11 +18,11 @@ import base64
import json
import logging
from functools import cached_property
from typing import Any
import cv2
import numpy as np
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@@ -172,7 +172,7 @@ class LeKiwiClient(Robot):
return last_msg
def _parse_observation_json(self, obs_string: str) -> dict[str, Any] | None:
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
"""Parses the JSON observation string."""
try:
return json.loads(obs_string)
@@ -196,15 +196,15 @@ class LeKiwiClient(Robot):
return None
def _remote_state_from_obs(
self, observation: dict[str, Any]
) -> tuple[dict[str, np.ndarray], dict[str, Any]]:
self, observation: RobotObservation
) -> tuple[dict[str, np.ndarray], RobotObservation]:
"""Extracts frames, and state from the parsed observation."""
flat_state = {key: observation.get(key, 0.0) for key in self._state_order}
state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32)
obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec}
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
# Decode images
current_frames: dict[str, np.ndarray] = {}
@@ -217,7 +217,7 @@ class LeKiwiClient(Robot):
return current_frames, obs_dict
def _get_data(self) -> tuple[dict[str, np.ndarray], dict[str, Any], dict[str, Any]]:
def _get_data(self) -> tuple[dict[str, np.ndarray], RobotObservation]:
"""
Polls the video socket for the latest observation data.
@@ -252,7 +252,7 @@ class LeKiwiClient(Robot):
return new_frames, new_state
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
"""
Capture observations from the remote robot: current follower arm positions,
present wheel speeds (converted to body-frame velocities: x, y, theta),
@@ -307,12 +307,11 @@ class LeKiwiClient(Robot):
def configure(self):
pass
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
Args:
action (np.ndarray): array containing the goal positions for the motors.
action (RobotAction): array containing the goal positions for the motors.
Raises:
RobotDeviceNotConnectedError: if robot is not connected.
@@ -17,7 +17,6 @@
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
@@ -26,6 +25,7 @@ from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -165,7 +165,7 @@ class OmxFollower(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -185,7 +185,7 @@ class OmxFollower(Robot):
return obs_dict
def send_action(self, action: dict[str, float]) -> dict[str, float]:
def send_action(self, action: RobotAction) -> RobotAction:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
@@ -193,10 +193,10 @@ class OmxFollower(Robot):
Thus, this function always returns the action actually sent.
Args:
action (dict[str, float]): The goal positions for the motors.
action (RobotAction): The goal positions for the motors.
Returns:
dict[str, float]: The action sent to the motors, potentially clipped.
RobotAction: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
+4 -5
View File
@@ -18,9 +18,8 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any
import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.import_utils import _reachy2_sdk_available
from ..robot import Robot
@@ -171,8 +170,8 @@ class Reachy2Robot(Robot):
else:
return {}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict: dict[str, Any] = {}
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
# Read Reachy 2 state
before_read_t = time.perf_counter()
@@ -185,7 +184,7 @@ class Reachy2Robot(Robot):
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
if self.reachy is not None:
if not self.is_connected:
raise ConnectionError()
+6 -6
View File
@@ -15,11 +15,11 @@
import abc
import builtins
from pathlib import Path
from typing import Any
import draccus
from lerobot.motors import MotorCalibration
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS
from .config import RobotConfig
@@ -153,28 +153,28 @@ class Robot(abc.ABC):
pass
@abc.abstractmethod
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
"""
Retrieve the current observation from the robot.
Returns:
dict[str, Any]: A flat dictionary representing the robot's current sensory state. Its structure
RobotObservation: A flat dictionary representing the robot's current sensory state. Its structure
should match :pymeth:`observation_features`.
"""
pass
@abc.abstractmethod
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
"""
Send an action command to the robot.
Args:
action (dict[str, Any]): Dictionary representing the desired action. Its structure should match
action (RobotAction): Dictionary representing the desired action. Its structure should match
:pymeth:`action_features`.
Returns:
dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by
RobotAction: The action actually sent to the motors potentially clipped or modified, e.g. by
safety limits on velocity.
"""
pass
@@ -28,6 +28,7 @@ from lerobot.processor import (
ProcessorStepRegistry,
RobotAction,
RobotActionProcessorStep,
RobotObservation,
TransitionKey,
)
from lerobot.utils.rotation import Rotation
@@ -438,7 +439,7 @@ class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep):
kinematics: RobotKinematics
motor_names: list[str]
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
def observation(self, observation: RobotObservation) -> RobotObservation:
return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names)
def transform_features(
@@ -17,7 +17,7 @@
import logging
import time
from functools import cached_property
from typing import Any, TypeAlias
from typing import TypeAlias
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
@@ -25,6 +25,7 @@ from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -175,7 +176,7 @@ class SOFollower(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -195,7 +196,7 @@ class SOFollower(Robot):
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
@@ -206,7 +207,7 @@ class SOFollower(Robot):
RobotDeviceNotConnectedError: if robot is not connected.
Returns:
the action sent to the motors, potentially clipped.
RobotAction: the action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
+3 -2
View File
@@ -26,6 +26,7 @@ import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.envs.factory import make_env
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from ..robot import Robot
@@ -269,7 +270,7 @@ class UnitreeG1(Robot):
for cam in self._cameras.values():
cam.disconnect()
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
lowstate = self._lowstate
if lowstate is None:
return {}
@@ -350,7 +351,7 @@ class UnitreeG1(Robot):
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
for motor in G1_29_JointIndex:
key = f"{motor.name}.q"
if key in action:
+2 -2
View File
@@ -177,9 +177,9 @@ def rollout(
action = policy.select_action(observation)
action = postprocessor(action)
action_transition = {"action": action}
action_transition = {ACTION: action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
action = action_transition[ACTION]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
+1 -1
View File
@@ -165,7 +165,7 @@ def teleop_loop(
# Process action for robot through pipeline
robot_action_to_send = robot_action_processor((teleop_action, obs))
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
# Send processed action to robot (robot_action_processor.to_output should return RobotAction)
_ = robot.send_action(robot_action_to_send)
if display_data:
@@ -63,6 +63,7 @@ else:
from lerobot.configs import parser
from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_STATE
@dataclass
@@ -86,7 +87,7 @@ class TokenizerTrainingConfig:
# Whether to apply delta transform (relative actions vs absolute actions)
use_delta_transform: bool = False
# Dataset key for state observations (default: "observation.state")
state_key: str = "observation.state"
state_key: str = OBS_STATE
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
normalization_mode: str = "QUANTILES"
# FAST vocabulary size (BPE vocab size)
@@ -223,12 +224,10 @@ def process_episode(args):
else:
# if no state key, use zeros (no delta transform)
state = np.zeros_like(
frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"])
frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION])
)
action = (
frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"])
)
action = frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION])
states.append(state)
actions.append(action)
@@ -468,8 +467,8 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
# get normalization stats from dataset
norm_stats = dataset.meta.stats
if norm_stats is not None and "action" in norm_stats:
action_stats = norm_stats["action"]
if norm_stats is not None and ACTION in norm_stats:
action_stats = norm_stats[ACTION]
# build encoded dimension indices
encoded_dim_indices = []
@@ -20,6 +20,8 @@ from typing import Any
import numpy as np
from lerobot.processor import RobotAction
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
from .configuration_gamepad import GamepadTeleopConfig
@@ -83,7 +85,7 @@ class GamepadTeleop(Teleoperator):
self.gamepad = Gamepad()
self.gamepad.start()
def get_action(self) -> dict[str, Any]:
def get_action(self) -> RobotAction:
# Update the controller to get fresh inputs
self.gamepad.update()
@@ -21,6 +21,7 @@ import time
from queue import Queue
from typing import Any
from lerobot.processor import RobotAction
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
@@ -124,7 +125,7 @@ class KeyboardTeleop(Teleoperator):
def configure(self):
pass
def get_action(self) -> dict[str, Any]:
def get_action(self) -> RobotAction:
before_read_t = time.perf_counter()
if not self.is_connected:
@@ -181,7 +182,7 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
}
def get_action(self) -> dict[str, Any]:
def get_action(self) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
@@ -374,12 +375,12 @@ class KeyboardRoverTeleop(KeyboardTeleop):
# Only remove key if it's being released
self.current_pressed.pop(key_char, None)
def get_action(self) -> dict[str, Any]:
def get_action(self) -> RobotAction:
"""
Get the current action based on pressed keys.
Returns:
dict with 'linear.vel' and 'angular.vel' keys
RobotAction with 'linear.vel' and 'angular.vel' keys
"""
before_read_t = time.perf_counter()
+3 -2
View File
@@ -20,6 +20,7 @@ from typing import Any
import draccus
from lerobot.motors.motors_bus import MotorCalibration
from lerobot.processor import RobotAction
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
from .config import TeleoperatorConfig
@@ -150,12 +151,12 @@ class Teleoperator(abc.ABC):
pass
@abc.abstractmethod
def get_action(self) -> dict[str, Any]:
def get_action(self) -> RobotAction:
"""
Retrieve the current action from the teleoperator.
Returns:
dict[str, Any]: A flat dictionary representing the teleoperator's current actions. Its
RobotAction: A flat dictionary representing the teleoperator's current actions. Its
structure should match :pymeth:`observation_features`.
"""
pass
+2
View File
@@ -28,11 +28,13 @@ OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
ACTION = "action"
ACTION_PREFIX = ACTION + "."
ACTION_TOKENS = ACTION + ".tokens"
ACTION_TOKEN_MASK = ACTION + ".token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
INFO = "info"
ROBOTS = "robots"
TELEOPERATORS = "teleoperators"
+6 -5
View File
@@ -14,12 +14,13 @@
import numbers
import os
from typing import Any
import numpy as np
import rerun as rr
from .constants import OBS_PREFIX, OBS_STR
from lerobot.processor import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
def init_rerun(
@@ -50,8 +51,8 @@ def _is_scalar(x):
def log_rerun_data(
observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None,
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
@@ -96,7 +97,7 @@ def log_rerun_data(
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith("action.") else f"action.{k}"
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
+3 -3
View File
@@ -17,10 +17,10 @@
import random
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
from lerobot.cameras import CameraConfig, make_cameras_from_configs
from lerobot.motors.motors_bus import Motor, MotorNormMode
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots import Robot, RobotConfig
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from tests.mocks.mock_motors_bus import MockMotorsBus
@@ -119,7 +119,7 @@ class MockRobot(Robot):
def configure(self) -> None:
pass
def get_observation(self) -> dict[str, Any]:
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -130,7 +130,7 @@ class MockRobot(Robot):
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
def send_action(self, action: RobotAction) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
+2 -1
View File
@@ -19,6 +19,7 @@ from dataclasses import dataclass
from functools import cached_property
from typing import Any
from lerobot.processor import RobotAction
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@@ -88,7 +89,7 @@ class MockTeleop(Teleoperator):
def configure(self) -> None:
pass
def get_action(self) -> dict[str, Any]:
def get_action(self) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")