mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Merge branch 'main' into feature/add-multitask-dit
This commit is contained in:
@@ -40,7 +40,6 @@ from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
@@ -48,6 +47,7 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -351,7 +351,7 @@ class RobotClient:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
def control_loop_action(self, verbose: bool = False) -> RobotAction:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
|
||||
@@ -18,12 +18,12 @@ from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None
|
||||
action: RobotAction | None = None, observation: RobotObservation | None = None
|
||||
) -> dict[PipelineFeatureType, dict[str, Any]]:
|
||||
"""
|
||||
Creates the initial features dict for the dataset from action and observation specs.
|
||||
|
||||
@@ -29,6 +29,8 @@ from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
@@ -237,7 +239,7 @@ class LiberoEnv(gym.Env):
|
||||
env.reset()
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
|
||||
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation:
|
||||
images = {}
|
||||
for camera_name in self.camera_name:
|
||||
image = raw_obs[camera_name]
|
||||
@@ -313,7 +315,7 @@ class LiberoEnv(gym.Env):
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
if action.ndim != 1:
|
||||
raise ValueError(
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
|
||||
@@ -25,6 +25,8 @@ import metaworld.policies as policies
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
# ---- Load configuration data from the external JSON file ----
|
||||
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
|
||||
try:
|
||||
@@ -161,7 +163,7 @@ class MetaworldEnv(gym.Env):
|
||||
env._freeze_rand_vec = False # otherwise no randomization
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
|
||||
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation:
|
||||
image = None
|
||||
if self._env is not None:
|
||||
image = self._env.render()
|
||||
@@ -196,7 +198,7 @@ class MetaworldEnv(gym.Env):
|
||||
self,
|
||||
seed: int | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
) -> tuple[RobotObservation, dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
|
||||
@@ -204,7 +206,7 @@ class MetaworldEnv(gym.Env):
|
||||
seed (Optional[int]): Random seed for environment initialization.
|
||||
|
||||
Returns:
|
||||
observation (Dict[str, Any]): The initial formatted observation.
|
||||
observation (RobotObservation): The initial formatted observation.
|
||||
info (Dict[str, Any]): Additional info about the reset state.
|
||||
"""
|
||||
super().reset(seed=seed)
|
||||
@@ -216,7 +218,7 @@ class MetaworldEnv(gym.Env):
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
"""
|
||||
Perform one environment step.
|
||||
|
||||
@@ -224,7 +226,7 @@ class MetaworldEnv(gym.Env):
|
||||
action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
|
||||
|
||||
Returns:
|
||||
observation (Dict[str, Any]): The formatted observation after the step.
|
||||
observation (RobotObservation): The formatted observation after the step.
|
||||
reward (float): The scalar reward for this step.
|
||||
terminated (bool): Whether the episode terminated successfully.
|
||||
truncated (bool): Whether the episode was truncated due to a time limit.
|
||||
|
||||
@@ -29,6 +29,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.processor import RobotObservation
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
@@ -152,7 +153,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
)
|
||||
|
||||
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
task_result = env.call("task_description")
|
||||
|
||||
@@ -52,7 +52,11 @@ from lerobot.processor.converters import (
|
||||
transition_to_batch,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
@@ -256,7 +260,7 @@ def make_pre_post_processors(
|
||||
}
|
||||
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features["action"].shape[0]
|
||||
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@@ -137,14 +138,14 @@ class GrootConfig(PreTrainedConfig):
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
else:
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_shape = self.input_features[OBS_STATE].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
@@ -152,14 +153,14 @@ class GrootConfig(PreTrainedConfig):
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
self.output_features[ACTION] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_shape = self.output_features[ACTION].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
|
||||
@@ -46,7 +46,7 @@ from lerobot.policies.groot.action_head.flow_matching_action_head import (
|
||||
FlowmatchingActionHeadConfig,
|
||||
)
|
||||
from lerobot.policies.groot.utils import ensure_eagle_cache_ready
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
|
||||
|
||||
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
@@ -227,8 +227,8 @@ class GR00TN15(PreTrainedModel):
|
||||
|
||||
detected_error = False
|
||||
error_msg = ERROR_MSG
|
||||
if "action" in inputs:
|
||||
action = inputs["action"]
|
||||
if ACTION in inputs:
|
||||
action = inputs[ACTION]
|
||||
# In inference, action may be omitted or None; validate only when it's a tensor.
|
||||
if action is None:
|
||||
pass # allow None during inference
|
||||
|
||||
@@ -41,6 +41,7 @@ from torch import Tensor
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1 import GR00TN15
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
class GrootPolicy(PreTrainedPolicy):
|
||||
@@ -147,7 +148,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
original_action_dim = self.config.output_features["action"].shape[0]
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
return actions
|
||||
|
||||
@@ -51,7 +51,11 @@ from lerobot.processor.converters import (
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
HF_LEROBOT_HOME,
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGES,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
@@ -107,9 +111,9 @@ def make_groot_pre_post_processors(
|
||||
# Define feature specs for optional normalization steps
|
||||
_features: dict[str, PolicyFeature] = {
|
||||
# Observation features (only add those we may normalize)
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
|
||||
# Action feature
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
|
||||
}
|
||||
|
||||
# Normalize STATE and ACTION with min_max (SO100-like default)
|
||||
@@ -120,7 +124,7 @@ def make_groot_pre_post_processors(
|
||||
|
||||
# Determine env action dimension from config (simple, object-like PolicyFeature)
|
||||
try:
|
||||
env_action_dim = int(config.output_features["action"].shape[0])
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
|
||||
@@ -268,9 +272,9 @@ class GrootPackInputsStep(ProcessorStep):
|
||||
return torch.where(mask, mapped, torch.zeros_like(mapped))
|
||||
|
||||
# 1) Video (B, T=1, V, H, W, C) uint8
|
||||
img_keys = sorted([k for k in obs if k.startswith("observation.images.")])
|
||||
if not img_keys and "observation.image" in obs:
|
||||
img_keys = ["observation.image"]
|
||||
img_keys = sorted([k for k in obs if k.startswith(OBS_IMAGES)])
|
||||
if not img_keys and OBS_IMAGE in obs:
|
||||
img_keys = [OBS_IMAGE]
|
||||
if img_keys:
|
||||
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
|
||||
video = np.stack(cams, axis=1) # (B, V, H, W, C)
|
||||
@@ -294,14 +298,14 @@ class GrootPackInputsStep(ProcessorStep):
|
||||
comp["language"] = lang
|
||||
|
||||
# 3) State/state_mask -> (B, 1, max_state_dim)
|
||||
if "observation.state" in obs:
|
||||
state = obs["observation.state"] # (B, D)
|
||||
if OBS_STATE in obs:
|
||||
state = obs[OBS_STATE] # (B, D)
|
||||
if state.dim() != 2:
|
||||
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
|
||||
bsz, d = state.shape
|
||||
# Normalize BEFORE padding
|
||||
if self.normalize_min_max:
|
||||
state = _min_max_norm(state, "observation.state")
|
||||
state = _min_max_norm(state, OBS_STATE)
|
||||
state = state.unsqueeze(1) # (B, 1, D)
|
||||
if d > self.max_state_dim:
|
||||
state = state[:, :, : self.max_state_dim]
|
||||
@@ -320,11 +324,11 @@ class GrootPackInputsStep(ProcessorStep):
|
||||
# Normalize BEFORE temporal expansion/padding
|
||||
if self.normalize_min_max:
|
||||
if action.dim() == 2:
|
||||
action = _min_max_norm(action, "action")
|
||||
action = _min_max_norm(action, ACTION)
|
||||
elif action.dim() == 3:
|
||||
b, t, d = action.shape
|
||||
flat = action.reshape(b * t, d)
|
||||
flat = _min_max_norm(flat, "action")
|
||||
flat = _min_max_norm(flat, ACTION)
|
||||
action = flat.view(b, t, d)
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
|
||||
@@ -590,7 +594,7 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
# forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0
|
||||
# inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min
|
||||
if self.normalize_min_max and self.stats is not None:
|
||||
stats_k = self.stats.get("action", {})
|
||||
stats_k = self.stats.get(ACTION, {})
|
||||
d = action.shape[-1]
|
||||
min_v = torch.as_tensor(
|
||||
stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device
|
||||
|
||||
@@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
@@ -124,19 +124,19 @@ class PI0Config(PreTrainedConfig):
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if "action" not in self.output_features:
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
@@ -117,26 +118,26 @@ class PI05Config(PreTrainedConfig):
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if "action" not in self.output_features:
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
@@ -110,26 +111,26 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if "action" not in self.output_features:
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@@ -86,8 +87,8 @@ class SARMConfig(PreTrainedConfig):
|
||||
|
||||
pretrained_model_path: str | None = None
|
||||
device: str | None = None
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
state_key: str = "observation.state"
|
||||
image_key: str = OBS_IMAGES + ".top" # Key for image used from the dataset
|
||||
state_key: str = OBS_STATE
|
||||
|
||||
# Populated by the processor (video_features, state_features, text_features)
|
||||
input_features: dict = field(default_factory=lambda: {})
|
||||
|
||||
@@ -40,6 +40,7 @@ from lerobot.policies.sarm.sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
@@ -721,7 +722,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
Returns:
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
"""
|
||||
observation = batch.get("observation", batch)
|
||||
observation = batch.get(OBS_STR, batch)
|
||||
|
||||
# Extract features
|
||||
video_features = observation["video_features"].to(self.device)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -140,7 +139,7 @@ def prepare_observation_for_inference(
|
||||
|
||||
|
||||
def build_inference_frame(
|
||||
observation: dict[str, Any],
|
||||
observation: RobotObservation,
|
||||
device: torch.device,
|
||||
ds_features: dict[str, dict],
|
||||
task: str | None = None,
|
||||
|
||||
@@ -18,6 +18,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("wall_x")
|
||||
@@ -105,14 +106,14 @@ class WallXConfig(PreTrainedConfig):
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
else:
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_shape = self.input_features[OBS_STATE].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
@@ -120,14 +121,14 @@ class WallXConfig(PreTrainedConfig):
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
self.output_features[ACTION] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_shape = self.output_features[ACTION].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1861,7 +1861,7 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
action_dim = self.config.output_features["action"].shape[0]
|
||||
action_dim = self.config.output_features[ACTION].shape[0]
|
||||
dof_mask = torch.cat(
|
||||
[
|
||||
torch.ones(
|
||||
@@ -1977,7 +1977,7 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
elif self.config.prediction_mode == "fast":
|
||||
output = self.model(
|
||||
**batch,
|
||||
action_dim=self.config.output_features["action"].shape[0],
|
||||
action_dim=self.config.output_features[ACTION].shape[0],
|
||||
pred_horizon=self.config.chunk_size,
|
||||
mode="predict",
|
||||
predict_mode="fast",
|
||||
@@ -1989,7 +1989,7 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
actions = output["predict_action"]
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
action_dim = self.config.output_features["action"].shape[0]
|
||||
action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :action_dim]
|
||||
|
||||
return actions
|
||||
|
||||
@@ -41,6 +41,7 @@ from lerobot.processor.converters import policy_action_to_transition, transition
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
OBS_PREFIX,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
@@ -137,8 +138,9 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
processed_obs[key] = img
|
||||
# Process robot_state into a flat state vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
robot_state_str = OBS_PREFIX + "robot_state"
|
||||
if robot_state_str in processed_obs:
|
||||
robot_state = processed_obs.pop(robot_state_str)
|
||||
|
||||
# Extract components
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||
@@ -174,8 +176,8 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
state_feats = {}
|
||||
|
||||
# add our new flattened state
|
||||
state_feats["observation.state"] = PolicyFeature(
|
||||
key="observation.state",
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
key=OBS_STATE,
|
||||
shape=(20,),
|
||||
dtype="float32",
|
||||
)
|
||||
@@ -247,7 +249,7 @@ class XVLAImageScaleProcessorStep(ProcessorStep):
|
||||
keys_to_scale = self.image_keys
|
||||
if keys_to_scale is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
|
||||
keys_to_scale = [k for k in obs if k.startswith(OBS_IMAGES)]
|
||||
|
||||
# Scale each image
|
||||
for key in keys_to_scale:
|
||||
@@ -303,7 +305,7 @@ class XVLAImageToFloatProcessorStep(ProcessorStep):
|
||||
keys_to_convert = self.image_keys
|
||||
if keys_to_convert is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
|
||||
keys_to_convert = [k for k in obs if k.startswith(OBS_IMAGES)]
|
||||
|
||||
# Convert each image
|
||||
for key in keys_to_convert:
|
||||
@@ -376,7 +378,7 @@ class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
|
||||
keys_to_normalize = self.image_keys
|
||||
if keys_to_normalize is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
|
||||
keys_to_normalize = [k for k in obs if k.startswith(OBS_IMAGES)]
|
||||
|
||||
# Normalize each image
|
||||
for key in keys_to_normalize:
|
||||
|
||||
@@ -49,7 +49,6 @@ from .hil_processor import (
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
from .pipeline import (
|
||||
@@ -94,14 +93,12 @@ __all__ = [
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"JointVelocityProcessorStep",
|
||||
"make_default_processors",
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"MotorCurrentProcessorStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
"ObservationProcessorStep",
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED
|
||||
from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
@@ -176,7 +176,7 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation: dict[str, Any] | None = None,
|
||||
observation: RobotObservation | None = None,
|
||||
action: PolicyAction | RobotAction | None = None,
|
||||
reward: float = 0.0,
|
||||
done: bool = False,
|
||||
@@ -384,7 +384,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
REWARD: transition.get(TransitionKey.REWARD, 0.0),
|
||||
DONE: transition.get(TransitionKey.DONE, False),
|
||||
TRUNCATED: transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
INFO: transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add complementary data.
|
||||
|
||||
@@ -45,7 +45,7 @@ RobotObservation: TypeAlias = dict[str, Any]
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||
TransitionKey.OBSERVATION.value: RobotObservation | None,
|
||||
TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
@@ -60,8 +60,9 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
processed_obs[key] = img
|
||||
# Process robot_state into a flat state vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
observation_robot_state_str = OBS_PREFIX + "robot_state"
|
||||
if observation_robot_state_str in processed_obs:
|
||||
robot_state = processed_obs.pop(observation_robot_state_str)
|
||||
|
||||
# Extract components
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||
@@ -98,8 +99,8 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
state_feats = {}
|
||||
|
||||
# add our new flattened state
|
||||
state_feats["observation.state"] = PolicyFeature(
|
||||
key="observation.state",
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
key=OBS_STATE,
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.constants import ACTION
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -239,7 +239,7 @@ class _NormalizationMixin:
|
||||
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
|
||||
return config
|
||||
|
||||
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
|
||||
def _normalize_observation(self, observation: RobotObservation, inverse: bool) -> dict[str, Tensor]:
|
||||
"""
|
||||
Applies (un)normalization to all relevant features in an observation dictionary.
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
# Generic type variables for pipeline input and output.
|
||||
TInput = TypeVar("TInput")
|
||||
@@ -1337,7 +1337,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
return features
|
||||
|
||||
# Convenience methods for processing individual parts of a transition.
|
||||
def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
def process_observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""Processes only the observation part of a transition through the pipeline.
|
||||
|
||||
Args:
|
||||
@@ -1440,7 +1440,7 @@ class ObservationProcessorStep(ProcessorStep, ABC):
|
||||
"""An abstract `ProcessorStep` that specifically targets the observation in a transition."""
|
||||
|
||||
@abstractmethod
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""Processes an observation dictionary. Subclasses must implement this method.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -38,7 +38,7 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .core import EnvTransition, RobotObservation, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
@@ -139,7 +139,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
|
||||
|
||||
@@ -38,13 +38,12 @@ from lerobot.processor import (
|
||||
GripperPenaltyProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
JointVelocityProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
MotorCurrentProcessorStep,
|
||||
Numpy2TorchActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
RobotObservation,
|
||||
TimeLimitProcessorStep,
|
||||
Torch2NumpyActionProcessorStep,
|
||||
TransitionKey,
|
||||
@@ -77,6 +76,8 @@ from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@@ -163,7 +164,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
self._setup_spaces()
|
||||
|
||||
def _get_observation(self) -> dict[str, Any]:
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation including joint positions and camera images."""
|
||||
obs_dict = self.robot.get_observation()
|
||||
raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names}
|
||||
@@ -220,7 +221,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
) -> tuple[RobotObservation, dict[str, Any]]:
|
||||
"""Reset environment to initial state.
|
||||
|
||||
Args:
|
||||
@@ -249,7 +250,7 @@ class RobotEnv(gym.Env):
|
||||
self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
|
||||
return obs, {TeleopEvents.IS_INTERVENTION: False}
|
||||
|
||||
def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
|
||||
def step(self, action) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
"""Execute one environment step with given action."""
|
||||
joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())}
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -116,7 +116,7 @@ class BiSOFollower(Robot):
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -129,7 +129,7 @@ class BiSOFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
|
||||
@@ -18,12 +18,12 @@
|
||||
import base64
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -197,11 +197,11 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: float,
|
||||
}
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
Returns:
|
||||
dict: Observation containing:
|
||||
RobotObservation: Observation containing:
|
||||
- front: Front camera image (480, 640, 3) in RGB format
|
||||
- rear: Rear camera image (480, 640, 3) in RGB format
|
||||
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
|
||||
@@ -255,7 +255,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
return observation
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
Args:
|
||||
@@ -264,8 +264,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
- angular.vel: Target angular velocity (-1 to 1)
|
||||
|
||||
Returns:
|
||||
dict: The action that was sent (matches action_features keys)
|
||||
|
||||
RobotAction: The action that was sent (matches action_features keys)
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
@@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -128,7 +128,7 @@ class HopeJrArm(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -149,7 +149,7 @@ class HopeJrArm(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
@@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -159,7 +159,7 @@ class HopeJrHand(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -181,7 +181,7 @@ class HopeJrHand(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
@@ -25,6 +24,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -182,7 +182,7 @@ class KochFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -202,7 +202,7 @@ class KochFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, float]) -> dict[str, float]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
@@ -210,10 +210,10 @@ class KochFollower(Robot):
|
||||
Thus, this function always returns the action actually sent.
|
||||
|
||||
Args:
|
||||
action (dict[str, float]): The goal positions for the motors.
|
||||
action (RobotAction): The goal positions for the motors.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The action sent to the motors, potentially clipped.
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -338,7 +339,7 @@ class LeKiwi(Robot):
|
||||
"theta.vel": theta,
|
||||
} # m/s and deg/s
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -369,7 +370,7 @@ class LeKiwi(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
@@ -380,7 +381,7 @@ class LeKiwi(Robot):
|
||||
RobotDeviceNotConnectedError: if robot is not connected.
|
||||
|
||||
Returns:
|
||||
np.ndarray: the action sent to the motors, potentially clipped.
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -18,11 +18,11 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
@@ -172,7 +172,7 @@ class LeKiwiClient(Robot):
|
||||
|
||||
return last_msg
|
||||
|
||||
def _parse_observation_json(self, obs_string: str) -> dict[str, Any] | None:
|
||||
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
||||
"""Parses the JSON observation string."""
|
||||
try:
|
||||
return json.loads(obs_string)
|
||||
@@ -196,15 +196,15 @@ class LeKiwiClient(Robot):
|
||||
return None
|
||||
|
||||
def _remote_state_from_obs(
|
||||
self, observation: dict[str, Any]
|
||||
) -> tuple[dict[str, np.ndarray], dict[str, Any]]:
|
||||
self, observation: RobotObservation
|
||||
) -> tuple[dict[str, np.ndarray], RobotObservation]:
|
||||
"""Extracts frames, and state from the parsed observation."""
|
||||
|
||||
flat_state = {key: observation.get(key, 0.0) for key in self._state_order}
|
||||
|
||||
state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32)
|
||||
|
||||
obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec}
|
||||
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
||||
|
||||
# Decode images
|
||||
current_frames: dict[str, np.ndarray] = {}
|
||||
@@ -217,7 +217,7 @@ class LeKiwiClient(Robot):
|
||||
|
||||
return current_frames, obs_dict
|
||||
|
||||
def _get_data(self) -> tuple[dict[str, np.ndarray], dict[str, Any], dict[str, Any]]:
|
||||
def _get_data(self) -> tuple[dict[str, np.ndarray], RobotObservation]:
|
||||
"""
|
||||
Polls the video socket for the latest observation data.
|
||||
|
||||
@@ -252,7 +252,7 @@ class LeKiwiClient(Robot):
|
||||
|
||||
return new_frames, new_state
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
@@ -307,12 +307,11 @@ class LeKiwiClient(Robot):
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||
|
||||
Args:
|
||||
action (np.ndarray): array containing the goal positions for the motors.
|
||||
|
||||
action (RobotAction): array containing the goal positions for the motors.
|
||||
Raises:
|
||||
RobotDeviceNotConnectedError: if robot is not connected.
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
@@ -26,6 +25,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -165,7 +165,7 @@ class OmxFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -185,7 +185,7 @@ class OmxFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, float]) -> dict[str, float]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
@@ -193,10 +193,10 @@ class OmxFollower(Robot):
|
||||
Thus, this function always returns the action actually sent.
|
||||
|
||||
Args:
|
||||
action (dict[str, float]): The goal positions for the motors.
|
||||
action (RobotAction): The goal positions for the motors.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The action sent to the motors, potentially clipped.
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -18,9 +18,8 @@ from __future__ import annotations
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -171,8 +170,8 @@ class Reachy2Robot(Robot):
|
||||
else:
|
||||
return {}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict: dict[str, Any] = {}
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict: RobotObservation = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
before_read_t = time.perf_counter()
|
||||
@@ -185,7 +184,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
import abc
|
||||
import builtins
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.motors import MotorCalibration
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS
|
||||
|
||||
from .config import RobotConfig
|
||||
@@ -153,28 +153,28 @@ class Robot(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Retrieve the current observation from the robot.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A flat dictionary representing the robot's current sensory state. Its structure
|
||||
RobotObservation: A flat dictionary representing the robot's current sensory state. Its structure
|
||||
should match :pymeth:`observation_features`.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""
|
||||
Send an action command to the robot.
|
||||
|
||||
Args:
|
||||
action (dict[str, Any]): Dictionary representing the desired action. Its structure should match
|
||||
action (RobotAction): Dictionary representing the desired action. Its structure should match
|
||||
:pymeth:`action_features`.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by
|
||||
RobotAction: The action actually sent to the motors potentially clipped or modified, e.g. by
|
||||
safety limits on velocity.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.processor import (
|
||||
ProcessorStepRegistry,
|
||||
RobotAction,
|
||||
RobotActionProcessorStep,
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.utils.rotation import Rotation
|
||||
@@ -438,7 +439,7 @@ class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep):
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names)
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any, TypeAlias
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
@@ -25,6 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -175,7 +176,7 @@ class SOFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -195,7 +196,7 @@ class SOFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
@@ -206,7 +207,7 @@ class SOFollower(Robot):
|
||||
RobotDeviceNotConnectedError: if robot is not connected.
|
||||
|
||||
Returns:
|
||||
the action sent to the motors, potentially clipped.
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -26,6 +26,7 @@ import numpy as np
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -269,7 +270,7 @@ class UnitreeG1(Robot):
|
||||
for cam in self._cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
return {}
|
||||
@@ -350,7 +351,7 @@ class UnitreeG1(Robot):
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
for motor in G1_29_JointIndex:
|
||||
key = f"{motor.name}.q"
|
||||
if key in action:
|
||||
|
||||
@@ -177,9 +177,9 @@ def rollout(
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {"action": action}
|
||||
action_transition = {ACTION: action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition["action"]
|
||||
action = action_transition[ACTION]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
|
||||
@@ -165,7 +165,7 @@ def teleop_loop(
|
||||
# Process action for robot through pipeline
|
||||
robot_action_to_send = robot_action_processor((teleop_action, obs))
|
||||
|
||||
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
|
||||
# Send processed action to robot (robot_action_processor.to_output should return RobotAction)
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
|
||||
if display_data:
|
||||
|
||||
@@ -63,6 +63,7 @@ else:
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,7 +87,7 @@ class TokenizerTrainingConfig:
|
||||
# Whether to apply delta transform (relative actions vs absolute actions)
|
||||
use_delta_transform: bool = False
|
||||
# Dataset key for state observations (default: "observation.state")
|
||||
state_key: str = "observation.state"
|
||||
state_key: str = OBS_STATE
|
||||
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
||||
normalization_mode: str = "QUANTILES"
|
||||
# FAST vocabulary size (BPE vocab size)
|
||||
@@ -223,12 +224,10 @@ def process_episode(args):
|
||||
else:
|
||||
# if no state key, use zeros (no delta transform)
|
||||
state = np.zeros_like(
|
||||
frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"])
|
||||
frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION])
|
||||
)
|
||||
|
||||
action = (
|
||||
frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"])
|
||||
)
|
||||
action = frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION])
|
||||
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
@@ -468,8 +467,8 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
|
||||
|
||||
# get normalization stats from dataset
|
||||
norm_stats = dataset.meta.stats
|
||||
if norm_stats is not None and "action" in norm_stats:
|
||||
action_stats = norm_stats["action"]
|
||||
if norm_stats is not None and ACTION in norm_stats:
|
||||
action_stats = norm_stats[ACTION]
|
||||
|
||||
# build encoded dimension indices
|
||||
encoded_dim_indices = []
|
||||
|
||||
@@ -20,6 +20,8 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
@@ -83,7 +85,7 @@ class GamepadTeleop(Teleoperator):
|
||||
self.gamepad = Gamepad()
|
||||
self.gamepad.start()
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
def get_action(self) -> RobotAction:
|
||||
# Update the controller to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import time
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -124,7 +125,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
def get_action(self) -> RobotAction:
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
@@ -181,7 +182,7 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
}
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||
@@ -374,12 +375,12 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
# Only remove key if it's being released
|
||||
self.current_pressed.pop(key_char, None)
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get the current action based on pressed keys.
|
||||
|
||||
Returns:
|
||||
dict with 'linear.vel' and 'angular.vel' keys
|
||||
RobotAction with 'linear.vel' and 'angular.vel' keys
|
||||
"""
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Any
|
||||
import draccus
|
||||
|
||||
from lerobot.motors.motors_bus import MotorCalibration
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
@@ -150,12 +151,12 @@ class Teleoperator(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Retrieve the current action from the teleoperator.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A flat dictionary representing the teleoperator's current actions. Its
|
||||
RobotAction: A flat dictionary representing the teleoperator's current actions. Its
|
||||
structure should match :pymeth:`observation_features`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -28,11 +28,13 @@ OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
ACTION_PREFIX = ACTION + "."
|
||||
ACTION_TOKENS = ACTION + ".tokens"
|
||||
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
INFO = "info"
|
||||
|
||||
ROBOTS = "robots"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
|
||||
import numbers
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from .constants import OBS_PREFIX, OBS_STR
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||
|
||||
|
||||
def init_rerun(
|
||||
@@ -50,8 +51,8 @@ def _is_scalar(x):
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
observation: RobotObservation | None = None,
|
||||
action: RobotAction | None = None,
|
||||
compress_images: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -96,7 +97,7 @@ def log_rerun_data(
|
||||
for k, v in action.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("action.") else f"action.{k}"
|
||||
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
|
||||
@@ -17,10 +17,10 @@
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras import CameraConfig, make_cameras_from_configs
|
||||
from lerobot.motors.motors_bus import Motor, MotorNormMode
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots import Robot, RobotConfig
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from tests.mocks.mock_motors_bus import MockMotorsBus
|
||||
@@ -119,7 +119,7 @@ class MockRobot(Robot):
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -130,7 +130,7 @@ class MockRobot(Robot):
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
@@ -88,7 +89,7 @@ class MockTeleop(Teleoperator):
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user