diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index 7b739ecc1..884e84d8c 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -204,22 +204,26 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name Your dataset includes: -**Your Actions (2 things)**: +**Your Actions (2 features)**: -- How much you moved forward/backward -- How much you turned left/right +- `linear_velocity`: How much you moved forward/backward +- `angular_velocity`: How much you turned left/right -**Robot Observations (12 things)**: +**Robot Observations (24 features)**: - Front camera video - Rear camera video - Current speed - Battery level -- Which way the robot is facing -- GPS location (latitude, longitude, signal strength) +- Orientation +- GPS (latitude, longitude, signal strength) - Network signal strength - Vibration level -- Lamp status (on/off) +- Lamp state (on/off) +- Accelerometer (x, y, z) +- Gyroscope (x, y, z) +- Magnetometer (x, y, z) +- Wheel RPMs (4 wheels) ### Where Your Data Goes diff --git a/examples/dataset/load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py index 4fda25884..ea3516710 100644 --- a/examples/dataset/load_lerobot_dataset.py +++ b/examples/dataset/load_lerobot_dataset.py @@ -32,7 +32,8 @@ import torch from huggingface_hub import HfApi import lerobot -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset def main(): diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index a3144a442..ef98640aa 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 9292157f7..ace2e35b8 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor import make_default_processors from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 837217eda..9cd7a98c2 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -16,15 +16,13 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import ( - RobotAction, - RobotObservation, RobotProcessorPipeline, make_default_teleop_action_processor, ) @@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.scripts.lerobot_record import record_loop +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 1f5005db9..f2a17cd33 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -15,11 +15,11 @@ # limitations under the License. from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( observation_to_transition, robot_action_observation_to_transition, @@ -38,6 +38,7 @@ from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 9d7806cf4..7b955cdb7 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -18,7 +18,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -27,6 +27,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index 6eaaec806..7242c39ce 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -16,7 +16,7 @@ import time from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -31,6 +31,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/port_datasets/port_droid.py b/examples/port_datasets/port_droid.py index a1fb50914..f58bacbe0 100644 --- a/examples/port_datasets/port_droid.py +++ b/examples/port_datasets/port_droid.py @@ -22,7 +22,8 @@ from pathlib import Path import numpy as np import tensorflow_datasets as tfds -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds DROID_SHARDS = 2048 diff --git a/examples/port_datasets/slurm_upload.py b/examples/port_datasets/slurm_upload.py index 55002c0be..7fb01c11b 100644 --- a/examples/port_datasets/slurm_upload.py +++ b/examples/port_datasets/slurm_upload.py @@ -26,7 +26,7 @@ from huggingface_hub import HfApi from huggingface_hub.constants import REPOCARD_NAME from port_droid import DROID_SHARDS -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.datasets.utils import create_lerobot_dataset_card from lerobot.utils.utils import init_logging @@ -155,7 +155,7 @@ class UploadDataset(PipelineStep): from datasets.utils.tqdm import disable_progress_bars from huggingface_hub import CommitOperationAdd, preupload_lfs_files - from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.utils.utils import init_logging init_logging() diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 613fd67d7..a94d4da48 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -113,8 +113,9 @@ from lerobot.configs import parser from lerobot.configs.default import DatasetConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.factory import resolve_delta_timestamps -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 9d9e1364a..36da88e1b 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -82,7 +82,7 @@ from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.action_queue import ActionQueue from lerobot.policies.rtc.configuration_rtc import RTCConfig diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index b614b89f2..638591021 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -16,15 +16,13 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import ( - RobotAction, - RobotObservation, RobotProcessorPipeline, make_default_teleop_action_processor, ) @@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.scripts.lerobot_record import record_loop +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index d85a1c5cc..634bd891a 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -16,11 +16,11 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( observation_to_transition, robot_action_observation_to_transition, @@ -35,6 +35,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ) from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 47a2f6635..b042e02dd 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -19,7 +19,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -28,6 +28,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py index 71d2899de..af21f079b 100644 --- a/examples/so100_to_so100_EE/teleoperate.py +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -17,7 +17,7 @@ import time from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, robot_action_to_transition, @@ -30,6 +30,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/training/train_policy.py b/examples/training/train_policy.py index 16f2a4d87..07ec10c92 100644 --- a/examples/training/train_policy.py +++ b/examples/training/train_policy.py @@ -19,8 +19,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/training/train_with_streaming.py b/examples/training/train_with_streaming.py index 185be5b13..973698e74 100644 --- a/examples/training/train_with_streaming.py +++ b/examples/training/train_with_streaming.py @@ -20,9 +20,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset -from lerobot.datasets.utils import dataset_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/tutorial/act/act_training_example.py b/examples/tutorial/act/act_training_example.py index fe70f3023..b62c49cac 100644 --- a/examples/tutorial/act/act_training_example.py +++ b/examples/tutorial/act/act_training_example.py @@ -5,8 +5,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/tutorial/act/act_using_example.py b/examples/tutorial/act/act_using_example.py index 60bc802d8..15254d8eb 100644 --- a/examples/tutorial/act/act_using_example.py +++ b/examples/tutorial/act/act_using_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/examples/tutorial/diffusion/diffusion_training_example.py b/examples/tutorial/diffusion/diffusion_training_example.py index 6db081450..dc6ca68a3 100644 --- a/examples/tutorial/diffusion/diffusion_training_example.py +++ b/examples/tutorial/diffusion/diffusion_training_example.py @@ -5,8 +5,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/tutorial/diffusion/diffusion_using_example.py b/examples/tutorial/diffusion/diffusion_using_example.py index d8ac75cfe..9b31cf359 100644 --- a/examples/tutorial/diffusion/diffusion_using_example.py +++ b/examples/tutorial/diffusion/diffusion_using_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/examples/tutorial/pi0/using_pi0_example.py b/examples/tutorial/pi0/using_pi0_example.py index 056c3d81a..d8cf9dbff 100644 --- a/examples/tutorial/pi0/using_pi0_example.py +++ b/examples/tutorial/pi0/using_pi0_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.pi0.modeling_pi0 import PI0Policy from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index 980ac7985..d367a01ce 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -6,8 +6,8 @@ from queue import Empty, Full import torch import torch.optim as optim +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import SACPolicy diff --git a/examples/tutorial/smolvla/using_smolvla_example.py b/examples/tutorial/smolvla/using_smolvla_example.py index ce3aa7bca..b99126efa 100644 --- a/examples/tutorial/smolvla/using_smolvla_example.py +++ b/examples/tutorial/smolvla/using_smolvla_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 8b12920d9..9dd44eb44 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -23,7 +23,7 @@ from typing import Any import torch from lerobot.configs.types import PolicyFeature -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ( # noqa: F401 diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index aedce2a74..3f63929df 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -39,15 +39,13 @@ import grpc import torch from lerobot.policies.factory import get_policy_class, make_pre_post_processors -from lerobot.processor import ( - PolicyAction, - PolicyProcessorPipeline, -) +from lerobot.processor import PolicyProcessorPipeline from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore ) from lerobot.transport.utils import receive_bytes_in_chunks +from lerobot.types import PolicyAction from .configs import PolicyServerConfig from .constants import SUPPORTED_POLICIES diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 3fb0c6c4e..7f481b9ca 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -36,6 +36,16 @@ class DatasetConfig: video_backend: str = field(default_factory=get_safe_default_codec) streaming: bool = False + def __post_init__(self) -> None: + if self.episodes is not None: + if any(ep < 0 for ep in self.episodes): + raise ValueError( + f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}" + ) + if len(self.episodes) != len(set(self.episodes)): + duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1}) + raise ValueError(f"Episode indices contain duplicates: {duplicates}") + @dataclass class WandBConfig: diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 44b013c29..ce567b8f5 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -30,8 +30,8 @@ from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.utils.hub import HubMixin -from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available T = TypeVar("T", bound="PreTrainedConfig") logger = getLogger(__name__) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 9d20afc68..8b8aedb26 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -51,7 +51,7 @@ class TrainPipelineConfig(HubMixin): # AND for the evaluation environments. seed: int | None = 1000 # Set to True to use deterministic cuDNN algorithms for reproducibility. - # This disables cudnn.benchmark and may reduce training speed by ~10-20%. + # This disables cudnn.benchmark and may reduce training speed by ~10-20 percent. cudnn_deterministic: bool = False # Number of workers for the dataloader. num_workers: int = 4 diff --git a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py index 67e37bab8..8f3a65e39 100644 --- a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py +++ b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py @@ -746,7 +746,8 @@ def save_annotations_to_dataset( dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse" ): """Save annotations to LeRobot dataset parquet format.""" - from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes + from lerobot.datasets.io_utils import load_episodes + from lerobot.datasets.utils import DEFAULT_EPISODES_PATH episodes_dataset = load_episodes(dataset_path) if not episodes_dataset or len(episodes_dataset) == 0: @@ -840,7 +841,7 @@ def generate_auto_sparse_annotations( def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]: """Load annotations from LeRobot dataset parquet files.""" - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes episodes_dataset = load_episodes(dataset_path) if not episodes_dataset or len(episodes_dataset) == 0: diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index b32116233..66f055f04 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -24,7 +24,16 @@ import pandas as pd import tqdm from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import get_hf_features_from_features +from lerobot.datasets.io_utils import ( + get_file_size_in_mb, + get_parquet_file_size_in_mb, + to_parquet_with_hf_images, + write_info, + write_stats, + write_tasks, +) from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -32,14 +41,7 @@ from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, - get_file_size_in_mb, - get_hf_features_from_features, - get_parquet_file_size_in_mb, - to_parquet_with_hf_images, update_chunk_file_indices, - write_info, - write_stats, - write_tasks, ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s diff --git a/src/lerobot/datasets/backward_compatibility.py b/src/lerobot/datasets/backward_compatibility.py deleted file mode 100644 index ae95c5f7b..000000000 --- a/src/lerobot/datasets/backward_compatibility.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import packaging.version - -V30_MESSAGE = """ -The dataset you requested ({repo_id}) is in {version} format. - -We introduced a new format since v3.0 which is not backward compatible with v2.1. -Please, update your dataset to the new format using this command: -``` -python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id} -``` - -If you already have a converted version uploaded to the hub, then this error might be because of -an older version in your local cache. Consider deleting the cached version and retrying. - -If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) -or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). -""" - -FUTURE_MESSAGE = """ -The dataset you requested ({repo_id}) is only available in {version} format. -As we cannot ensure forward compatibility with it, please update your current version of lerobot. -""" - - -class CompatibilityError(Exception): ... - - -class BackwardCompatibilityError(CompatibilityError): - def __init__(self, repo_id: str, version: packaging.version.Version): - if version.major == 2 and version.minor == 1: - message = V30_MESSAGE.format(repo_id=repo_id, version=version) - else: - raise NotImplementedError( - "Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)." - ) - super().__init__(message) - - -class ForwardCompatibilityError(CompatibilityError): - def __init__(self, repo_id: str, version: packaging.version.Version): - message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version) - super().__init__(message) diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 61e174d5c..5bd95810b 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -15,7 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.datasets.utils import load_image_as_numpy +from lerobot.datasets.io_utils import load_image_as_numpy DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py new file mode 100644 index 000000000..560a90a6e --- /dev/null +++ b/src/lerobot/datasets/dataset_metadata.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import numpy as np +import packaging.version +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from huggingface_hub import snapshot_download + +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info +from lerobot.datasets.io_utils import ( + get_file_size_in_mb, + load_episodes, + load_info, + load_stats, + load_subtasks, + load_tasks, + write_info, + write_json, + write_stats, + write_tasks, +) +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_FEATURES, + INFO_PATH, + check_version_compatibility, + flatten_dict, + get_safe_version, + is_valid_version, + update_chunk_file_indices, +) +from lerobot.datasets.video_utils import get_video_info +from lerobot.utils.constants import HF_LEROBOT_HOME + +CODEBASE_VERSION = "v3.0" + + +class LeRobotDatasetMetadata: + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + revision: str | None = None, + force_cache_sync: bool = False, + metadata_buffer_size: int = 10, + ): + self.repo_id = repo_id + self.revision = revision if revision else CODEBASE_VERSION + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self.writer = None + self.latest_episode = None + self.metadata_buffer: list[dict] = [] + self.metadata_buffer_size = metadata_buffer_size + + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + if is_valid_version(self.revision): + self.revision = get_safe_version(self.repo_id, self.revision) + + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + table = pa.Table.from_pydict(combined_dict) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + self._flush_metadata_buffer() + + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + + def load_metadata(self): + self.info = load_info(self.root) + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + self.tasks = load_tasks(self.root) + self.subtasks = load_subtasks(self.root) + self.episodes = load_episodes(self.root) + self.stats = load_stats(self.root) + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + @property + def url_root(self) -> str: + return f"hf://datasets/{self.repo_id}" + + @property + def _version(self) -> packaging.version.Version: + """Codebase version used to create this dataset.""" + return packaging.version.parse(self.info["codebase_version"]) + + def get_data_file_path(self, ep_index: int) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) + ep = self.episodes[ep_index] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) + ep = self.episodes[ep_index] + chunk_idx = ep[f"videos/{vid_key}/chunk_index"] + file_idx = ep[f"videos/{vid_key}/file_index"] + fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def video_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["video_path"] + + @property + def robot_type(self) -> str | None: + """Robot type used in recording this dataset.""" + return self.info["robot_type"] + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.info["fps"] + + @property + def features(self) -> dict[str, dict]: + """All features contained in the dataset.""" + return self.info["features"] + + @property + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "image"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + + @property + def camera_keys(self) -> list[str]: + """Keys to access visual modalities (regardless of their storage method).""" + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def names(self) -> dict[str, list | dict]: + """Names of the various dimensions of vector modalities.""" + return {key: ft["names"] for key, ft in self.features.items()} + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return {key: tuple(ft["shape"]) for key, ft in self.features.items()} + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] + + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + + @property + def chunks_size(self) -> int: + """Max number of files per chunk.""" + return self.info["chunks_size"] + + @property + def data_files_size_in_mb(self) -> int: + """Max size of data file in mega bytes.""" + return self.info["data_files_size_in_mb"] + + @property + def video_files_size_in_mb(self) -> int: + """Max size of video file in mega bytes.""" + return self.info["video_files_size_in_mb"] + + def get_task_index(self, task: str) -> int | None: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise return None. + """ + if task in self.tasks.index: + return int(self.tasks.loc[task].task_index) + else: + return None + + def save_episode_tasks(self, tasks: list[str]): + if len(set(tasks)) != len(tasks): + raise ValueError(f"Tasks are not unique: {tasks}") + + if self.tasks is None: + new_tasks = tasks + task_indices = range(len(tasks)) + self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task")) + else: + new_tasks = [task for task in tasks if task not in self.tasks.index] + new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) + for task_idx, task in zip(new_task_indices, new_tasks, strict=False): + self.tasks.loc[task] = task_idx + + if len(new_tasks) > 0: + # Update on disk + write_tasks(self.tasks, self.root) + + def _save_episode_metadata(self, episode_dict: dict) -> None: + """Buffer episode metadata and write to parquet in batches for efficiency. + + This function accumulates episode metadata in a buffer and flushes it when the buffer + reaches the configured size. This reduces I/O overhead by writing multiple episodes + at once instead of one row at a time. + + Notes: We both need to update parquet files and HF dataset: + - `pandas` loads parquet file in RAM + - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, + or loads directly from pyarrow cache. + """ + # Convert to list format for each value + episode_dict = {key: [value] for key, value in episode_dict.items()} + num_frames = episode_dict["length"][0] + + if self.latest_episode is None: + # Initialize indices and frame count for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + if self.episodes is not None and len(self.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] + file_idx = self.episodes[-1]["meta/episodes/file_index"] + latest_num_frames = self.episodes[-1]["dataset_to_index"] + episode_dict["dataset_from_index"] = [latest_num_frames] + episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + else: + episode_dict["dataset_from_index"] = [0] + episode_dict["dataset_to_index"] = [num_frames] + + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + else: + chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] + file_idx = self.latest_episode["meta/episodes/file_index"][0] + + latest_path = ( + self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + if self.writer is None + else self.writer.where + ) + + if Path(latest_path).exists(): + latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) + latest_num_frames = self.latest_episode["episode_index"][0] + + av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 + + if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: + # Size limit is reached, flush buffer and prepare new parquet file + self._flush_metadata_buffer() + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + self._close_writer() + + # Update the existing pandas dataframe with new row + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] + episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] + + # Add to buffer + self.metadata_buffer.append(episode_dict) + self.latest_episode = episode_dict + + if len(self.metadata_buffer) >= self.metadata_buffer_size: + self._flush_metadata_buffer() + + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + episode_metadata: dict, + ) -> None: + episode_dict = { + "episode_index": episode_index, + "tasks": episode_tasks, + "length": episode_length, + } + episode_dict.update(episode_metadata) + episode_dict.update(flatten_dict({"stats": episode_stats})) + self._save_episode_metadata(episode_dict) + + # Update info + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + self.info["total_tasks"] = len(self.tasks) + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + + write_info(self.info, self.root) + + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats + write_stats(self.stats, self.root) + + def update_video_info(self, video_key: str | None = None) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + if video_key is not None and video_key not in self.video_keys: + raise ValueError(f"Video key {video_key} not found in dataset") + + video_keys = [video_key] if video_key is not None else self.video_keys + for key in video_keys: + if not self.features[key].get("info", None): + video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) + self.info["features"][key]["info"] = get_video_info(video_path) + + def update_chunk_settings( + self, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, + ) -> None: + """Update chunk and file size settings after dataset creation. + + This allows users to customize storage organization without modifying the constructor. + These settings control how episodes are chunked and how large files can grow before + creating new ones. + + Args: + chunks_size: Maximum number of files per chunk directory. If None, keeps current value. + data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. + video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. + """ + if chunks_size is not None: + if chunks_size <= 0: + raise ValueError(f"chunks_size must be positive, got {chunks_size}") + self.info["chunks_size"] = chunks_size + + if data_files_size_in_mb is not None: + if data_files_size_in_mb <= 0: + raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") + self.info["data_files_size_in_mb"] = data_files_size_in_mb + + if video_files_size_in_mb is not None: + if video_files_size_in_mb <= 0: + raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") + self.info["video_files_size_in_mb"] = video_files_size_in_mb + + # Update the info file on disk + write_info(self.info, self.root) + + def get_chunk_settings(self) -> dict[str, int]: + """Get current chunk and file size settings. + + Returns: + Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. + """ + return { + "chunks_size": self.chunks_size, + "data_files_size_in_mb": self.data_files_size_in_mb, + "video_files_size_in_mb": self.video_files_size_in_mb, + } + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Total episodes: '{self.total_episodes}',\n" + f" Total frames: '{self.total_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + features: dict, + robot_type: str | None = None, + root: str | Path | None = None, + use_videos: bool = True, + metadata_buffer_size: int = 10, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, + ) -> "LeRobotDatasetMetadata": + """Creates metadata for a LeRobotDataset.""" + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + obj.root.mkdir(parents=True, exist_ok=False) + + features = {**features, **DEFAULT_FEATURES} + _validate_feature_names(features) + + obj.tasks = None + obj.subtasks = None + obj.episodes = None + obj.stats = None + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, + fps, + features, + use_videos, + robot_type, + chunks_size, + data_files_size_in_mb, + video_files_size_in_mb, + ) + if len(obj.video_keys) > 0 and not use_videos: + raise ValueError( + f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " + "Either remove video features from the features dict, or set 'use_videos=True'." + ) + write_json(obj.info, obj.root / INFO_PATH) + obj.revision = None + obj.writer = None + obj.latest_episode = None + obj.metadata_buffer = [] + obj.metadata_buffer_size = metadata_buffer_size + return obj diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 546b3d67f..87cdc18e5 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -38,19 +38,22 @@ from tqdm import tqdm from lerobot.datasets.aggregate import aggregate_datasets from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.io_utils import ( + get_parquet_file_size_in_mb, + load_episodes, + write_info, + write_stats, + write_tasks, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DATA_DIR, DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, - get_parquet_file_size_in_mb, - load_episodes, update_chunk_file_indices, - write_info, - write_stats, - write_tasks, ) from lerobot.datasets.video_utils import encode_video_frames, get_video_info from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE @@ -915,7 +918,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) - This ensures images are properly embedded and the file can be loaded correctly by HF datasets. """ - from lerobot.datasets.utils import embed_images, get_hf_features_from_features + from lerobot.datasets.feature_utils import get_hf_features_from_features + from lerobot.datasets.io_utils import embed_images hf_features = get_hf_features_from_features(meta.features) ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train") diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 31e939809..76ece8961 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -20,11 +20,9 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.lerobot_dataset import ( - LeRobotDataset, - LeRobotDatasetMetadata, - MultiLeRobotDataset, -) +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py new file mode 100644 index 000000000..d9a3c6301 --- /dev/null +++ b/src/lerobot/datasets/feature_utils.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pprint import pformat +from typing import Any + +import datasets +import numpy as np +from PIL import Image as PILImage + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_FEATURES, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, + DEFAULT_VIDEO_PATH, +) +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR +from lerobot.utils.utils import is_valid_numpy_dtype_string + + +def get_hf_features_from_features(features: dict) -> datasets.Features: + """Convert a LeRobot features dictionary to a `datasets.Features` object. + + Args: + features (dict): A LeRobot-style feature dictionary. + + Returns: + datasets.Features: The corresponding Hugging Face `datasets.Features` object. + + Raises: + ValueError: If a feature has an unsupported shape. + """ + hf_features = {} + for key, ft in features.items(): + if ft["dtype"] == "video": + continue + elif ft["dtype"] == "image": + hf_features[key] = datasets.Image() + elif ft["shape"] == (1,): + hf_features[key] = datasets.Value(dtype=ft["dtype"]) + elif len(ft["shape"]) == 1: + hf_features[key] = datasets.Sequence( + length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) + ) + elif len(ft["shape"]) == 2: + hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 3: + hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 4: + hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 5: + hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) + else: + raise ValueError(f"Corresponding feature is not valid: {ft}") + + return datasets.Features(hf_features) + + +def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + Args: + features (dict): The LeRobot features dictionary. + + Raises: + ValueError: If any feature name contains '/'. + """ + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") + + +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + """Convert hardware-specific features to a LeRobot dataset feature dictionary. + + This function takes a dictionary describing hardware outputs (like joint states + or camera image shapes) and formats it into the standard LeRobot feature + specification. + + Args: + hw_features (dict): Dictionary mapping feature names to their type (float for + joints) or shape (tuple for images). + prefix (str): The prefix to add to the feature keys (e.g., "observation" + or "action"). + use_video (bool): If True, image features are marked as "video", otherwise "image". + + Returns: + dict: A LeRobot features dictionary. + """ + features = {} + joint_fts = { + key: ftype + for key, ftype in hw_features.items() + if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) + } + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts and prefix == ACTION: + features[prefix] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + if joint_fts and prefix == OBS_STR: + features[f"{prefix}.state"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.images.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + """Construct a single data frame from raw values based on dataset features. + + A "frame" is a dictionary containing all the data for a single timestep, + formatted as numpy arrays according to the feature specification. + + Args: + ds_features (dict): The LeRobot dataset features dictionary. + values (dict): A dictionary of raw values from the hardware/environment. + prefix (str): The prefix to filter features by (e.g., "observation" + or "action"). + + Returns: + dict: A dictionary representing a single frame of data. + """ + frame = {} + for key, ft in ds_features.items(): + if key in DEFAULT_FEATURES or not key.startswith(prefix): + continue + elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] + + return frame + + +def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert dataset features to policy features. + + This function transforms the dataset's feature specification into a format + that a policy can use, classifying features by type (e.g., visual, state, + action) and ensuring correct shapes (e.g., channel-first for images). + + Args: + features (dict): The LeRobot dataset features dictionary. + + Returns: + dict: A dictionary mapping feature keys to `PolicyFeature` objects. + + Raises: + ValueError: If an image feature does not have a 3D shape. + """ + # TODO(aliberts): Implement "type" in dataset features and simplify this + policy_features = {} + for key, ft in features.items(): + shape = ft["shape"] + if ft["dtype"] in ["image", "video"]: + type = FeatureType.VISUAL + if len(shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") + + names = ft["names"] + # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif key == OBS_ENV_STATE: + type = FeatureType.ENV + elif key.startswith(OBS_STR): + type = FeatureType.STATE + elif key.startswith(ACTION): + type = FeatureType.ACTION + else: + continue + + policy_features[key] = PolicyFeature( + type=type, + shape=shape, + ) + + return policy_features + + +def combine_feature_dicts(*dicts: dict) -> dict: + """Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (e.g. `observation.images.*`), the last one wins (if they are identical). + + Args: + *dicts: A variable number of LeRobot feature dictionaries to merge. + + Returns: + dict: A single merged feature dictionary. + + Raises: + ValueError: If there's a dtype mismatch for a feature being merged. + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out + + +def create_empty_dataset_info( + codebase_version: str, + fps: int, + features: dict, + use_videos: bool, + robot_type: str | None = None, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, +) -> dict: + """Create a template dictionary for a new dataset's `info.json`. + + Args: + codebase_version (str): The version of the LeRobot codebase. + fps (int): The frames per second of the data. + features (dict): The LeRobot features dictionary for the dataset. + use_videos (bool): Whether the dataset will store videos. + robot_type (str | None): The type of robot used, if any. + + Returns: + dict: A dictionary with the initial dataset metadata. + """ + return { + "codebase_version": codebase_version, + "robot_type": robot_type, + "total_episodes": 0, + "total_frames": 0, + "total_tasks": 0, + "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, + "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, + "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, + "fps": fps, + "splits": {}, + "data_path": DEFAULT_DATA_PATH, + "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "features": features, + } + + +def check_delta_timestamps( + delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True +) -> bool: + """Check if delta timestamps are multiples of 1/fps +/- tolerance. + + This ensures that adding these delta timestamps to any existing timestamp in + the dataset will result in a value that aligns with the dataset's frame rate. + + Args: + delta_timestamps (dict): A dictionary where values are lists of time + deltas in seconds. + fps (int): The frames per second of the dataset. + tolerance_s (float): The allowed tolerance in seconds. + raise_value_error (bool): If True, raises an error on failure. + + Returns: + bool: True if all deltas are valid, False otherwise. + + Raises: + ValueError: If any delta is outside the tolerance and `raise_value_error` is True. + """ + outside_tolerance = {} + for key, delta_ts in delta_timestamps.items(): + within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] + if not all(within_tolerance): + outside_tolerance[key] = [ + ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within + ] + + if len(outside_tolerance) > 0: + if raise_value_error: + raise ValueError( + f""" + The following delta_timestamps are found outside of tolerance range. + Please make sure they are multiples of 1/{fps} +/- tolerance and adjust + their values accordingly. + \n{pformat(outside_tolerance)} + """ + ) + return False + + return True + + +def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + """Convert delta timestamps in seconds to delta indices in frames. + + Args: + delta_timestamps (dict): A dictionary of time deltas in seconds. + fps (int): The frames per second of the dataset. + + Returns: + dict: A dictionary of frame delta indices. + """ + delta_indices = {} + for key, delta_ts in delta_timestamps.items(): + delta_indices[key] = [round(d * fps) for d in delta_ts] + + return delta_indices + + +def validate_frame(frame: dict, features: dict) -> None: + expected_features = set(features) - set(DEFAULT_FEATURES) + actual_features = set(frame) + + # task is a special required field that's not part of regular features + if "task" not in actual_features: + raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") + + # Remove task from actual_features for regular feature validation + actual_features_for_validation = actual_features - {"task"} + + error_message = validate_features_presence(actual_features_for_validation, expected_features) + + common_features = actual_features_for_validation & expected_features + for name in common_features: + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) + + if error_message: + raise ValueError(error_message) + + +def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: + """Check for missing or extra features in a frame. + + Args: + actual_features (set[str]): The set of feature names present in the frame. + expected_features (set[str]): The set of feature names expected in the frame. + + Returns: + str: An error message string if there's a mismatch, otherwise an empty string. + """ + error_message = "" + missing_features = expected_features - actual_features + extra_features = actual_features - expected_features + + if missing_features or extra_features: + error_message += "Feature mismatch in `frame` dictionary:\n" + if missing_features: + error_message += f"Missing features: {missing_features}\n" + if extra_features: + error_message += f"Extra features: {extra_features}\n" + + return error_message + + +def validate_feature_dtype_and_shape( + name: str, feature: dict, value: np.ndarray | PILImage.Image | str +) -> str: + """Validate the dtype and shape of a single feature's value. + + Args: + name (str): The name of the feature. + feature (dict): The feature specification from the LeRobot features dictionary. + value: The value of the feature to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + + Raises: + NotImplementedError: If the feature dtype is not supported for validation. + """ + expected_dtype = feature["dtype"] + expected_shape = feature["shape"] + if is_valid_numpy_dtype_string(expected_dtype): + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) + elif expected_dtype in ["image", "video"]: + return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "string": + return validate_feature_string(name, value) + else: + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + + +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +) -> str: + """Validate a feature that is expected to be a numpy array. + + Args: + name (str): The name of the feature. + expected_dtype (str): The expected numpy dtype as a string. + expected_shape (list[int]): The expected shape. + value (np.ndarray): The numpy array to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + error_message = "" + if isinstance(value, np.ndarray): + actual_dtype = value.dtype + actual_shape = value.shape + + if actual_dtype != np.dtype(expected_dtype): + error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" + + if actual_shape != expected_shape: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" + else: + error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_image_or_video( + name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image +) -> str: + """Validate a feature that is expected to be an image or video frame. + + Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. + + Args: + name (str): The name of the feature. + expected_shape (list[str]): The expected shape (C, H, W). + value: The image data to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + elif isinstance(value, PILImage.Image): + pass + else: + error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_string(name: str, value: str) -> str: + """Validate a feature that is expected to be a string. + + Args: + name (str): The name of the feature. + value (str): The value to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + if not isinstance(value, str): + return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" + return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: + """Validate the episode buffer before it's written to disk. + + Ensures the buffer has the required keys, contains at least one frame, and + has features consistent with the dataset's specification. + + Args: + episode_buffer (dict): The buffer containing data for a single episode. + total_episodes (int): The current total number of episodes in the dataset. + features (dict): The LeRobot features dictionary for the dataset. + + Raises: + ValueError: If the buffer is invalid. + NotImplementedError: If the episode index is manually set and doesn't match. + """ + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 23bc2efb8..9f40394de 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import multiprocessing import queue import threading @@ -22,6 +23,8 @@ import numpy as np import PIL.Image import torch +logger = logging.getLogger(__name__) + def safe_stop_image_writer(func): def wrapper(*args, **kwargs): @@ -31,7 +34,7 @@ def safe_stop_image_writer(func): dataset = kwargs.get("dataset") image_writer = getattr(dataset, "image_writer", None) if dataset else None if image_writer is not None: - print("Waiting for image writer to terminate...") + logger.warning("Waiting for image writer to terminate...") image_writer.stop() raise e @@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level PIL.Image.Image object. Side Effects: - Prints an error message to the console if the image writing process - fails for any reason. + Logs an error message if the image writing process fails for any reason. """ try: if isinstance(image, np.ndarray): @@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level raise TypeError(f"Unsupported image type: {type(image)}") img.save(fpath, compress_level=compress_level) except Exception as e: - print(f"Error writing image {fpath}: {e}") + logger.error("Error writing image %s: %s", fpath, e) def worker_thread_loop(queue: queue.Queue): diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py new file mode 100644 index 000000000..cee6cfba8 --- /dev/null +++ b/src/lerobot/datasets/io_utils.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from pathlib import Path +from typing import Any + +import datasets +import numpy as np +import pandas +import pandas as pd +import pyarrow.dataset as pa_ds +import pyarrow.parquet as pq +import torch +from datasets import Dataset +from datasets.table import embed_table_storage +from PIL import Image as PILImage +from torchvision import transforms + +from lerobot.datasets.utils import ( + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_EPISODES_PATH, + DEFAULT_SUBTASKS_PATH, + DEFAULT_TASKS_PATH, + EPISODES_DIR, + INFO_PATH, + STATS_PATH, + flatten_dict, + serialize_dict, + unflatten_dict, +) +from lerobot.utils.utils import SuppressProgressBars + + +def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: + metadata = pq.read_metadata(parquet_path) + total_uncompressed_size = 0 + for row_group in range(metadata.num_row_groups): + rg_metadata = metadata.row_group(row_group) + for column in range(rg_metadata.num_columns): + col_metadata = rg_metadata.column(column) + total_uncompressed_size += col_metadata.total_uncompressed_size + return total_uncompressed_size / (1024**2) + + +def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: + return hf_ds.data.nbytes // (1024**2) + + +def load_nested_dataset( + pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None +) -> Dataset: + """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet + Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage + Concatenate all pyarrow references to return HF Dataset format + + Args: + pq_dir: Directory containing parquet files + features: Optional features schema to ensure consistent loading of complex types like images + episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency. + """ + paths = sorted(pq_dir.glob("*/*.parquet")) + if len(paths) == 0: + raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") + + with SuppressProgressBars(): + # We use .from_parquet() memory-mapped loading for efficiency + filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None + return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) + + +def get_parquet_num_frames(parquet_path: str | Path) -> int: + metadata = pq.read_metadata(parquet_path) + return metadata.num_rows + + +def get_file_size_in_mb(file_path: Path) -> float: + """Get file size on disk in megabytes. + + Args: + file_path (Path): Path to the file. + """ + file_size_bytes = file_path.stat().st_size + return file_size_bytes / (1024**2) + + +def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + """Embed image bytes into the dataset table before saving to Parquet. + + This function prepares a Hugging Face dataset for serialization by converting + image objects into an embedded format that can be stored in Arrow/Parquet. + + Args: + dataset (datasets.Dataset): The input dataset, possibly containing image features. + + Returns: + datasets.Dataset: The dataset with images embedded in the table storage. + """ + # Embed image bytes into the table before saving to parquet + format = dataset.format + dataset = dataset.with_format("arrow") + dataset = dataset.map(embed_table_storage, batched=False) + dataset = dataset.with_format(**format) + return dataset + + +def load_json(fpath: Path) -> Any: + """Load data from a JSON file. + + Args: + fpath (Path): Path to the JSON file. + + Returns: + Any: The data loaded from the JSON file. + """ + with open(fpath) as f: + return json.load(f) + + +def write_json(data: dict, fpath: Path) -> None: + """Write data to a JSON file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to write. + fpath (Path): The path to the output JSON file. + """ + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4, ensure_ascii=False) + + +def write_info(info: dict, local_dir: Path) -> None: + write_json(info, local_dir / INFO_PATH) + + +def load_info(local_dir: Path) -> dict: + """Load dataset info metadata from its standard file path. + + Also converts shape lists to tuples for consistency. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: The dataset information dictionary. + """ + info = load_json(local_dir / INFO_PATH) + for ft in info["features"].values(): + ft["shape"] = tuple(ft["shape"]) + return info + + +def write_stats(stats: dict, local_dir: Path) -> None: + """Serialize and write dataset statistics to their standard file path. + + Args: + stats (dict): The statistics dictionary (can contain tensors/numpy arrays). + local_dir (Path): The root directory of the dataset. + """ + serialized_stats = serialize_dict(stats) + write_json(serialized_stats, local_dir / STATS_PATH) + + +def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: + """Recursively cast numerical values in a stats dictionary to numpy arrays. + + Args: + stats (dict): The statistics dictionary. + + Returns: + dict: The statistics dictionary with values cast to numpy arrays. + """ + stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) + + +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: + """Load dataset statistics and cast numerical values to numpy arrays. + + Returns None if the stats file doesn't exist. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + A dictionary of statistics or None if the file is not found. + """ + if not (local_dir / STATS_PATH).exists(): + return None + stats = load_json(local_dir / STATS_PATH) + return cast_stats_to_numpy(stats) + + +def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: + path = local_dir / DEFAULT_TASKS_PATH + path.parent.mkdir(parents=True, exist_ok=True) + tasks.to_parquet(path) + + +def load_tasks(local_dir: Path) -> pandas.DataFrame: + tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) + tasks.index.name = "task" + return tasks + + +def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: + """Load subtasks from subtasks.parquet if it exists.""" + subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + return pd.read_parquet(subtasks_path) + return None + + +def write_episodes(episodes: Dataset, local_dir: Path) -> None: + """Write episode metadata to a parquet file in the LeRobot v3.0 format. + This function writes episode-level metadata to a single parquet file. + Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures. + + Args: + episodes: HuggingFace Dataset containing episode metadata + local_dir: Root directory where the dataset will be stored + """ + episode_size_mb = get_hf_dataset_size_in_mb(episodes) + if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB: + raise NotImplementedError( + f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. " + f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. " + "This function only supports single-file episode metadata. " + ) + + fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) + fpath.parent.mkdir(parents=True, exist_ok=True) + episodes.to_parquet(fpath) + + +def load_episodes(local_dir: Path) -> datasets.Dataset: + episodes = load_nested_dataset(local_dir / EPISODES_DIR) + # Select episode features/columns containing references to episode data and videos + # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) + # This is to speedup access to these data, instead of having to load episode stats. + episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) + return episodes + + +def load_image_as_numpy( + fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True +) -> np.ndarray: + """Load an image from a file into a numpy array. + + Args: + fpath (str | Path): Path to the image file. + dtype (np.dtype): The desired data type of the output array. If floating, + pixels are scaled to [0, 1]. + channel_first (bool): If True, converts the image to (C, H, W) format. + Otherwise, it remains in (H, W, C) format. + + Returns: + np.ndarray: The image as a numpy array. + """ + img = PILImage.open(fpath).convert("RGB") + img_array = np.array(img, dtype=dtype) + if channel_first: # (H, W, C) -> (C, H, W) + img_array = np.transpose(img_array, (2, 0, 1)) + if np.issubdtype(dtype, np.floating): + img_array /= 255.0 + return img_array + + +def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: + """Convert a batch from a Hugging Face dataset to torch tensors. + + This transform function converts items from Hugging Face dataset format (pyarrow) + to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) + to a torch image representation (C, H, W, float32) in the range [0, 1]. Other + types are converted to torch.tensor. + + Args: + items_dict (dict): A dictionary representing a batch of data from a + Hugging Face dataset. + + Returns: + dict: The batch with items converted to torch tensors. + """ + for key in items_dict: + first_item = items_dict[key][0] + if isinstance(first_item, PILImage.Image): + to_tensor = transforms.ToTensor() + items_dict[key] = [to_tensor(img) for img in items_dict[key]] + elif first_item is None: + pass + else: + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] + return items_dict + + +def to_parquet_with_hf_images( + df: pandas.DataFrame, path: Path, features: datasets.Features | None = None +) -> None: + """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. + This way, it can be loaded by HF dataset and correctly formatted images are returned. + + Args: + df: DataFrame to write to parquet. + path: Path to write the parquet file. + features: Optional HuggingFace Features schema. If provided, ensures image columns + are properly typed as Image() in the parquet schema. + """ + # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only + ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) + ds.to_parquet(path) + + +def item_to_torch(item: dict) -> dict: + """Convert all items in a dictionary to PyTorch tensors where appropriate. + + This function is used to convert an item from a streaming dataset to PyTorch tensors. + + Args: + item (dict): Dictionary of items from a dataset. + + Returns: + dict: Dictionary with all tensor-like items converted to torch.Tensor. + """ + for key, val in item.items(): + if isinstance(val, (np.ndarray | list)) and key not in ["task"]: + # Convert numpy arrays and lists to torch tensors + item[key] = torch.tensor(val) + return item diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 26f0c769c..8f0600ba8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -23,526 +23,52 @@ from pathlib import Path import datasets import numpy as np -import packaging.version import pandas as pd import PIL.Image -import pyarrow as pa import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats -from lerobot.datasets.image_writer import AsyncImageWriter, write_image -from lerobot.datasets.utils import ( - DEFAULT_EPISODES_PATH, - DEFAULT_FEATURES, - DEFAULT_IMAGE_PATH, - INFO_PATH, - _validate_feature_names, +from lerobot.datasets.compute_stats import compute_episode_stats +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( check_delta_timestamps, - check_version_compatibility, - create_empty_dataset_info, - create_lerobot_dataset_card, - embed_images, - flatten_dict, get_delta_indices, - get_file_size_in_mb, get_hf_features_from_features, - get_safe_version, - hf_transform_to_torch, - is_valid_version, - load_episodes, - load_info, - load_nested_dataset, - load_stats, - load_subtasks, - load_tasks, - update_chunk_file_indices, validate_episode_buffer, validate_frame, +) +from lerobot.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.datasets.io_utils import ( + embed_images, + get_file_size_in_mb, + hf_transform_to_torch, + load_episodes, + load_nested_dataset, write_info, - write_json, - write_stats, - write_tasks, +) +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_IMAGE_PATH, + create_lerobot_dataset_card, + get_safe_version, + is_valid_version, + update_chunk_file_indices, ) from lerobot.datasets.video_utils import ( StreamingVideoEncoder, - VideoFrame, concatenate_video_files, decode_video_frames, encode_video_frames, get_safe_default_codec, get_video_duration_in_s, - get_video_info, resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME -CODEBASE_VERSION = "v3.0" - - -class LeRobotDatasetMetadata: - def __init__( - self, - repo_id: str, - root: str | Path | None = None, - revision: str | None = None, - force_cache_sync: bool = False, - metadata_buffer_size: int = 10, - ): - self.repo_id = repo_id - self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - self.writer = None - self.latest_episode = None - self.metadata_buffer: list[dict] = [] - self.metadata_buffer_size = metadata_buffer_size - - try: - if force_cache_sync: - raise FileNotFoundError - self.load_metadata() - except (FileNotFoundError, NotADirectoryError): - if is_valid_version(self.revision): - self.revision = get_safe_version(self.repo_id, self.revision) - - (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") - self.load_metadata() - - def _flush_metadata_buffer(self) -> None: - """Write all buffered episode metadata to parquet file.""" - if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: - return - - combined_dict = {} - for episode_dict in self.metadata_buffer: - for key, value in episode_dict.items(): - if key not in combined_dict: - combined_dict[key] = [] - # Extract value and serialize numpy arrays - # because PyArrow's from_pydict function doesn't support numpy arrays - val = value[0] if isinstance(value, list) else value - combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) - - first_ep = self.metadata_buffer[0] - chunk_idx = first_ep["meta/episodes/chunk_index"][0] - file_idx = first_ep["meta/episodes/file_index"][0] - - table = pa.Table.from_pydict(combined_dict) - - if not self.writer: - path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) - path.parent.mkdir(parents=True, exist_ok=True) - self.writer = pq.ParquetWriter( - path, schema=table.schema, compression="snappy", use_dictionary=True - ) - - self.writer.write_table(table) - - self.latest_episode = self.metadata_buffer[-1] - self.metadata_buffer.clear() - - def _close_writer(self) -> None: - """Close and cleanup the parquet writer if it exists.""" - self._flush_metadata_buffer() - - writer = getattr(self, "writer", None) - if writer is not None: - writer.close() - self.writer = None - - def __del__(self): - """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor - """ - self._close_writer() - - def load_metadata(self): - self.info = load_info(self.root) - check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) - self.tasks = load_tasks(self.root) - self.subtasks = load_subtasks(self.root) - self.episodes = load_episodes(self.root) - self.stats = load_stats(self.root) - - def pull_from_repo( - self, - allow_patterns: list[str] | str | None = None, - ignore_patterns: list[str] | str | None = None, - ) -> None: - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self.revision, - local_dir=self.root, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - ) - - @property - def url_root(self) -> str: - return f"hf://datasets/{self.repo_id}" - - @property - def _version(self) -> packaging.version.Version: - """Codebase version used to create this dataset.""" - return packaging.version.parse(self.info["codebase_version"]) - - def get_data_file_path(self, ep_index: int) -> Path: - if self.episodes is None: - self.episodes = load_episodes(self.root) - if ep_index >= len(self.episodes): - raise IndexError( - f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" - ) - ep = self.episodes[ep_index] - chunk_idx = ep["data/chunk_index"] - file_idx = ep["data/file_index"] - fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - return Path(fpath) - - def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: - if self.episodes is None: - self.episodes = load_episodes(self.root) - if ep_index >= len(self.episodes): - raise IndexError( - f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" - ) - ep = self.episodes[ep_index] - chunk_idx = ep[f"videos/{vid_key}/chunk_index"] - file_idx = ep[f"videos/{vid_key}/file_index"] - fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) - return Path(fpath) - - @property - def data_path(self) -> str: - """Formattable string for the parquet files.""" - return self.info["data_path"] - - @property - def video_path(self) -> str | None: - """Formattable string for the video files.""" - return self.info["video_path"] - - @property - def robot_type(self) -> str | None: - """Robot type used in recording this dataset.""" - return self.info["robot_type"] - - @property - def fps(self) -> int: - """Frames per second used during data collection.""" - return self.info["fps"] - - @property - def features(self) -> dict[str, dict]: - """All features contained in the dataset.""" - return self.info["features"] - - @property - def image_keys(self) -> list[str]: - """Keys to access visual modalities stored as images.""" - return [key for key, ft in self.features.items() if ft["dtype"] == "image"] - - @property - def video_keys(self) -> list[str]: - """Keys to access visual modalities stored as videos.""" - return [key for key, ft in self.features.items() if ft["dtype"] == "video"] - - @property - def camera_keys(self) -> list[str]: - """Keys to access visual modalities (regardless of their storage method).""" - return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] - - @property - def names(self) -> dict[str, list | dict]: - """Names of the various dimensions of vector modalities.""" - return {key: ft["names"] for key, ft in self.features.items()} - - @property - def shapes(self) -> dict: - """Shapes for the different features.""" - return {key: tuple(ft["shape"]) for key, ft in self.features.items()} - - @property - def total_episodes(self) -> int: - """Total number of episodes available.""" - return self.info["total_episodes"] - - @property - def total_frames(self) -> int: - """Total number of frames saved in this dataset.""" - return self.info["total_frames"] - - @property - def total_tasks(self) -> int: - """Total number of different tasks performed in this dataset.""" - return self.info["total_tasks"] - - @property - def chunks_size(self) -> int: - """Max number of files per chunk.""" - return self.info["chunks_size"] - - @property - def data_files_size_in_mb(self) -> int: - """Max size of data file in mega bytes.""" - return self.info["data_files_size_in_mb"] - - @property - def video_files_size_in_mb(self) -> int: - """Max size of video file in mega bytes.""" - return self.info["video_files_size_in_mb"] - - def get_task_index(self, task: str) -> int | None: - """ - Given a task in natural language, returns its task_index if the task already exists in the dataset, - otherwise return None. - """ - if task in self.tasks.index: - return int(self.tasks.loc[task].task_index) - else: - return None - - def save_episode_tasks(self, tasks: list[str]): - if len(set(tasks)) != len(tasks): - raise ValueError(f"Tasks are not unique: {tasks}") - - if self.tasks is None: - new_tasks = tasks - task_indices = range(len(tasks)) - self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task")) - else: - new_tasks = [task for task in tasks if task not in self.tasks.index] - new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) - for task_idx, task in zip(new_task_indices, new_tasks, strict=False): - self.tasks.loc[task] = task_idx - - if len(new_tasks) > 0: - # Update on disk - write_tasks(self.tasks, self.root) - - def _save_episode_metadata(self, episode_dict: dict) -> None: - """Buffer episode metadata and write to parquet in batches for efficiency. - - This function accumulates episode metadata in a buffer and flushes it when the buffer - reaches the configured size. This reduces I/O overhead by writing multiple episodes - at once instead of one row at a time. - - Notes: We both need to update parquet files and HF dataset: - - `pandas` loads parquet file in RAM - - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, - or loads directly from pyarrow cache. - """ - # Convert to list format for each value - episode_dict = {key: [value] for key, value in episode_dict.items()} - num_frames = episode_dict["length"][0] - - if self.latest_episode is None: - # Initialize indices and frame count for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - if self.episodes is not None and len(self.episodes) > 0: - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] - file_idx = self.episodes[-1]["meta/episodes/file_index"] - latest_num_frames = self.episodes[-1]["dataset_to_index"] - episode_dict["dataset_from_index"] = [latest_num_frames] - episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] - - # When resuming, move to the next file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) - else: - episode_dict["dataset_from_index"] = [0] - episode_dict["dataset_to_index"] = [num_frames] - - episode_dict["meta/episodes/chunk_index"] = [chunk_idx] - episode_dict["meta/episodes/file_index"] = [file_idx] - else: - chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] - file_idx = self.latest_episode["meta/episodes/file_index"][0] - - latest_path = ( - self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - if self.writer is None - else self.writer.where - ) - - if Path(latest_path).exists(): - latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) - latest_num_frames = self.latest_episode["episode_index"][0] - - av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 - - if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: - # Size limit is reached, flush buffer and prepare new parquet file - self._flush_metadata_buffer() - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) - self._close_writer() - - # Update the existing pandas dataframe with new row - episode_dict["meta/episodes/chunk_index"] = [chunk_idx] - episode_dict["meta/episodes/file_index"] = [file_idx] - episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] - episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] - - # Add to buffer - self.metadata_buffer.append(episode_dict) - self.latest_episode = episode_dict - - if len(self.metadata_buffer) >= self.metadata_buffer_size: - self._flush_metadata_buffer() - - def save_episode( - self, - episode_index: int, - episode_length: int, - episode_tasks: list[str], - episode_stats: dict[str, dict], - episode_metadata: dict, - ) -> None: - episode_dict = { - "episode_index": episode_index, - "tasks": episode_tasks, - "length": episode_length, - } - episode_dict.update(episode_metadata) - episode_dict.update(flatten_dict({"stats": episode_stats})) - self._save_episode_metadata(episode_dict) - - # Update info - self.info["total_episodes"] += 1 - self.info["total_frames"] += episode_length - self.info["total_tasks"] = len(self.tasks) - self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} - - write_info(self.info, self.root) - - self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats - write_stats(self.stats, self.root) - - def update_video_info(self, video_key: str | None = None) -> None: - """ - Warning: this function writes info from first episode videos, implicitly assuming that all videos have - been encoded the same way. Also, this means it assumes the first episode exists. - """ - if video_key is not None and video_key not in self.video_keys: - raise ValueError(f"Video key {video_key} not found in dataset") - - video_keys = [video_key] if video_key is not None else self.video_keys - for key in video_keys: - if not self.features[key].get("info", None): - video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) - self.info["features"][key]["info"] = get_video_info(video_path) - - def update_chunk_settings( - self, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, - ) -> None: - """Update chunk and file size settings after dataset creation. - - This allows users to customize storage organization without modifying the constructor. - These settings control how episodes are chunked and how large files can grow before - creating new ones. - - Args: - chunks_size: Maximum number of files per chunk directory. If None, keeps current value. - data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. - video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. - """ - if chunks_size is not None: - if chunks_size <= 0: - raise ValueError(f"chunks_size must be positive, got {chunks_size}") - self.info["chunks_size"] = chunks_size - - if data_files_size_in_mb is not None: - if data_files_size_in_mb <= 0: - raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") - self.info["data_files_size_in_mb"] = data_files_size_in_mb - - if video_files_size_in_mb is not None: - if video_files_size_in_mb <= 0: - raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") - self.info["video_files_size_in_mb"] = video_files_size_in_mb - - # Update the info file on disk - write_info(self.info, self.root) - - def get_chunk_settings(self) -> dict[str, int]: - """Get current chunk and file size settings. - - Returns: - Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. - """ - return { - "chunks_size": self.chunks_size, - "data_files_size_in_mb": self.data_files_size_in_mb, - "video_files_size_in_mb": self.video_files_size_in_mb, - } - - def __repr__(self): - feature_keys = list(self.features) - return ( - f"{self.__class__.__name__}({{\n" - f" Repository ID: '{self.repo_id}',\n" - f" Total episodes: '{self.total_episodes}',\n" - f" Total frames: '{self.total_frames}',\n" - f" Features: '{feature_keys}',\n" - "})',\n" - ) - - @classmethod - def create( - cls, - repo_id: str, - fps: int, - features: dict, - robot_type: str | None = None, - root: str | Path | None = None, - use_videos: bool = True, - metadata_buffer_size: int = 10, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, - ) -> "LeRobotDatasetMetadata": - """Creates metadata for a LeRobotDataset.""" - obj = cls.__new__(cls) - obj.repo_id = repo_id - obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - - obj.root.mkdir(parents=True, exist_ok=False) - - features = {**features, **DEFAULT_FEATURES} - _validate_feature_names(features) - - obj.tasks = None - obj.subtasks = None - obj.episodes = None - obj.stats = None - obj.info = create_empty_dataset_info( - CODEBASE_VERSION, - fps, - features, - use_videos, - robot_type, - chunks_size, - data_files_size_in_mb, - video_files_size_in_mb, - ) - if len(obj.video_keys) > 0 and not use_videos: - raise ValueError() - write_json(obj.info, obj.root / INFO_PATH) - obj.revision = None - obj.writer = None - obj.latest_episode = None - obj.metadata_buffer = [] - obj.metadata_buffer_size = metadata_buffer_size - return obj +logger = logging.getLogger(__name__) def _encode_video_worker( @@ -596,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset): the dataset from that address and load it, pending your dataset is compliant with codebase_version v3.0. If your dataset has been created before this new format, you will be prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at - lerobot/datasets/v30/convert_dataset_v21_to_v30.py. + lerobot/scripts/convert_dataset_v21_to_v30.py. 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty @@ -1326,7 +852,7 @@ class LeRobotDataset(torch.utils.data.Dataset): temp_path = future.result() results[video_key] = temp_path except Exception as exc: - logging.error(f"Video encoding failed for {video_key}: {exc}") + logger.error(f"Video encoding failed for {video_key}: {exc}") raise exc for video_key in self.meta.video_keys: @@ -1365,7 +891,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if end_episode is None: end_episode = self.num_episodes - logging.info( + logger.info( f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" ) @@ -1375,7 +901,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_df = pd.read_parquet(episode_df_path) for ep_idx in range(start_episode, end_episode): - logging.info(f"Encoding videos for episode {ep_idx}") + logger.info(f"Encoding videos for episode {ep_idx}") if ( self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx @@ -1605,7 +1131,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: if isinstance(self.image_writer, AsyncImageWriter): - logging.warning( + logger.warning( "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." ) @@ -1683,7 +1209,6 @@ class LeRobotDataset(torch.utils.data.Dataset): if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) - # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj.create_episode_buffer() obj.episodes = None @@ -1717,184 +1242,3 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._streaming_encoder = None return obj - - -class MultiLeRobotDataset(torch.utils.data.Dataset): - """A dataset consisting of multiple underlying `LeRobotDataset`s. - - The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API - structure of `LeRobotDataset`. - """ - - def __init__( - self, - repo_ids: list[str], - root: str | Path | None = None, - episodes: dict | None = None, - image_transforms: Callable | None = None, - delta_timestamps: dict[str, list[float]] | None = None, - tolerances_s: dict | None = None, - download_videos: bool = True, - video_backend: str | None = None, - ): - super().__init__() - self.repo_ids = repo_ids - self.root = Path(root) if root else HF_LEROBOT_HOME - self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) - # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which - # are handled by this class. - self._datasets = [ - LeRobotDataset( - repo_id, - root=self.root / repo_id, - episodes=episodes[repo_id] if episodes else None, - image_transforms=image_transforms, - delta_timestamps=delta_timestamps, - tolerance_s=self.tolerances_s[repo_id], - download_videos=download_videos, - video_backend=video_backend, - ) - for repo_id in repo_ids - ] - - # Disable any data keys that are not common across all of the datasets. Note: we may relax this - # restriction in future iterations of this class. For now, this is necessary at least for being able - # to use PyTorch's default DataLoader collate function. - self.disabled_features = set() - intersection_features = set(self._datasets[0].features) - for ds in self._datasets: - intersection_features.intersection_update(ds.features) - if len(intersection_features) == 0: - raise RuntimeError( - "Multiple datasets were provided but they had no keys common to all of them. " - "The multi-dataset functionality currently only keeps common keys." - ) - for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): - extra_keys = set(ds.features).difference(intersection_features) - if extra_keys: - logging.warning( - f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " - "other datasets." - ) - self.disabled_features.update(extra_keys) - - self.image_transforms = image_transforms - self.delta_timestamps = delta_timestamps - # TODO(rcadene, aliberts): We should not perform this aggregation for datasets - # with multiple robots of different ranges. Instead we should have one normalization - # per robot. - self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) - - @property - def repo_id_to_index(self): - """Return a mapping from dataset repo_id to a dataset index automatically created by this class. - - This index is incorporated as a data key in the dictionary returned by `__getitem__`. - """ - return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} - - @property - def fps(self) -> int: - """Frames per second used during data collection. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ - return self._datasets[0].meta.info["fps"] - - @property - def video(self) -> bool: - """Returns True if this dataset loads video frames from mp4 files. - - Returns False if it only loads images from png files. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ - return self._datasets[0].meta.info.get("video", False) - - @property - def features(self) -> datasets.Features: - features = {} - for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) - return features - - @property - def camera_keys(self) -> list[str]: - """Keys to access image and video stream from cameras.""" - keys = [] - for key, feats in self.features.items(): - if isinstance(feats, (datasets.Image | VideoFrame)): - keys.append(key) - return keys - - @property - def video_frame_keys(self) -> list[str]: - """Keys to access video frames that requires to be decoded into images. - - Note: It is empty if the dataset contains images only, - or equal to `self.cameras` if the dataset contains videos only, - or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. - """ - video_frame_keys = [] - for key, feats in self.features.items(): - if isinstance(feats, VideoFrame): - video_frame_keys.append(key) - return video_frame_keys - - @property - def num_frames(self) -> int: - """Number of samples/frames.""" - return sum(d.num_frames for d in self._datasets) - - @property - def num_episodes(self) -> int: - """Number of episodes.""" - return sum(d.num_episodes for d in self._datasets) - - @property - def tolerance_s(self) -> float: - """Tolerance in seconds used to discard loaded frames when their timestamps - are not close enough from the requested frames. It is only used when `delta_timestamps` - is provided or when loading video frames from mp4 files. - """ - # 1e-4 to account for possible numerical error - return 1 / self.fps - 1e-4 - - def __len__(self): - return self.num_frames - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - if idx >= len(self): - raise IndexError(f"Index {idx} out of bounds.") - # Determine which dataset to get an item from based on the index. - start_idx = 0 - dataset_idx = 0 - for dataset in self._datasets: - if idx >= start_idx + dataset.num_frames: - start_idx += dataset.num_frames - dataset_idx += 1 - continue - break - else: - raise AssertionError("We expect the loop to break out as long as the index is within bounds.") - item = self._datasets[dataset_idx][idx - start_idx] - item["dataset_index"] = torch.tensor(dataset_idx) - for data_key in self.disabled_features: - if data_key in item: - del item[data_key] - - return item - - def __repr__(self): - return ( - f"{self.__class__.__name__}(\n" - f" Repository IDs: '{self.repo_ids}',\n" - f" Number of Samples: {self.num_frames},\n" - f" Number of Episodes: {self.num_episodes},\n" - f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" - f" Recorded Frames per Second: {self.fps},\n" - f" Camera Keys: {self.camera_keys},\n" - f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.image_transforms},\n" - f")" - ) diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py new file mode 100644 index 000000000..917d5c5eb --- /dev/null +++ b/src/lerobot/datasets/multi_dataset.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections.abc import Callable +from pathlib import Path + +import datasets +import torch +import torch.utils + +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.video_utils import VideoFrame +from lerobot.utils.constants import HF_LEROBOT_HOME + +logger = logging.getLogger(__name__) + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + root: str | Path | None = None, + episodes: dict | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[str, list[float]] | None = None, + tolerances_s: dict | None = None, + download_videos: bool = True, + video_backend: str | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + self.root = Path(root) if root else HF_LEROBOT_HOME + self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + root=self.root / repo_id, + episodes=episodes[repo_id] if episodes else None, + image_transforms=image_transforms, + delta_timestamps=delta_timestamps, + tolerance_s=self.tolerances_s[repo_id], + download_videos=download_videos, + video_backend=video_backend, + ) + for repo_id in repo_ids + ] + + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_features = set() + intersection_features = set(self._datasets[0].features) + for ds in self._datasets: + intersection_features.intersection_update(ds.features) + if len(intersection_features) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. " + "The multi-dataset functionality currently only keeps common keys." + ) + for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(ds.features).difference(intersection_features) + if extra_keys: + logger.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) + + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image | VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_frames(self) -> int: + """Number of samples/frames.""" + return sum(d.num_frames for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_frames: + start_idx += dataset.num_frames + dataset_idx += 1 + continue + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_features: + if data_key in item: + del item[data_key] + + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Number of Samples: {self.num_frames},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.image_transforms},\n" + f")" + ) diff --git a/src/lerobot/datasets/online_buffer.py b/src/lerobot/datasets/online_buffer.py deleted file mode 100644 index 563d800b9..000000000 --- a/src/lerobot/datasets/online_buffer.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""An online buffer for the online training loop in train.py - -Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should -consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much -faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it -supports in-place slicing and mutation which is very handy for a dynamic buffer. -""" - -import os -from pathlib import Path -from typing import Any - -import numpy as np -import torch - -from lerobot.datasets.lerobot_dataset import LeRobotDataset - - -def _make_memmap_safe(**kwargs) -> np.memmap: - """Make a numpy memmap with checks on available disk space first. - - Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape" - - For information on dtypes: - https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing - """ - if kwargs["mode"].startswith("w"): - required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes - stats = os.statvfs(Path(kwargs["filename"]).parent) - available_space = stats.f_bavail * stats.f_frsize # bytes - if required_space >= available_space * 0.8: - raise RuntimeError( - f"You're about to take up {required_space} of {available_space} bytes available." - ) - return np.memmap(**kwargs) - - -class OnlineBuffer(torch.utils.data.Dataset): - """FIFO data buffer for the online training loop in train.py. - - Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training - loop in the same way that a LeRobotDataset would be used. - - The underlying data structure will have data inserted in a circular fashion. Always insert after the - last index, and when you reach the end, wrap around to the start. - - The data is stored in a numpy memmap. - """ - - NEXT_INDEX_KEY = "_next_index" - OCCUPANCY_MASK_KEY = "_occupancy_mask" - INDEX_KEY = "index" - FRAME_INDEX_KEY = "frame_index" - EPISODE_INDEX_KEY = "episode_index" - TIMESTAMP_KEY = "timestamp" - IS_PAD_POSTFIX = "_is_pad" - - def __init__( - self, - write_dir: str | Path, - data_spec: dict[str, Any] | None, - buffer_capacity: int | None, - fps: float | None = None, - delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None, - ): - """ - The online buffer can be provided from scratch or you can load an existing online buffer by passing - a `write_dir` associated with an existing buffer. - - Args: - write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key. - Note that if the files already exist, they are opened in read-write mode (used for training - resumption.) - data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int], - "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer, - but note that "index", "frame_index" and "episode_index" are already accounted for by this - class, so you don't need to include them. - buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your - system's available disk space when choosing this. - fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the - delta_timestamps logic. You can pass None if you are not using delta_timestamps. - delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally - converted to dict[str, np.ndarray] for optimization purposes. - - """ - self.set_delta_timestamps(delta_timestamps) - self._fps = fps - # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from - # the requested frames. It is only used when `delta_timestamps` is provided. - # minus 1e-4 to account for possible numerical error - self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None - self._buffer_capacity = buffer_capacity - data_spec = self._make_data_spec(data_spec, buffer_capacity) - Path(write_dir).mkdir(parents=True, exist_ok=True) - self._data = {} - for k, v in data_spec.items(): - self._data[k] = _make_memmap_safe( - filename=Path(write_dir) / k, - dtype=v["dtype"] if v is not None else None, - mode="r+" if (Path(write_dir) / k).exists() else "w+", - shape=tuple(v["shape"]) if v is not None else None, - ) - - @property - def delta_timestamps(self) -> dict[str, np.ndarray] | None: - return self._delta_timestamps - - def set_delta_timestamps(self, value: dict[str, list[float]] | None): - """Set delta_timestamps converting the values to numpy arrays. - - The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays - need to be converted into numpy arrays. - """ - if value is not None: - self._delta_timestamps = {k: np.array(v) for k, v in value.items()} - else: - self._delta_timestamps = None - - def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: - """Makes the data spec for np.memmap.""" - if any(k.startswith("_") for k in data_spec): - raise ValueError( - "data_spec keys should not start with '_'. This prefix is reserved for internal logic." - ) - preset_keys = { - OnlineBuffer.INDEX_KEY, - OnlineBuffer.FRAME_INDEX_KEY, - OnlineBuffer.EPISODE_INDEX_KEY, - OnlineBuffer.TIMESTAMP_KEY, - } - if len(intersection := set(data_spec).intersection(preset_keys)) > 0: - raise ValueError( - f"data_spec should not contain any of {preset_keys} as these are handled internally. " - f"The provided data_spec has {intersection}." - ) - complete_data_spec = { - # _next_index will be a pointer to the next index that we should start filling from when we add - # more data. - OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, - # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied - # with real data rather than the dummy initialization. - OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, - OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, - } - for k, v in data_spec.items(): - complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} - return complete_data_spec - - def add_data(self, data: dict[str, np.ndarray]): - """Add new data to the buffer, which could potentially mean shifting old data out. - - The new data should contain all the frames (in order) of any number of episodes. The indices should - start from 0 (note to the developer: this can easily be generalized). See the `rollout` and - `eval_policy` functions in `eval.py` for more information on how the data is constructed. - - Shift the incoming data index and episode_index to continue on from the last frame. Note that this - will be done in place! - """ - if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0: - raise ValueError(f"Missing data keys: {missing_keys}") - new_data_length = len(data[self.data_keys[0]]) - if not all(len(data[k]) == new_data_length for k in self.data_keys): - raise ValueError("All data items should have the same length") - - next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY] - - # Sanity check to make sure that the new data indices start from 0. - assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0 - assert data[OnlineBuffer.INDEX_KEY][0].item() == 0 - - # Shift the incoming indices if necessary. - if self.num_frames > 0: - last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] - last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] - data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 - data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 - - # Insert the new data starting from next_index. It may be necessary to wrap around to the start. - n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index)) - for k in self.data_keys: - if n_surplus == 0: - slc = slice(next_index, next_index + new_data_length) - self._data[k][slc] = data[k] - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True - else: - self._data[k][next_index:] = data[k][:-n_surplus] - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True - self._data[k][:n_surplus] = data[k][-n_surplus:] - if n_surplus == 0: - self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length - else: - self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus - - @property - def data_keys(self) -> list[str]: - keys = set(self._data) - keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY) - keys.remove(OnlineBuffer.NEXT_INDEX_KEY) - return sorted(keys) - - @property - def fps(self) -> float | None: - return self._fps - - @property - def num_episodes(self) -> int: - return len( - np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) - ) - - @property - def num_frames(self) -> int: - return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]) - - def __len__(self): - return self.num_frames - - def _item_to_tensors(self, item: dict) -> dict: - item_ = {} - for k, v in item.items(): - if isinstance(v, torch.Tensor): - item_[k] = v - elif isinstance(v, np.ndarray): - item_[k] = torch.from_numpy(v) - else: - item_[k] = torch.tensor(v) - return item_ - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - if idx >= len(self) or idx < -len(self): - raise IndexError - - item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")} - - if self.delta_timestamps is None: - return self._item_to_tensors(item) - - episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY] - current_ts = item[OnlineBuffer.TIMESTAMP_KEY] - episode_data_indices = np.where( - np.bitwise_and( - self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index, - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], - ) - )[0] - episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] - - for data_key in self.delta_timestamps: - # Note: The logic in this loop is copied from `load_previous_and_future_frames`. - # Get timestamps used as query to retrieve data of previous/future frames. - query_ts = current_ts + self.delta_timestamps[data_key] - - # Compute distances between each query timestamp and all timestamps of all the frames belonging to - # the episode. - dist = np.abs(query_ts[:, None] - episode_timestamps[None, :]) - argmin_ = np.argmin(dist, axis=1) - min_ = dist[np.arange(dist.shape[0]), argmin_] - - is_pad = min_ > self.tolerance_s - - # Check violated query timestamps are all outside the episode range. - assert ( - (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) - ).all(), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" - ") inside the episode range." - ) - - # Load frames for this data key. - item[data_key] = self._data[data_key][episode_data_indices[argmin_]] - - item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad - - return self._item_to_tensors(item) - - def get_data_by_key(self, key: str) -> torch.Tensor: - """Returns all data for a given data key as a Tensor.""" - return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) - - -def compute_sampler_weights( - offline_dataset: LeRobotDataset, - offline_drop_n_last_frames: int = 0, - online_dataset: OnlineBuffer | None = None, - online_sampling_ratio: float | None = None, - online_drop_n_last_frames: int = 0, -) -> torch.Tensor: - """Compute the sampling weights for the online training dataloader in train.py. - - Args: - offline_dataset: The LeRobotDataset used for offline pre-training. - online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode. - online_dataset: The OnlineBuffer used in online training. - online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an - online dataset is provided, this value must also be provided. - online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online - dataset. - Returns: - Tensor of weights for [offline_dataset; online_dataset], normalized to 1. - - Notes to maintainers: - - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach. - - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace - `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature - is the ability to turn shuffling off. - - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not - included here to avoid adding complexity. - """ - if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): - raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") - if (online_dataset is None) ^ (online_sampling_ratio is None): - raise ValueError( - "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." - ) - offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio - - weights = [] - - if len(offline_dataset) > 0: - offline_data_mask_indices = [] - for start_index, end_index in zip( - offline_dataset.meta.episodes["dataset_from_index"], - offline_dataset.meta.episodes["dataset_to_index"], - strict=True, - ): - offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames)) - offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) - offline_data_mask[torch.tensor(offline_data_mask_indices)] = True - weights.append( - torch.full( - size=(len(offline_dataset),), - fill_value=offline_sampling_ratio / offline_data_mask.sum(), - ) - * offline_data_mask - ) - - if online_dataset is not None and len(online_dataset) > 0: - online_data_mask_indices = [] - episode_indices = online_dataset.get_data_by_key("episode_index") - for episode_idx in torch.unique(episode_indices): - where_episode = torch.where(episode_indices == episode_idx) - start_index = where_episode[0][0] - end_index = where_episode[0][-1] + 1 - online_data_mask_indices.extend( - range(start_index.item(), end_index.item() - online_drop_n_last_frames) - ) - online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool) - online_data_mask[torch.tensor(online_data_mask_indices)] = True - weights.append( - torch.full( - size=(len(online_dataset),), - fill_value=online_sampling_ratio / online_data_mask.sum(), - ) - * online_data_mask - ) - - weights = torch.cat(weights) - - if weights.sum() == 0: - weights += 1 / len(weights) - else: - weights /= weights.sum() - - return weights diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 161633f26..96779fdc6 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -17,8 +17,9 @@ from collections.abc import Sequence from typing import Any from lerobot.configs.types import PipelineFeatureType -from lerobot.datasets.utils import hw_to_dataset_features -from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation +from lerobot.datasets.feature_utils import hw_to_dataset_features +from lerobot.processor import DataProcessorPipeline +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR @@ -43,11 +44,11 @@ def create_initial_features( return features -# Helper to filter state/action keys based on regex patterns. -def should_keep(key: str, patterns: tuple[str]) -> bool: +# Helper to filter state/action keys based on compiled regex patterns. +def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool: if patterns is None: return True - return any(re.search(pat, key) for pat in patterns) + return any(pat.search(key) for pat in patterns) def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str: @@ -88,6 +89,8 @@ def aggregate_pipeline_dataset_features( Returns: A dictionary of features formatted for a Hugging Face LeRobot Dataset. """ + compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None + all_features = pipeline.transform_features(initial_features) # Intermediate storage for categorized and filtered features. @@ -119,7 +122,7 @@ def aggregate_pipeline_dataset_features( # 2. Apply filtering rules. if is_image and not use_videos: continue - if not is_image and not should_keep(key, patterns): + if not is_image and not should_keep(key, compiled_patterns): continue # 3. Add the feature to the appropriate group with a clean name. diff --git a/src/lerobot/datasets/push_dataset_to_hub/utils.py b/src/lerobot/datasets/push_dataset_to_hub/utils.py deleted file mode 100644 index 48214e1bf..000000000 --- a/src/lerobot/datasets/push_dataset_to_hub/utils.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datasets -import torch - - -# TODO(aliberts): remove -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: - """ - Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. - - Parameters: - - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. - - Returns: - - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: - - "from": A tensor containing the starting index of each episode. - - "to": A tensor containing the ending index of each episode. - """ - episode_data_index = {"from": [], "to": []} - - current_episode = None - """ - The episode_index is a list of integers, each representing the episode index of the corresponding example. - For instance, the following is a valid episode_index: - [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] - - Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and - ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: - { - "from": [0, 3, 7], - "to": [3, 7, 12] - } - """ - if len(hf_dataset) == 0: - episode_data_index = { - "from": torch.tensor([]), - "to": torch.tensor([]), - } - return episode_data_index - for idx, episode_idx in enumerate(hf_dataset["episode_index"]): - if episode_idx != current_episode: - # We encountered a new episode, so we append its starting location to the "from" list - episode_data_index["from"].append(idx) - # If this is not the first episode, we append the ending location of the previous episode to the "to" list - if current_episode is not None: - episode_data_index["to"].append(idx) - # Let's keep track of the current episode index - current_episode = episode_idx - else: - # We are still in the same episode, so there is nothing for us to do here - pass - # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list - episode_data_index["to"].append(idx + 1) - - for k in ["from", "to"]: - episode_data_index[k] = torch.tensor(episode_data_index[k]) - - return episode_data_index diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index d0bb20c27..2bf7ab922 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -13,10 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections.abc import Iterator import torch +logger = logging.getLogger(__name__) + class EpisodeAwareSampler: def __init__( @@ -39,13 +42,35 @@ class EpisodeAwareSampler: drop_n_last_frames: Number of frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. """ + if drop_n_first_frames < 0: + raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") + if drop_n_last_frames < 0: + raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") + indices = [] for episode_idx, (start_index, end_index) in enumerate( zip(dataset_from_indices, dataset_to_indices, strict=True) ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: + ep_length = end_index - start_index + if drop_n_first_frames + drop_n_last_frames >= ep_length: + logger.warning( + "Episode %d has %d frames but drop_n_first_frames=%d and " + "drop_n_last_frames=%d removes all frames. Skipping.", + episode_idx, + ep_length, + drop_n_first_frames, + drop_n_last_frames, + ) + continue indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) + if not indices: + raise ValueError( + "No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. " + "All episodes were either filtered out or had too few frames." + ) + self.indices = indices self.shuffle = shuffle diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 454389d46..62e00558a 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -13,7 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Generator, Iterator +from collections import deque +from collections.abc import Callable, Generator, Iterable, Iterator from pathlib import Path import datasets @@ -21,16 +22,13 @@ import numpy as np import torch from datasets import load_dataset -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import get_delta_indices +from lerobot.datasets.io_utils import item_to_torch from lerobot.datasets.utils import ( - Backtrackable, - LookAheadError, - LookBackError, check_version_compatibility, find_float_index, - get_delta_indices, is_float_in_list, - item_to_torch, safe_shard, ) from lerobot.datasets.video_utils import ( @@ -40,6 +38,164 @@ from lerobot.datasets.video_utils import ( from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE +class LookBackError(Exception): + """ + Exception raised when trying to look back in the history of a Backtrackable object. + """ + + pass + + +class LookAheadError(Exception): + """ + Exception raised when trying to look ahead in the future of a Backtrackable object. + """ + + pass + + +class Backtrackable[T]: + """ + Wrap any iterator/iterable so you can step back up to `history` items + and look ahead up to `lookahead` items. + + This is useful for streaming datasets where you need to access previous and future items + but can't load the entire dataset into memory. + + Example: + ------- + ```python + ds = load_dataset("c4", "en", streaming=True, split="train") + rev = Backtrackable(ds, history=3, lookahead=2) + + x0 = next(rev) # forward + x1 = next(rev) + x2 = next(rev) + + # Look ahead + x3_peek = rev.peek_ahead(1) # next item without moving cursor + x4_peek = rev.peek_ahead(2) # two items ahead + + # Look back + x1_again = rev.peek_back(1) # previous item without moving cursor + x0_again = rev.peek_back(2) # two items back + + # Move backward + x1_back = rev.prev() # back one step + next(rev) # returns x2, continues forward from where we were + ``` + """ + + __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") + + def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): + if history < 1: + raise ValueError("history must be >= 1") + if lookahead <= 0: + raise ValueError("lookahead must be > 0") + + self._source: Iterator[T] = iter(iterable) + self._back_buf: deque[T] = deque(maxlen=history) + self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._cursor: int = 0 + self._history = history + self._lookahead = lookahead + + def __iter__(self) -> "Backtrackable[T]": + return self + + def __next__(self) -> T: + # If we've stepped back, consume from back buffer first + if self._cursor < 0: # -1 means "last item", etc. + self._cursor += 1 + return self._back_buf[self._cursor] + + # If we have items in the ahead buffer, use them first + item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) + + # Add current item to back buffer and reset cursor + self._back_buf.append(item) + self._cursor = 0 + return item + + def prev(self) -> T: + """ + Step one item back in history and return it. + Raises IndexError if already at the oldest buffered item. + """ + if len(self._back_buf) + self._cursor <= 1: + raise LookBackError("At start of history") + + self._cursor -= 1 + return self._back_buf[self._cursor] + + def peek_back(self, n: int = 1) -> T: + """ + Look `n` items back (n=1 == previous item) without moving the cursor. + """ + if n < 0 or n + 1 > len(self._back_buf) + self._cursor: + raise LookBackError("peek_back distance out of range") + + return self._back_buf[self._cursor - (n + 1)] + + def peek_ahead(self, n: int = 1) -> T: + """ + Look `n` items ahead (n=1 == next item) without moving the cursor. + Fills the ahead buffer if necessary. + """ + if n < 1: + raise LookAheadError("peek_ahead distance must be 1 or more") + elif n > self._lookahead: + raise LookAheadError("peek_ahead distance exceeds lookahead limit") + + # Fill ahead buffer if we don't have enough items + while len(self._ahead_buf) < n: + try: + item = next(self._source) + self._ahead_buf.append(item) + + except StopIteration as err: + raise LookAheadError("peek_ahead: not enough items in source") from err + + return self._ahead_buf[n - 1] + + def history(self) -> list[T]: + """ + Return a copy of the buffered history (most recent last). + The list length ≤ `history` argument passed at construction. + """ + if self._cursor == 0: + return list(self._back_buf) + + # When cursor<0, slice so the order remains chronological + return list(self._back_buf)[: self._cursor or None] + + def can_peek_back(self, steps: int = 1) -> bool: + """ + Check if we can go back `steps` items without raising an IndexError. + """ + return steps <= len(self._back_buf) + self._cursor + + def can_peek_ahead(self, steps: int = 1) -> bool: + """ + Check if we can peek ahead `steps` items. + This may involve trying to fill the ahead buffer. + """ + if self._lookahead > 0 and steps > self._lookahead: + return False + + # Try to fill ahead buffer to check if we can peek that far + try: + while len(self._ahead_buf) < steps: + if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: + return False + item = next(self._source) + self._ahead_buf.append(item) + return True + except StopIteration: + return False + + class StreamingLeRobotDataset(torch.utils.data.IterableDataset): """LeRobotDataset with streaming capabilities. diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 8bc56a1bd..2e1d360f9 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -17,35 +17,57 @@ import contextlib import importlib.resources import json import logging -from collections import deque -from collections.abc import Iterable, Iterator -from pathlib import Path -from pprint import pformat +from collections.abc import Iterator from typing import Any import datasets import numpy as np import packaging.version -import pandas -import pandas as pd -import pyarrow.dataset as pa_ds -import pyarrow.parquet as pq import torch -from datasets import Dataset -from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError -from PIL import Image as PILImage -from torchvision import transforms -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.datasets.backward_compatibility import ( - FUTURE_MESSAGE, - BackwardCompatibilityError, - ForwardCompatibilityError, -) -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string +V30_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. + +We introduced a new format since v3.0 which is not backward compatible with v2.1. +Please, update your dataset to the new format using this command: +``` +python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id={repo_id} +``` + +If you already have a converted version uploaded to the hub, then this error might be because of +an older version in your local cache. Consider deleting the cached version and retrying. + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + +FUTURE_MESSAGE = """ +The dataset you requested ({repo_id}) is only available in {version} format. +As we cannot ensure forward compatibility with it, please update your current version of lerobot. +""" + + +class CompatibilityError(Exception): ... + + +class BackwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + if version.major == 2 and version.minor == 1: + message = V30_MESSAGE.format(repo_id=repo_id, version=version) + else: + raise NotImplementedError( + "Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)." + ) + super().__init__(message) + + +class ForwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version) + super().__init__(message) + DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file @@ -79,21 +101,6 @@ DEFAULT_FEATURES = { } -def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: - metadata = pq.read_metadata(parquet_path) - total_uncompressed_size = 0 - for row_group in range(metadata.num_row_groups): - rg_metadata = metadata.row_group(row_group) - for column in range(rg_metadata.num_columns): - col_metadata = rg_metadata.column(column) - total_uncompressed_size += col_metadata.total_uncompressed_size - return total_uncompressed_size / (1024**2) - - -def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: - return hf_ds.data.nbytes // (1024**2) - - def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 @@ -103,43 +110,6 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) - return chunk_idx, file_idx -def load_nested_dataset( - pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None -) -> Dataset: - """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet - Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage - Concatenate all pyarrow references to return HF Dataset format - - Args: - pq_dir: Directory containing parquet files - features: Optional features schema to ensure consistent loading of complex types like images - episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency. - """ - paths = sorted(pq_dir.glob("*/*.parquet")) - if len(paths) == 0: - raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") - - with SuppressProgressBars(): - # We use .from_parquet() memory-mapped loading for efficiency - filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None - return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) - - -def get_parquet_num_frames(parquet_path: str | Path) -> int: - metadata = pq.read_metadata(parquet_path) - return metadata.num_rows - - -def get_file_size_in_mb(file_path: Path) -> float: - """Get file size on disk in megabytes. - - Args: - file_path (Path): Path to the file. - """ - file_size_bytes = file_path.stat().st_size - return file_size_bytes / (1024**2) - - def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: """Flatten a nested dictionary by joining keys with a separator. @@ -222,217 +192,6 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: return unflatten_dict(serialized_dict) -def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: - """Embed image bytes into the dataset table before saving to Parquet. - - This function prepares a Hugging Face dataset for serialization by converting - image objects into an embedded format that can be stored in Arrow/Parquet. - - Args: - dataset (datasets.Dataset): The input dataset, possibly containing image features. - - Returns: - datasets.Dataset: The dataset with images embedded in the table storage. - """ - # Embed image bytes into the table before saving to parquet - format = dataset.format - dataset = dataset.with_format("arrow") - dataset = dataset.map(embed_table_storage, batched=False) - dataset = dataset.with_format(**format) - return dataset - - -def load_json(fpath: Path) -> Any: - """Load data from a JSON file. - - Args: - fpath (Path): Path to the JSON file. - - Returns: - Any: The data loaded from the JSON file. - """ - with open(fpath) as f: - return json.load(f) - - -def write_json(data: dict, fpath: Path) -> None: - """Write data to a JSON file. - - Creates parent directories if they don't exist. - - Args: - data (dict): The dictionary to write. - fpath (Path): The path to the output JSON file. - """ - fpath.parent.mkdir(exist_ok=True, parents=True) - with open(fpath, "w") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - - -def write_info(info: dict, local_dir: Path) -> None: - write_json(info, local_dir / INFO_PATH) - - -def load_info(local_dir: Path) -> dict: - """Load dataset info metadata from its standard file path. - - Also converts shape lists to tuples for consistency. - - Args: - local_dir (Path): The root directory of the dataset. - - Returns: - dict: The dataset information dictionary. - """ - info = load_json(local_dir / INFO_PATH) - for ft in info["features"].values(): - ft["shape"] = tuple(ft["shape"]) - return info - - -def write_stats(stats: dict, local_dir: Path) -> None: - """Serialize and write dataset statistics to their standard file path. - - Args: - stats (dict): The statistics dictionary (can contain tensors/numpy arrays). - local_dir (Path): The root directory of the dataset. - """ - serialized_stats = serialize_dict(stats) - write_json(serialized_stats, local_dir / STATS_PATH) - - -def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: - """Recursively cast numerical values in a stats dictionary to numpy arrays. - - Args: - stats (dict): The statistics dictionary. - - Returns: - dict: The statistics dictionary with values cast to numpy arrays. - """ - stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} - return unflatten_dict(stats) - - -def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: - """Load dataset statistics and cast numerical values to numpy arrays. - - Returns None if the stats file doesn't exist. - - Args: - local_dir (Path): The root directory of the dataset. - - Returns: - A dictionary of statistics or None if the file is not found. - """ - if not (local_dir / STATS_PATH).exists(): - return None - stats = load_json(local_dir / STATS_PATH) - return cast_stats_to_numpy(stats) - - -def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: - path = local_dir / DEFAULT_TASKS_PATH - path.parent.mkdir(parents=True, exist_ok=True) - tasks.to_parquet(path) - - -def load_tasks(local_dir: Path) -> pandas.DataFrame: - tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) - tasks.index.name = "task" - return tasks - - -def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: - """Load subtasks from subtasks.parquet if it exists.""" - subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH - if subtasks_path.exists(): - return pd.read_parquet(subtasks_path) - return None - - -def write_episodes(episodes: Dataset, local_dir: Path) -> None: - """Write episode metadata to a parquet file in the LeRobot v3.0 format. - This function writes episode-level metadata to a single parquet file. - Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures. - - Args: - episodes: HuggingFace Dataset containing episode metadata - local_dir: Root directory where the dataset will be stored - """ - episode_size_mb = get_hf_dataset_size_in_mb(episodes) - if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB: - raise NotImplementedError( - f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. " - f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. " - "This function only supports single-file episode metadata. " - ) - - fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) - fpath.parent.mkdir(parents=True, exist_ok=True) - episodes.to_parquet(fpath) - - -def load_episodes(local_dir: Path) -> datasets.Dataset: - episodes = load_nested_dataset(local_dir / EPISODES_DIR) - # Select episode features/columns containing references to episode data and videos - # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) - # This is to speedup access to these data, instead of having to load episode stats. - episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) - return episodes - - -def load_image_as_numpy( - fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True -) -> np.ndarray: - """Load an image from a file into a numpy array. - - Args: - fpath (str | Path): Path to the image file. - dtype (np.dtype): The desired data type of the output array. If floating, - pixels are scaled to [0, 1]. - channel_first (bool): If True, converts the image to (C, H, W) format. - Otherwise, it remains in (H, W, C) format. - - Returns: - np.ndarray: The image as a numpy array. - """ - img = PILImage.open(fpath).convert("RGB") - img_array = np.array(img, dtype=dtype) - if channel_first: # (H, W, C) -> (C, H, W) - img_array = np.transpose(img_array, (2, 0, 1)) - if np.issubdtype(dtype, np.floating): - img_array /= 255.0 - return img_array - - -def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: - """Convert a batch from a Hugging Face dataset to torch tensors. - - This transform function converts items from Hugging Face dataset format (pyarrow) - to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) - to a torch image representation (C, H, W, float32) in the range [0, 1]. Other - types are converted to torch.tensor. - - Args: - items_dict (dict): A dictionary representing a batch of data from a - Hugging Face dataset. - - Returns: - dict: The batch with items converted to torch tensors. - """ - for key in items_dict: - first_item = items_dict[key][0] - if isinstance(first_item, PILImage.Image): - to_tensor = transforms.ToTensor() - items_dict[key] = [to_tensor(img) for img in items_dict[key]] - elif first_item is None: - pass - else: - items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] - return items_dict - - def is_valid_version(version: str) -> bool: """Check if a string is a valid PEP 440 version. @@ -560,337 +319,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> raise ForwardCompatibilityError(repo_id, min(upper_versions)) -def get_hf_features_from_features(features: dict) -> datasets.Features: - """Convert a LeRobot features dictionary to a `datasets.Features` object. - - Args: - features (dict): A LeRobot-style feature dictionary. - - Returns: - datasets.Features: The corresponding Hugging Face `datasets.Features` object. - - Raises: - ValueError: If a feature has an unsupported shape. - """ - hf_features = {} - for key, ft in features.items(): - if ft["dtype"] == "video": - continue - elif ft["dtype"] == "image": - hf_features[key] = datasets.Image() - elif ft["shape"] == (1,): - hf_features[key] = datasets.Value(dtype=ft["dtype"]) - elif len(ft["shape"]) == 1: - hf_features[key] = datasets.Sequence( - length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) - ) - elif len(ft["shape"]) == 2: - hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 3: - hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 4: - hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 5: - hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) - else: - raise ValueError(f"Corresponding feature is not valid: {ft}") - - return datasets.Features(hf_features) - - -def _validate_feature_names(features: dict[str, dict]) -> None: - """Validate that feature names do not contain invalid characters. - - Args: - features (dict): The LeRobot features dictionary. - - Raises: - ValueError: If any feature name contains '/'. - """ - invalid_features = {name: ft for name, ft in features.items() if "/" in name} - if invalid_features: - raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") - - -def hw_to_dataset_features( - hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True -) -> dict[str, dict]: - """Convert hardware-specific features to a LeRobot dataset feature dictionary. - - This function takes a dictionary describing hardware outputs (like joint states - or camera image shapes) and formats it into the standard LeRobot feature - specification. - - Args: - hw_features (dict): Dictionary mapping feature names to their type (float for - joints) or shape (tuple for images). - prefix (str): The prefix to add to the feature keys (e.g., "observation" - or "action"). - use_video (bool): If True, image features are marked as "video", otherwise "image". - - Returns: - dict: A LeRobot features dictionary. - """ - features = {} - joint_fts = { - key: ftype - for key, ftype in hw_features.items() - if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) - } - cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} - - if joint_fts and prefix == ACTION: - features[prefix] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - if joint_fts and prefix == OBS_STR: - features[f"{prefix}.state"] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - for key, shape in cam_fts.items(): - features[f"{prefix}.images.{key}"] = { - "dtype": "video" if use_video else "image", - "shape": shape, - "names": ["height", "width", "channels"], - } - - _validate_feature_names(features) - return features - - -def build_dataset_frame( - ds_features: dict[str, dict], values: dict[str, Any], prefix: str -) -> dict[str, np.ndarray]: - """Construct a single data frame from raw values based on dataset features. - - A "frame" is a dictionary containing all the data for a single timestep, - formatted as numpy arrays according to the feature specification. - - Args: - ds_features (dict): The LeRobot dataset features dictionary. - values (dict): A dictionary of raw values from the hardware/environment. - prefix (str): The prefix to filter features by (e.g., "observation" - or "action"). - - Returns: - dict: A dictionary representing a single frame of data. - """ - frame = {} - for key, ft in ds_features.items(): - if key in DEFAULT_FEATURES or not key.startswith(prefix): - continue - elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: - frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) - elif ft["dtype"] in ["image", "video"]: - frame[key] = values[key.removeprefix(f"{prefix}.images.")] - - return frame - - -def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: - """Convert dataset features to policy features. - - This function transforms the dataset's feature specification into a format - that a policy can use, classifying features by type (e.g., visual, state, - action) and ensuring correct shapes (e.g., channel-first for images). - - Args: - features (dict): The LeRobot dataset features dictionary. - - Returns: - dict: A dictionary mapping feature keys to `PolicyFeature` objects. - - Raises: - ValueError: If an image feature does not have a 3D shape. - """ - # TODO(aliberts): Implement "type" in dataset features and simplify this - policy_features = {} - for key, ft in features.items(): - shape = ft["shape"] - if ft["dtype"] in ["image", "video"]: - type = FeatureType.VISUAL - if len(shape) != 3: - raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") - - names = ft["names"] - # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) - shape = (shape[2], shape[0], shape[1]) - elif key == OBS_ENV_STATE: - type = FeatureType.ENV - elif key.startswith(OBS_STR): - type = FeatureType.STATE - elif key.startswith(ACTION): - type = FeatureType.ACTION - else: - continue - - policy_features[key] = PolicyFeature( - type=type, - shape=shape, - ) - - return policy_features - - -def combine_feature_dicts(*dicts: dict) -> dict: - """Merge LeRobot grouped feature dicts. - - - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. - - For others (e.g. `observation.images.*`), the last one wins (if they are identical). - - Args: - *dicts: A variable number of LeRobot feature dictionaries to merge. - - Returns: - dict: A single merged feature dictionary. - - Raises: - ValueError: If there's a dtype mismatch for a feature being merged. - """ - out: dict = {} - for d in dicts: - for key, value in d.items(): - if not isinstance(value, dict): - out[key] = value - continue - - dtype = value.get("dtype") - shape = value.get("shape") - is_vector = ( - dtype not in ("image", "video", "string") - and isinstance(shape, tuple) - and len(shape) == 1 - and "names" in value - ) - - if is_vector: - # Initialize or retrieve the accumulating dict for this feature key - target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) - # Ensure consistent data types across merged entries - if "dtype" in target and dtype != target["dtype"]: - raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") - - # Merge feature names: append only new ones to preserve order without duplicates - seen = set(target["names"]) - for n in value["names"]: - if n not in seen: - target["names"].append(n) - seen.add(n) - # Recompute the shape to reflect the updated number of features - target["shape"] = (len(target["names"]),) - else: - # For images/videos and non-1D entries: override with the latest definition - out[key] = value - return out - - -def create_empty_dataset_info( - codebase_version: str, - fps: int, - features: dict, - use_videos: bool, - robot_type: str | None = None, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, -) -> dict: - """Create a template dictionary for a new dataset's `info.json`. - - Args: - codebase_version (str): The version of the LeRobot codebase. - fps (int): The frames per second of the data. - features (dict): The LeRobot features dictionary for the dataset. - use_videos (bool): Whether the dataset will store videos. - robot_type (str | None): The type of robot used, if any. - - Returns: - dict: A dictionary with the initial dataset metadata. - """ - return { - "codebase_version": codebase_version, - "robot_type": robot_type, - "total_episodes": 0, - "total_frames": 0, - "total_tasks": 0, - "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, - "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, - "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, - "fps": fps, - "splits": {}, - "data_path": DEFAULT_DATA_PATH, - "video_path": DEFAULT_VIDEO_PATH if use_videos else None, - "features": features, - } - - -def check_delta_timestamps( - delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True -) -> bool: - """Check if delta timestamps are multiples of 1/fps +/- tolerance. - - This ensures that adding these delta timestamps to any existing timestamp in - the dataset will result in a value that aligns with the dataset's frame rate. - - Args: - delta_timestamps (dict): A dictionary where values are lists of time - deltas in seconds. - fps (int): The frames per second of the dataset. - tolerance_s (float): The allowed tolerance in seconds. - raise_value_error (bool): If True, raises an error on failure. - - Returns: - bool: True if all deltas are valid, False otherwise. - - Raises: - ValueError: If any delta is outside the tolerance and `raise_value_error` is True. - """ - outside_tolerance = {} - for key, delta_ts in delta_timestamps.items(): - within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] - if not all(within_tolerance): - outside_tolerance[key] = [ - ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within - ] - - if len(outside_tolerance) > 0: - if raise_value_error: - raise ValueError( - f""" - The following delta_timestamps are found outside of tolerance range. - Please make sure they are multiples of 1/{fps} +/- tolerance and adjust - their values accordingly. - \n{pformat(outside_tolerance)} - """ - ) - return False - - return True - - -def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: - """Convert delta timestamps in seconds to delta indices in frames. - - Args: - delta_timestamps (dict): A dictionary of time deltas in seconds. - fps (int): The frames per second of the dataset. - - Returns: - dict: A dictionary of frame delta indices. - """ - delta_indices = {} - for key, delta_ts in delta_timestamps.items(): - delta_indices[key] = [round(d * fps) for d in delta_ts] - - return delta_indices - - def cycle(iterable: Any) -> Iterator[Any]: """Create a dataloader-safe cyclical iterator. @@ -982,229 +410,6 @@ def create_lerobot_dataset_card( ) -def validate_frame(frame: dict, features: dict) -> None: - expected_features = set(features) - set(DEFAULT_FEATURES) - actual_features = set(frame) - - # task is a special required field that's not part of regular features - if "task" not in actual_features: - raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") - - # Remove task from actual_features for regular feature validation - actual_features_for_validation = actual_features - {"task"} - - error_message = validate_features_presence(actual_features_for_validation, expected_features) - - common_features = actual_features_for_validation & expected_features - for name in common_features: - error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) - - if error_message: - raise ValueError(error_message) - - -def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: - """Check for missing or extra features in a frame. - - Args: - actual_features (set[str]): The set of feature names present in the frame. - expected_features (set[str]): The set of feature names expected in the frame. - - Returns: - str: An error message string if there's a mismatch, otherwise an empty string. - """ - error_message = "" - missing_features = expected_features - actual_features - extra_features = actual_features - expected_features - - if missing_features or extra_features: - error_message += "Feature mismatch in `frame` dictionary:\n" - if missing_features: - error_message += f"Missing features: {missing_features}\n" - if extra_features: - error_message += f"Extra features: {extra_features}\n" - - return error_message - - -def validate_feature_dtype_and_shape( - name: str, feature: dict, value: np.ndarray | PILImage.Image | str -) -> str: - """Validate the dtype and shape of a single feature's value. - - Args: - name (str): The name of the feature. - feature (dict): The feature specification from the LeRobot features dictionary. - value: The value of the feature to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - - Raises: - NotImplementedError: If the feature dtype is not supported for validation. - """ - expected_dtype = feature["dtype"] - expected_shape = feature["shape"] - if is_valid_numpy_dtype_string(expected_dtype): - return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) - elif expected_dtype in ["image", "video"]: - return validate_feature_image_or_video(name, expected_shape, value) - elif expected_dtype == "string": - return validate_feature_string(name, value) - else: - raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") - - -def validate_feature_numpy_array( - name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray -) -> str: - """Validate a feature that is expected to be a numpy array. - - Args: - name (str): The name of the feature. - expected_dtype (str): The expected numpy dtype as a string. - expected_shape (list[int]): The expected shape. - value (np.ndarray): The numpy array to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - error_message = "" - if isinstance(value, np.ndarray): - actual_dtype = value.dtype - actual_shape = value.shape - - if actual_dtype != np.dtype(expected_dtype): - error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" - - if actual_shape != expected_shape: - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" - else: - error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" - - return error_message - - -def validate_feature_image_or_video( - name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image -) -> str: - """Validate a feature that is expected to be an image or video frame. - - Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. - - Args: - name (str): The name of the feature. - expected_shape (list[str]): The expected shape (C, H, W). - value: The image data to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. - error_message = "" - if isinstance(value, np.ndarray): - actual_shape = value.shape - c, h, w = expected_shape - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" - elif isinstance(value, PILImage.Image): - pass - else: - error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" - - return error_message - - -def validate_feature_string(name: str, value: str) -> str: - """Validate a feature that is expected to be a string. - - Args: - name (str): The name of the feature. - value (str): The value to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - if not isinstance(value, str): - return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" - return "" - - -def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: - """Validate the episode buffer before it's written to disk. - - Ensures the buffer has the required keys, contains at least one frame, and - has features consistent with the dataset's specification. - - Args: - episode_buffer (dict): The buffer containing data for a single episode. - total_episodes (int): The current total number of episodes in the dataset. - features (dict): The LeRobot features dictionary for the dataset. - - Raises: - ValueError: If the buffer is invalid. - NotImplementedError: If the episode index is manually set and doesn't match. - """ - if "size" not in episode_buffer: - raise ValueError("size key not found in episode_buffer") - - if "task" not in episode_buffer: - raise ValueError("task key not found in episode_buffer") - - if episode_buffer["episode_index"] != total_episodes: - # TODO(aliberts): Add option to use existing episode_index - raise NotImplementedError( - "You might have manually provided the episode_buffer with an episode_index that doesn't " - "match the total number of episodes already in the dataset. This is not supported for now." - ) - - if episode_buffer["size"] == 0: - raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") - - buffer_keys = set(episode_buffer.keys()) - {"task", "size"} - if not buffer_keys == set(features): - raise ValueError( - f"Features from `episode_buffer` don't match the ones in `features`." - f"In episode_buffer not in features: {buffer_keys - set(features)}" - f"In features not in episode_buffer: {set(features) - buffer_keys}" - ) - - -def to_parquet_with_hf_images( - df: pandas.DataFrame, path: Path, features: datasets.Features | None = None -) -> None: - """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. - This way, it can be loaded by HF dataset and correctly formatted images are returned. - - Args: - df: DataFrame to write to parquet. - path: Path to write the parquet file. - features: Optional HuggingFace Features schema. If provided, ensures image columns - are properly typed as Image() in the parquet schema. - """ - # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only - ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) - ds.to_parquet(path) - - -def item_to_torch(item: dict) -> dict: - """Convert all items in a dictionary to PyTorch tensors where appropriate. - - This function is used to convert an item from a streaming dataset to PyTorch tensors. - - Args: - item (dict): Dictionary of items from a dataset. - - Returns: - dict: Dictionary with all tensor-like items converted to torch.Tensor. - """ - for key, val in item.items(): - if isinstance(val, (np.ndarray | list)) and key not in ["task"]: - # Convert numpy arrays and lists to torch tensors - item[key] = torch.tensor(val) - return item - - def is_float_in_list(target, float_list, threshold=1e-6): return any(abs(target - x) <= threshold for x in float_list) @@ -1216,164 +421,6 @@ def find_float_index(target, float_list, threshold=1e-6): return -1 -class LookBackError(Exception): - """ - Exception raised when trying to look back in the history of a Backtrackable object. - """ - - pass - - -class LookAheadError(Exception): - """ - Exception raised when trying to look ahead in the future of a Backtrackable object. - """ - - pass - - -class Backtrackable[T]: - """ - Wrap any iterator/iterable so you can step back up to `history` items - and look ahead up to `lookahead` items. - - This is useful for streaming datasets where you need to access previous and future items - but can't load the entire dataset into memory. - - Example: - ------- - ```python - ds = load_dataset("c4", "en", streaming=True, split="train") - rev = Backtrackable(ds, history=3, lookahead=2) - - x0 = next(rev) # forward - x1 = next(rev) - x2 = next(rev) - - # Look ahead - x3_peek = rev.peek_ahead(1) # next item without moving cursor - x4_peek = rev.peek_ahead(2) # two items ahead - - # Look back - x1_again = rev.peek_back(1) # previous item without moving cursor - x0_again = rev.peek_back(2) # two items back - - # Move backward - x1_back = rev.prev() # back one step - next(rev) # returns x2, continues forward from where we were - ``` - """ - - __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") - - def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): - if history < 1: - raise ValueError("history must be >= 1") - if lookahead <= 0: - raise ValueError("lookahead must be > 0") - - self._source: Iterator[T] = iter(iterable) - self._back_buf: deque[T] = deque(maxlen=history) - self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() - self._cursor: int = 0 - self._history = history - self._lookahead = lookahead - - def __iter__(self) -> "Backtrackable[T]": - return self - - def __next__(self) -> T: - # If we've stepped back, consume from back buffer first - if self._cursor < 0: # -1 means "last item", etc. - self._cursor += 1 - return self._back_buf[self._cursor] - - # If we have items in the ahead buffer, use them first - item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) - - # Add current item to back buffer and reset cursor - self._back_buf.append(item) - self._cursor = 0 - return item - - def prev(self) -> T: - """ - Step one item back in history and return it. - Raises IndexError if already at the oldest buffered item. - """ - if len(self._back_buf) + self._cursor <= 1: - raise LookBackError("At start of history") - - self._cursor -= 1 - return self._back_buf[self._cursor] - - def peek_back(self, n: int = 1) -> T: - """ - Look `n` items back (n=1 == previous item) without moving the cursor. - """ - if n < 0 or n + 1 > len(self._back_buf) + self._cursor: - raise LookBackError("peek_back distance out of range") - - return self._back_buf[self._cursor - (n + 1)] - - def peek_ahead(self, n: int = 1) -> T: - """ - Look `n` items ahead (n=1 == next item) without moving the cursor. - Fills the ahead buffer if necessary. - """ - if n < 1: - raise LookAheadError("peek_ahead distance must be 1 or more") - elif n > self._lookahead: - raise LookAheadError("peek_ahead distance exceeds lookahead limit") - - # Fill ahead buffer if we don't have enough items - while len(self._ahead_buf) < n: - try: - item = next(self._source) - self._ahead_buf.append(item) - - except StopIteration as err: - raise LookAheadError("peek_ahead: not enough items in source") from err - - return self._ahead_buf[n - 1] - - def history(self) -> list[T]: - """ - Return a copy of the buffered history (most recent last). - The list length ≤ `history` argument passed at construction. - """ - if self._cursor == 0: - return list(self._back_buf) - - # When cursor<0, slice so the order remains chronological - return list(self._back_buf)[: self._cursor or None] - - def can_peek_back(self, steps: int = 1) -> bool: - """ - Check if we can go back `steps` items without raising an IndexError. - """ - return steps <= len(self._back_buf) + self._cursor - - def can_peek_ahead(self, steps: int = 1) -> bool: - """ - Check if we can peek ahead `steps` items. - This may involve trying to fill the ahead buffer. - """ - if self._lookahead > 0 and steps > self._lookahead: - return False - - # Try to fill ahead buffer to check if we can peek that far - try: - while len(self._ahead_buf) < steps: - if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: - return False - item = next(self._source) - self._ahead_buf.append(item) - return True - except StopIteration: - return False - - def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset: """ Safe shards the dataset. diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 8c8494b87..e465b79b4 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -37,6 +37,8 @@ import torchvision from datasets.features.features import register_feature from PIL import Image +logger = logging.getLogger(__name__) + # List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build. # Determines the order of preference for auto-selection when vcodec="auto" is used. HW_ENCODERS = [ @@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]: av.codec.Codec(codec_name, "w") available.append(codec_name) except Exception: # nosec B110 - pass # nosec B110 + logger.debug("HW encoder '%s' not available", codec_name) # nosec B110 return available @@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str: if vcodec not in VALID_VIDEO_CODECS: raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") if vcodec != "auto": - logging.info(f"Using video codec: {vcodec}") + logger.info(f"Using video codec: {vcodec}") return vcodec available = detect_available_hw_encoders() for encoder in HW_ENCODERS: if encoder in available: - logging.info(f"Auto-selected video codec: {encoder}") + logger.info(f"Auto-selected video codec: {encoder}") return encoder - logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") + logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") return "libsvtav1" @@ -118,7 +120,7 @@ def get_safe_default_codec(): if importlib.util.find_spec("torchcodec"): return "torchcodec" else: - logging.warning( + logger.warning( "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" ) return "pyav" @@ -208,7 +210,7 @@ def decode_video_frames_torchvision( for frame in reader: current_ts = frame["pts"] if log_loaded_timestamps: - logging.info(f"frame loaded at timestamp={current_ts:.4f}") + logger.info(f"frame loaded at timestamp={current_ts:.4f}") loaded_frames.append(frame["data"]) loaded_ts.append(current_ts) if current_ts >= last_ts: @@ -244,7 +246,7 @@ def decode_video_frames_torchvision( closest_ts = loaded_ts[argmin_] if log_loaded_timestamps: - logging.info(f"{closest_ts=}") + logger.info(f"{closest_ts=}") # convert to the pytorch format which is float32 in [0,1] range (and channel first) closest_frames = closest_frames.type(torch.float32) / 255 @@ -348,7 +350,7 @@ def decode_video_frames_torchcodec( loaded_frames.append(frame) loaded_ts.append(pts.item()) if log_loaded_timestamps: - logging.info(f"Frame loaded at timestamp={pts:.4f}") + logger.info(f"Frame loaded at timestamp={pts:.4f}") query_ts = torch.tensor(timestamps) loaded_ts = torch.tensor(loaded_ts) @@ -374,7 +376,7 @@ def decode_video_frames_torchcodec( closest_ts = loaded_ts[argmin_] if log_loaded_timestamps: - logging.info(f"{closest_ts=}") + logger.info(f"{closest_ts=}") # convert to float32 in [0,1] range closest_frames = (closest_frames / 255.0).type(torch.float32) @@ -408,14 +410,14 @@ def encode_video_frames( imgs_dir = Path(imgs_dir) if video_path.exists() and not overwrite: - logging.warning(f"Video file already exists: {video_path}. Skipping encoding.") + logger.warning(f"Video file already exists: {video_path}. Skipping encoding.") return video_path.parent.mkdir(parents=True, exist_ok=True) # Encoders/pixel formats incompatibility check if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p": - logging.warning( + logger.warning( f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'" ) pix_fmt = "yuv420p" @@ -508,7 +510,7 @@ def concatenate_video_files( output_video_path = Path(output_video_path) if output_video_path.exists() and not overwrite: - logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") + logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") return output_video_path.parent.mkdir(parents=True, exist_ok=True) @@ -693,7 +695,7 @@ class _CameraEncoderThread(threading.Thread): self.result_queue.put(("ok", None)) except Exception as e: - logging.error(f"Encoder thread error: {e}") + logger.error(f"Encoder thread error: {e}") if container is not None: with contextlib.suppress(Exception): container.close() @@ -819,7 +821,7 @@ class StreamingVideoEncoder: count = self._dropped_frames[video_key] # Log periodically to avoid spam (1st, then every 10th) if count == 1 or count % 10 == 0: - logging.warning( + logger.warning( f"Encoder queue full for {video_key}, dropped {count} frame(s). " f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize." ) @@ -841,7 +843,7 @@ class StreamingVideoEncoder: # Report dropped frames for video_key, count in self._dropped_frames.items(): if count > 0: - logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") + logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") # Send sentinel to all queues for video_key in self._frame_queues: @@ -851,7 +853,7 @@ class StreamingVideoEncoder: for video_key in self._threads: self._threads[video_key].join(timeout=120) if self._threads[video_key].is_alive(): - logging.error(f"Encoder thread for {video_key} did not finish in time") + logger.error(f"Encoder thread for {video_key} did not finish in time") self._stop_events[video_key].set() self._threads[video_key].join(timeout=5) results[video_key] = (self._video_paths[video_key], None) @@ -863,7 +865,7 @@ class StreamingVideoEncoder: raise RuntimeError(f"Encoder thread for {video_key} failed: {data}") results[video_key] = (self._video_paths[video_key], data) except queue.Empty: - logging.error(f"No result from encoder thread for {video_key}") + logger.error(f"No result from encoder thread for {video_key}") results[video_key] = (self._video_paths[video_key], None) self._cleanup() @@ -1071,13 +1073,13 @@ class VideoEncodingManager: elif self.dataset.episodes_since_last_encoding > 0: # Handle any remaining episodes that haven't been batch encoded if exc_type is not None: - logging.info("Exception occurred. Encoding remaining episodes before exit...") + logger.info("Exception occurred. Encoding remaining episodes before exit...") else: - logging.info("Recording stopped. Encoding remaining episodes...") + logger.info("Recording stopped. Encoding remaining episodes...") start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding end_ep = self.dataset.num_episodes - logging.info( + logger.info( f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " f"from episode {start_ep} to {end_ep - 1}" ) @@ -1094,7 +1096,7 @@ class VideoEncodingManager: episode_index=interrupted_episode_index, image_key=key, frame_index=0 ).parent if img_dir.exists(): - logging.debug( + logger.debug( f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" ) shutil.rmtree(img_dir) @@ -1105,8 +1107,8 @@ class VideoEncodingManager: png_files = list(img_dir.rglob("*.png")) if len(png_files) == 0: shutil.rmtree(img_dir) - logging.debug("Cleaned up empty images directory") + logger.debug("Cleaned up empty images directory") else: - logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") return False # Don't suppress the original exception diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index d20dae8ea..6d3589fed 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -29,7 +29,7 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py index 4d91e002d..e9e29f304 100644 --- a/src/lerobot/envs/metaworld.py +++ b/src/lerobot/envs/metaworld.py @@ -25,7 +25,7 @@ import metaworld.policies as policies import numpy as np from gymnasium import spaces -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation # ---- Load configuration data from the external JSON file ---- CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 09431a18d..fd17a6762 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -29,7 +29,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import get_channel_first_image_shape diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py index 2b75353d9..e2e3d8937 100644 --- a/src/lerobot/optim/optimizers.py +++ b/src/lerobot/optim/optimizers.py @@ -23,7 +23,8 @@ import draccus import torch from safetensors.torch import load_file, save_file -from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json +from lerobot.datasets.io_utils import write_json +from lerobot.datasets.utils import flatten_dict, unflatten_dict from lerobot.utils.constants import ( OPTIMIZER_PARAM_GROUPS, OPTIMIZER_STATE, diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index 4af7f0802..19c3fd7bd 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -23,7 +23,7 @@ import draccus from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from lerobot.datasets.utils import write_json +from lerobot.datasets.io_utils import write_json from lerobot.utils.constants import SCHEDULER_STATE from lerobot.utils.io_utils import deserialize_json_into_object diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index d50d8652a..2320cd624 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -24,8 +24,8 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features from lerobot.envs.configs import EnvConfig from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig @@ -43,13 +43,14 @@ from lerobot.policies.utils import validate_visual_features_consistency from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline from lerobot.processor.converters import ( batch_to_transition, policy_action_to_transition, transition_to_batch, transition_to_policy_action, ) +from lerobot.types import PolicyAction from lerobot.utils.constants import ( ACTION, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 14149cf2f..8bf9dabca 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -49,7 +49,7 @@ from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( ACTION, HF_LEROBOT_HOME, diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index 6e01a4e16..425a85577 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -36,7 +36,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py index fde7d5c80..46e54432a 100644 --- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -37,7 +37,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 8f2bc23db..f377a7ffa 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -48,8 +48,8 @@ from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) -from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.pipeline import PipelineFeatureType +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 32165eba8..7110ba7d2 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -68,7 +68,7 @@ from lerobot.policies.utils import ( populate_queues, ) from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE -from lerobot.utils.utils import get_safe_dtype +from lerobot.utils.device_utils import get_safe_dtype class ActionSelectKwargs(TypedDict, total=False): diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 1a14b2925..82ab51005 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -23,8 +23,8 @@ from torch import nn from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.datasets.utils import build_dataset_frame -from lerobot.processor import PolicyAction, RobotAction, RobotObservation +from lerobot.datasets.feature_utils import build_dataset_frame +from lerobot.types import PolicyAction, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 359b4fdb1..6d3976b79 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -467,8 +467,8 @@ class VQBeTHead(nn.Module): self.vqvae_model.optimized_steps += 1 # if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part. if self.vqvae_model.optimized_steps >= n_vqvae_training_steps: - self.vqvae_model.discretized = torch.tensor(True) - self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True) + self.vqvae_model.discretized.fill_(True) + self.vqvae_model.vq_layer.freeze_codebook.fill_(True) print("Finished discretizing action data!") self.vqvae_model.eval() for param in self.vqvae_model.vq_layer.parameters(): diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index c4e3f2d6f..0fa9ffe3f 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -38,7 +38,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_IMAGES, OBS_PREFIX, diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 0b63e1606..12dcf0c6d 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -14,13 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .batch_processor import AddBatchDimensionProcessorStep -from .converters import ( - batch_to_transition, - create_transition, - transition_to_batch, -) -from .core import ( +from lerobot.types import ( EnvAction, EnvTransition, PolicyAction, @@ -28,6 +22,13 @@ from .core import ( RobotObservation, TransitionKey, ) + +from .batch_processor import AddBatchDimensionProcessorStep +from .converters import ( + batch_to_transition, + create_transition, + transition_to_batch, +) from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep from .device_processor import DeviceProcessorStep from .factory import ( diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index e1a90421f..c904acf84 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -25,9 +25,9 @@ from dataclasses import dataclass, field from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvTransition, PolicyAction from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from .core import EnvTransition, PolicyAction from .pipeline import ( ComplementaryDataProcessorStep, ObservationProcessorStep, diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 18c7b0220..ffdf0098c 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,10 +23,9 @@ from typing import Any import numpy as np import torch +from lerobot.types import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED -from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey - @singledispatch def to_tensor( diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index a8395637c..f7f5676ac 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -17,8 +17,8 @@ from dataclasses import dataclass from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.types import PolicyAction, RobotAction -from .core import PolicyAction, RobotAction from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 2d0dd0880..36c80e58e 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -25,9 +25,9 @@ from typing import Any import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.utils import get_safe_torch_device +from lerobot.types import EnvTransition, PolicyAction, TransitionKey +from lerobot.utils.device_utils import get_safe_torch_device -from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ProcessorStep, ProcessorStepRegistry diff --git a/src/lerobot/processor/factory.py b/src/lerobot/processor/factory.py index 5a0c41072..5028122f1 100644 --- a/src/lerobot/processor/factory.py +++ b/src/lerobot/processor/factory.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.types import RobotAction, RobotObservation + from .converters import ( observation_to_transition, robot_action_observation_to_transition, transition_to_observation, transition_to_robot_action, ) -from .core import RobotAction, RobotObservation from .pipeline import IdentityProcessorStep, RobotProcessorPipeline diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 4f225af92..e756ded7f 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -17,9 +17,9 @@ from dataclasses import dataclass from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvAction, EnvTransition, PolicyAction from .converters import to_tensor -from .core import EnvAction, EnvTransition, PolicyAction from .hil_processor import TELEOP_ACTION_KEY from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry @@ -75,7 +75,7 @@ class Numpy2TorchActionProcessorStep(ProcessorStep): def __call__(self, transition: EnvTransition) -> EnvTransition: """Converts numpy action to torch tensor if action exists, otherwise passes through.""" - from .core import TransitionKey + from lerobot.types import TransitionKey self._current_transition = transition.copy() new_transition = self._current_transition diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 34eaeed51..0b8521c2b 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -30,7 +30,8 @@ from lerobot.teleoperators.utils import TeleopEvents if TYPE_CHECKING: from lerobot.teleoperators.teleoperator import Teleoperator -from .core import EnvTransition, PolicyAction, TransitionKey +from lerobot.types import EnvTransition, PolicyAction, TransitionKey + from .pipeline import ( ComplementaryDataProcessorStep, InfoProcessorStep, diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 4769b91ac..8a7a1176a 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -26,10 +26,10 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.types import EnvTransition, PolicyAction, TransitionKey from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor -from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index db1c3015c..abfb31421 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -46,10 +46,10 @@ from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from lerobot.utils.hub import HubMixin from .converters import batch_to_transition, create_transition, transition_to_batch -from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey # Generic type variables for pipeline input and output. TInput = TypeVar("TInput") diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index da6e600af..2a972ecc8 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.types import EnvTransition, RobotObservation, TransitionKey from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, @@ -40,7 +41,6 @@ from lerobot.utils.constants import ( ) from lerobot.utils.import_utils import _transformers_available -from .core import EnvTransition, RobotObservation, TransitionKey from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 7427633d2..18c0ca1ea 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -62,7 +62,6 @@ from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.processor import TransitionKey from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so_follower # noqa: F401 @@ -77,6 +76,8 @@ from lerobot.transport.utils import ( send_bytes_in_chunks, transitions_to_bytes, ) +from lerobot.types import TransitionKey +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( @@ -86,7 +87,6 @@ from lerobot.utils.transition import ( ) from lerobot.utils.utils import ( TimerManager, - get_safe_torch_device, init_logging, ) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index ee09ac9ac..2853fbcb3 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -86,6 +86,7 @@ from lerobot.utils.constants import ( PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, ) +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( get_step_checkpoint_dir, @@ -96,7 +97,6 @@ from lerobot.utils.train_utils import ( from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( format_big_number, - get_safe_torch_device, init_logging, ) diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index 2e3885e67..7f5e92271 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index 28c58b898..ba1826e29 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index cdf6efde1..76707a80c 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -23,7 +23,7 @@ import cv2 import numpy as np import requests -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.errors import DeviceNotConnectedError @@ -33,21 +33,40 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig logger = logging.getLogger(__name__) # Action feature keys -ACTION_LINEAR_VEL = "linear.vel" -ACTION_ANGULAR_VEL = "angular.vel" +ACTION_LINEAR_VEL = "linear_velocity" +ACTION_ANGULAR_VEL = "angular_velocity" -# Observation feature keys +# Observation feature keys — cameras OBS_FRONT = "front" OBS_REAR = "rear" -OBS_LINEAR_VEL = "linear.vel" -OBS_BATTERY_LEVEL = "battery.level" -OBS_ORIENTATION_DEG = "orientation.deg" -OBS_GPS_LATITUDE = "gps.latitude" -OBS_GPS_LONGITUDE = "gps.longitude" -OBS_GPS_SIGNAL = "gps.signal" -OBS_SIGNAL_LEVEL = "signal.level" + +# Observation feature keys — telemetry +OBS_SPEED = "speed" +OBS_BATTERY_LEVEL = "battery_level" +OBS_ORIENTATION = "orientation" +OBS_GPS_LATITUDE = "gps_latitude" +OBS_GPS_LONGITUDE = "gps_longitude" +OBS_GPS_SIGNAL = "gps_signal" +OBS_SIGNAL_LEVEL = "signal_level" OBS_VIBRATION = "vibration" -OBS_LAMP_STATE = "lamp.state" +OBS_LAMP = "lamp" + +# Observation feature keys — IMU sensors +OBS_ACCELEROMETER_X = "accelerometer_x" +OBS_ACCELEROMETER_Y = "accelerometer_y" +OBS_ACCELEROMETER_Z = "accelerometer_z" +OBS_GYROSCOPE_X = "gyroscope_x" +OBS_GYROSCOPE_Y = "gyroscope_y" +OBS_GYROSCOPE_Z = "gyroscope_z" +OBS_MAGNETOMETER_X = "magnetometer_filtered_x" +OBS_MAGNETOMETER_Y = "magnetometer_filtered_y" +OBS_MAGNETOMETER_Z = "magnetometer_filtered_z" + +# Observation feature keys — wheel RPMs +OBS_WHEEL_RPM_0 = "wheel_rpm_0" +OBS_WHEEL_RPM_1 = "wheel_rpm_1" +OBS_WHEEL_RPM_2 = "wheel_rpm_2" +OBS_WHEEL_RPM_3 = "wheel_rpm_3" class EarthRoverMiniPlus(Robot): @@ -154,33 +173,60 @@ class EarthRoverMiniPlus(Robot): dict: Observation features with types/shapes: - front: (480, 640, 3) - Front camera RGB image - rear: (480, 640, 3) - Rear camera RGB image - - linear.vel: float - Current speed (0-1, SDK reports only positive speeds) - - battery.level: float - Battery level (0-1, normalized from 0-100) - - orientation.deg: float - Robot orientation (0-1, normalized from raw value) - - gps.latitude: float - GPS latitude coordinate - - gps.longitude: float - GPS longitude coordinate - - gps.signal: float - GPS signal strength (0-1, normalized from percentage) - - signal.level: float - Network signal level (0-1, normalized from 0-5) + - speed: float - Current speed (raw SDK value) + - battery_level: float - Battery level (0-100) + - orientation: float - Robot orientation in degrees + - gps_latitude: float - GPS latitude coordinate + - gps_longitude: float - GPS longitude coordinate + - gps_signal: float - GPS signal strength (percentage) + - signal_level: float - Network signal level (0-5) - vibration: float - Vibration sensor reading - - lamp.state: float - Lamp state (0=off, 1=on) + - lamp: float - Lamp state (0=off, 1=on) + - accelerometer_x: float - Accelerometer X axis (raw SDK value) + - accelerometer_y: float - Accelerometer Y axis (raw SDK value) + - accelerometer_z: float - Accelerometer Z axis (raw SDK value) + - gyroscope_x: float - Gyroscope X axis (raw SDK value) + - gyroscope_y: float - Gyroscope Y axis (raw SDK value) + - gyroscope_z: float - Gyroscope Z axis (raw SDK value) + - magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value) + - magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value) + - magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value) + - wheel_rpm_0: float - Wheel 0 RPM + - wheel_rpm_1: float - Wheel 1 RPM + - wheel_rpm_2: float - Wheel 2 RPM + - wheel_rpm_3: float - Wheel 3 RPM """ return { # Cameras (height, width, channels) OBS_FRONT: (480, 640, 3), OBS_REAR: (480, 640, 3), - # Motion state - OBS_LINEAR_VEL: float, - # Robot state + # Telemetry + OBS_SPEED: float, OBS_BATTERY_LEVEL: float, - OBS_ORIENTATION_DEG: float, - # GPS + OBS_ORIENTATION: float, OBS_GPS_LATITUDE: float, OBS_GPS_LONGITUDE: float, OBS_GPS_SIGNAL: float, - # Sensors OBS_SIGNAL_LEVEL: float, OBS_VIBRATION: float, - OBS_LAMP_STATE: float, + OBS_LAMP: float, + # IMU — accelerometer + OBS_ACCELEROMETER_X: float, + OBS_ACCELEROMETER_Y: float, + OBS_ACCELEROMETER_Z: float, + # IMU — gyroscope + OBS_GYROSCOPE_X: float, + OBS_GYROSCOPE_Y: float, + OBS_GYROSCOPE_Z: float, + # IMU — magnetometer + OBS_MAGNETOMETER_X: float, + OBS_MAGNETOMETER_Y: float, + OBS_MAGNETOMETER_Z: float, + # Wheel RPMs + OBS_WHEEL_RPM_0: float, + OBS_WHEEL_RPM_1: float, + OBS_WHEEL_RPM_2: float, + OBS_WHEEL_RPM_3: float, } @cached_property @@ -189,8 +235,8 @@ class EarthRoverMiniPlus(Robot): Returns: dict: Action features with types: - - linear.vel: float - Target linear velocity - - angular.vel: float - Target angular velocity + - linear_velocity: float - Target linear velocity (-1 to 1) + - angular_velocity: float - Target angular velocity (-1 to 1) """ return { ACTION_LINEAR_VEL: float, @@ -201,19 +247,29 @@ class EarthRoverMiniPlus(Robot): def get_observation(self) -> RobotObservation: """Get current robot observation from SDK. + Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear. + Frames are decoded from base64 and converted from BGR to RGB format. + Robot telemetry is retrieved from /data endpoint. + Sensor arrays (accels, gyros, mags, rpms) each contain entries of + [values..., timestamp]; the latest reading from each array is used. + Returns: RobotObservation: Observation containing: - front: Front camera image (480, 640, 3) in RGB format - rear: Rear camera image (480, 640, 3) in RGB format - - linear.vel: Current speed (0-1, SDK reports only positive speeds) - - battery.level: Battery level (0-1, normalized from 0-100) - - orientation.deg: Robot orientation (0-1, normalized from raw value) - - gps.latitude: GPS latitude coordinate - - gps.longitude: GPS longitude coordinate - - gps.signal: GPS signal strength (0-1, normalized from percentage) - - signal.level: Network signal level (0-1, normalized from 0-5) - - vibration: Vibration sensor reading - - lamp.state: Lamp state (0=off, 1=on) + - speed: float - Current speed (raw SDK value) + - battery_level: float - Battery level (0-100) + - orientation: float - Robot orientation in degrees + - gps_latitude: float - GPS latitude coordinate + - gps_longitude: float - GPS longitude coordinate + - gps_signal: float - GPS signal strength (percentage) + - signal_level: float - Network signal level (0-5) + - vibration: float - Vibration sensor reading + - lamp: float - Lamp state (0=off, 1=on) + - accelerometer_x/y/z: float - Accelerometer axes (raw SDK value) + - gyroscope_x/y/z: float - Gyroscope axes (raw SDK value) + - magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value) + - wheel_rpm_0/1/2/3: float - Wheel RPMs Raises: DeviceNotConnectedError: If robot is not connected @@ -235,22 +291,41 @@ class EarthRoverMiniPlus(Robot): # Get robot state from SDK robot_data = self._get_robot_data() - # Motion state - observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1 + # Telemetry + observation[OBS_SPEED] = float(robot_data["speed"]) + observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"]) + observation[OBS_ORIENTATION] = float(robot_data["orientation"]) + observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"]) + observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"]) + observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"]) + observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"]) + observation[OBS_VIBRATION] = float(robot_data["vibration"]) + observation[OBS_LAMP] = float(robot_data["lamp"]) - # Robot state - observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1 - observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1 + # Accelerometer — latest reading from accels array [x, y, z, ts] + accel = self._latest_sensor_reading(robot_data, "accels", n_values=3) + observation[OBS_ACCELEROMETER_X] = accel[0] + observation[OBS_ACCELEROMETER_Y] = accel[1] + observation[OBS_ACCELEROMETER_Z] = accel[2] - # GPS data - observation[OBS_GPS_LATITUDE] = robot_data["latitude"] - observation[OBS_GPS_LONGITUDE] = robot_data["longitude"] - observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1 + # Gyroscope — latest reading from gyros array [x, y, z, ts] + gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3) + observation[OBS_GYROSCOPE_X] = gyro[0] + observation[OBS_GYROSCOPE_Y] = gyro[1] + observation[OBS_GYROSCOPE_Z] = gyro[2] - # Sensors - observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1 - observation[OBS_VIBRATION] = robot_data["vibration"] - observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1 + # Magnetometer — latest reading from mags array [x, y, z, ts] + mag = self._latest_sensor_reading(robot_data, "mags", n_values=3) + observation[OBS_MAGNETOMETER_X] = mag[0] + observation[OBS_MAGNETOMETER_Y] = mag[1] + observation[OBS_MAGNETOMETER_Z] = mag[2] + + # Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts] + rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4) + observation[OBS_WHEEL_RPM_0] = rpm[0] + observation[OBS_WHEEL_RPM_1] = rpm[1] + observation[OBS_WHEEL_RPM_2] = rpm[2] + observation[OBS_WHEEL_RPM_3] = rpm[3] return observation @@ -260,11 +335,12 @@ class EarthRoverMiniPlus(Robot): Args: action: Action dict with keys: - - linear.vel: Target linear velocity (-1 to 1) - - angular.vel: Target angular velocity (-1 to 1) + - linear_velocity: Target linear velocity (-1 to 1) + - angular_velocity: Target angular velocity (-1 to 1) Returns: RobotAction: The action that was sent (matches action_features keys) + Raises: DeviceNotConnectedError: If robot is not connected @@ -272,18 +348,14 @@ class EarthRoverMiniPlus(Robot): Actions are sent to SDK via POST /control endpoint. SDK expects commands in range [-1, 1]. """ - - # Extract action values and convert to float linear = float(action.get(ACTION_LINEAR_VEL, 0.0)) angular = float(action.get(ACTION_ANGULAR_VEL, 0.0)) - # Send command to SDK try: self._send_command_to_sdk(linear, angular) except Exception as e: logger.error(f"Error sending action: {e}") - # Return action in format matching action_features return { ACTION_LINEAR_VEL: linear, ACTION_ANGULAR_VEL: angular, @@ -394,11 +466,27 @@ class EarthRoverMiniPlus(Robot): logger.error(f"Error decoding image: {e}") return None + @staticmethod + def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]: + """Extract the latest sensor reading from an SDK sensor array. + + The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``, + ``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``. + This helper returns the *n_values* leading floats from the last entry, + falling back to zeros when the key is missing or the array is empty. + """ + readings = robot_data.get(key) + if readings and len(readings) > 0: + latest = readings[-1] + return [float(v) for v in latest[:n_values]] + return [0.0] * n_values + def _get_robot_data(self) -> dict: """Get robot telemetry data from SDK. Returns: - dict: Robot telemetry data including battery, speed, orientation, GPS, etc: + dict: Robot telemetry data including battery, speed, orientation, GPS, + and sensor arrays (accels, gyros, mags, rpms): - Current data (if request succeeds) - Cached data (if request fails but cache exists) - Default values (if request fails and no cache exists yet) @@ -420,19 +508,23 @@ class EarthRoverMiniPlus(Robot): # Fallback: use cache or default values if self._last_robot_data is not None: return self._last_robot_data - else: - # Return dict with default values (used only on first failure before any cache exists) - return { - "speed": 0, - "battery": 0, - "orientation": 0, - "latitude": 0.0, - "longitude": 0.0, - "gps_signal": 0, - "signal_level": 0, - "vibration": 0.0, - "lamp": 0, - } + + # Return dict with default values (used only on first failure before any cache exists) + return { + "speed": 0, + "battery": 0, + "orientation": 0, + "latitude": 0.0, + "longitude": 0.0, + "gps_signal": 0, + "signal_level": 0, + "vibration": 0.0, + "lamp": 0, + "accels": [], + "gyros": [], + "mags": [], + "rpms": [], + } def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool: """Send control command to SDK. diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index e8269ae46..7f6492ef0 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index a05c4bbcb..784804836 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 53a32beed..44e83f6a3 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -24,7 +24,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 9d11a000f..60fac89e5 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -28,7 +28,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 1d5ea64a6..fd43e84fe 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -22,7 +22,7 @@ from functools import cached_property import cv2 import numpy as np -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.errors import DeviceNotConnectedError diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index e0b612c60..5d161daa2 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index c865f1ec1..99e8b920b 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -22,7 +22,7 @@ from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index fb466f85b..5227a096a 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -19,7 +19,7 @@ import time from typing import TYPE_CHECKING, Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import _reachy2_sdk_available from ..robot import Robot diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index d165886b9..1b556f963 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -19,7 +19,7 @@ from pathlib import Path import draccus from lerobot.motors import MotorCalibration -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS from .config import RobotConfig diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index c898e9137..ca132d102 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -24,7 +24,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 41146ebe6..9e373c05f 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -26,8 +26,6 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable import numpy as np from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.envs.factory import make_env -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK from lerobot.robots.unitree_g1.g1_utils import ( REMOTE_AXES, @@ -37,6 +35,7 @@ from lerobot.robots.unitree_g1.g1_utils import ( default_remote_input, make_locomotion_controller, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import _unitree_sdk_available from ..robot import Robot @@ -291,6 +290,8 @@ class UnitreeG1(Robot): def connect(self, calibrate: bool = True) -> None: # connect to DDS # Initialize DDS channel and simulation environment if self.config.is_simulation: + from lerobot.envs.factory import make_env + self._ChannelFactoryInitialize(0, "lo") self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) # Extract the actual gym env from the dict structure diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/scripts/augment_dataset_quantile_stats.py similarity index 97% rename from src/lerobot/datasets/v30/augment_dataset_quantile_stats.py rename to src/lerobot/scripts/augment_dataset_quantile_stats.py index 900a43a4f..4d80c9332 100644 --- a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py +++ b/src/lerobot/scripts/augment_dataset_quantile_stats.py @@ -28,7 +28,7 @@ quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script: Usage: ```bash -python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ +python src/lerobot/scripts/augment_dataset_quantile_stats.py \ --repo-id=lerobot/pusht \ ``` """ @@ -45,8 +45,9 @@ from requests import HTTPError from tqdm import tqdm from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.datasets.utils import write_stats +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION +from lerobot.datasets.io_utils import write_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.utils import init_logging diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/scripts/convert_dataset_v21_to_v30.py similarity index 98% rename from src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py rename to src/lerobot/scripts/convert_dataset_v21_to_v30.py index 81de05686..2b6dcf732 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/scripts/convert_dataset_v21_to_v30.py @@ -28,13 +28,13 @@ Usage: Convert a dataset from the hub: ```bash -python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ +python src/lerobot/scripts/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht ``` Convert a local dataset (works in place): ```bash -python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ +python src/lerobot/scripts/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht \ --root=/path/to/local/dataset/directory \ --push-to-hub=false @@ -60,7 +60,19 @@ from huggingface_hub import HfApi, snapshot_download from requests import HTTPError from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION +from lerobot.datasets.io_utils import ( + cast_stats_to_numpy, + get_file_size_in_mb, + get_parquet_file_size_in_mb, + get_parquet_num_frames, + load_info, + write_episodes, + write_info, + write_stats, + write_tasks, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -70,17 +82,8 @@ from lerobot.datasets.utils import ( LEGACY_EPISODES_PATH, LEGACY_EPISODES_STATS_PATH, LEGACY_TASKS_PATH, - cast_stats_to_numpy, flatten_dict, - get_file_size_in_mb, - get_parquet_file_size_in_mb, - get_parquet_num_frames, - load_info, update_chunk_file_indices, - write_episodes, - write_info, - write_stats, - write_tasks, ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s from lerobot.utils.constants import HF_LEROBOT_HOME diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index e32b80404..6d814f498 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -80,13 +80,14 @@ from lerobot.envs.utils import ( ) from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( - get_safe_torch_device, init_logging, inside_slurm, ) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index dc682fe6f..819634ba2 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -83,10 +83,10 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig +from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy @@ -139,10 +139,10 @@ from lerobot.utils.control_utils import ( sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, ) +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import ( - get_safe_torch_device, init_logging, log_say, ) diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index 74b0c9b83..b44f1fbea 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..openarm_leader import OpenArmLeader diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 69cb0f971..8c1796e45 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -20,7 +20,7 @@ from typing import Any import numpy as np -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 919f463d3..090aa7fae 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -21,7 +21,7 @@ import time from queue import Queue from typing import Any -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator @@ -341,8 +341,8 @@ class KeyboardRoverTeleop(KeyboardTeleop): def action_features(self) -> dict: """Return action format for rover (linear and angular velocities).""" return { - "linear.vel": float, - "angular.vel": float, + "linear_velocity": float, + "angular_velocity": float, } @property @@ -366,7 +366,7 @@ class KeyboardRoverTeleop(KeyboardTeleop): Get the current action based on pressed keys. Returns: - RobotAction with 'linear.vel' and 'angular.vel' keys + RobotAction with 'linear_velocity' and 'angular_velocity' keys. """ before_read_t = time.perf_counter() @@ -427,6 +427,6 @@ class KeyboardRoverTeleop(KeyboardTeleop): self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t return { - "linear.vel": linear_velocity, - "angular.vel": angular_velocity, + "linear_velocity": linear_velocity, + "angular_velocity": angular_velocity, } diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py index d9eaabe0f..65da7416a 100644 --- a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -20,7 +20,7 @@ from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py index 3fbcecf24..23594caa9 100644 --- a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py +++ b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py @@ -23,7 +23,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py index 67e64c7d5..c498bed7d 100644 --- a/src/lerobot/teleoperators/phone/phone_processor.py +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -17,8 +17,9 @@ from dataclasses import dataclass, field from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import ProcessorStepRegistry, RobotAction, RobotActionProcessorStep +from lerobot.processor import ProcessorStepRegistry, RobotActionProcessorStep from lerobot.teleoperators.phone.config_phone import PhoneOS +from lerobot.types import RobotAction @ProcessorStepRegistry.register("map_phone_action_to_robot_action") diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index 847b88b7f..f47904423 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -20,7 +20,7 @@ from typing import Any import draccus from lerobot.motors.motors_bus import MotorCalibration -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS from .config import TeleoperatorConfig diff --git a/src/lerobot/processor/core.py b/src/lerobot/types.py similarity index 100% rename from src/lerobot/processor/core.py rename to src/lerobot/types.py diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 7c605af17..94cd82fa1 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -32,8 +32,9 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import prepare_observation_for_inference -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline from lerobot.robots import Robot +from lerobot.types import PolicyAction @cache diff --git a/src/lerobot/utils/device_utils.py b/src/lerobot/utils/device_utils.py new file mode 100644 index 000000000..37981f07f --- /dev/null +++ b/src/lerobot/utils/device_utils.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch + + +def auto_select_torch_device() -> torch.device: + """Tries to select automatically a torch device.""" + if torch.cuda.is_available(): + logging.info("Cuda backend detected, using cuda.") + return torch.device("cuda") + elif torch.backends.mps.is_available(): + logging.info("Metal backend detected, using mps.") + return torch.device("mps") + elif torch.xpu.is_available(): + logging.info("Intel XPU backend detected, using xpu.") + return torch.device("xpu") + else: + logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") + return torch.device("cpu") + + +# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level +def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" + try_device = str(try_device) + if try_device.startswith("cuda"): + assert torch.cuda.is_available() + device = torch.device(try_device) + elif try_device == "mps": + assert torch.backends.mps.is_available() + device = torch.device("mps") + elif try_device == "xpu": + assert torch.xpu.is_available() + device = torch.device("xpu") + elif try_device == "cpu": + device = torch.device("cpu") + if log: + logging.warning("Using CPU, this will be slow.") + else: + device = torch.device(try_device) + if log: + logging.warning(f"Using custom {try_device} device.") + return device + + +def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): + """ + mps is currently not compatible with float64 + """ + if isinstance(device, torch.device): + device = device.type + if device == "mps" and dtype == torch.float64: + return torch.float32 + if device == "xpu" and dtype == torch.float64: + if hasattr(torch.xpu, "get_device_capability"): + device_capability = torch.xpu.get_device_capability() + # NOTE: Some Intel XPU devices do not support double precision (FP64). + # The `has_fp64` flag is returned by `torch.xpu.get_device_capability()` + # when available; if False, we fall back to float32 for compatibility. + if not device_capability.get("has_fp64", False): + logging.warning(f"Device {device} does not support float64, using float32 instead.") + return torch.float32 + else: + logging.warning( + f"Device {device} capability check failed. Assuming no support for float64, using float32 instead." + ) + return torch.float32 + return dtype + else: + return dtype + + +def is_torch_device_available(try_device: str) -> bool: + try_device = str(try_device) # Ensure try_device is a string + if try_device.startswith("cuda"): + return torch.cuda.is_available() + elif try_device == "mps": + return torch.backends.mps.is_available() + elif try_device == "xpu": + return torch.xpu.is_available() + elif try_device == "cpu": + return True + else: + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") + + +def is_amp_available(device: str): + if device in ["cuda", "xpu", "cpu"]: + return True + elif device == "mps": + return False + else: + raise ValueError(f"Unknown device '{device}.") diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index d8481f4b9..02f6aebb3 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -19,7 +19,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.utils import load_json, write_json +from lerobot.datasets.io_utils import load_json, write_json from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state from lerobot.policies.pretrained import PreTrainedPolicy diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index c7ad2bbdb..b9f8441d6 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import platform @@ -24,11 +26,12 @@ from copy import copy, deepcopy from datetime import datetime from pathlib import Path from statistics import mean +from typing import TYPE_CHECKING import numpy as np -import torch -from accelerate import Accelerator -from datasets.utils.logging import disable_progress_bar, enable_progress_bar + +if TYPE_CHECKING: + from accelerate import Accelerator def inside_slurm(): @@ -37,96 +40,6 @@ def inside_slurm(): return "SLURM_JOB_ID" in os.environ -def auto_select_torch_device() -> torch.device: - """Tries to select automatically a torch device.""" - if torch.cuda.is_available(): - logging.info("Cuda backend detected, using cuda.") - return torch.device("cuda") - elif torch.backends.mps.is_available(): - logging.info("Metal backend detected, using mps.") - return torch.device("mps") - elif torch.xpu.is_available(): - logging.info("Intel XPU backend detected, using xpu.") - return torch.device("xpu") - else: - logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") - return torch.device("cpu") - - -# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level -def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: - """Given a string, return a torch.device with checks on whether the device is available.""" - try_device = str(try_device) - if try_device.startswith("cuda"): - assert torch.cuda.is_available() - device = torch.device(try_device) - elif try_device == "mps": - assert torch.backends.mps.is_available() - device = torch.device("mps") - elif try_device == "xpu": - assert torch.xpu.is_available() - device = torch.device("xpu") - elif try_device == "cpu": - device = torch.device("cpu") - if log: - logging.warning("Using CPU, this will be slow.") - else: - device = torch.device(try_device) - if log: - logging.warning(f"Using custom {try_device} device.") - return device - - -def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): - """ - mps is currently not compatible with float64 - """ - if isinstance(device, torch.device): - device = device.type - if device == "mps" and dtype == torch.float64: - return torch.float32 - if device == "xpu" and dtype == torch.float64: - if hasattr(torch.xpu, "get_device_capability"): - device_capability = torch.xpu.get_device_capability() - # NOTE: Some Intel XPU devices do not support double precision (FP64). - # The `has_fp64` flag is returned by `torch.xpu.get_device_capability()` - # when available; if False, we fall back to float32 for compatibility. - if not device_capability.get("has_fp64", False): - logging.warning(f"Device {device} does not support float64, using float32 instead.") - return torch.float32 - else: - logging.warning( - f"Device {device} capability check failed. Assuming no support for float64, using float32 instead." - ) - return torch.float32 - return dtype - else: - return dtype - - -def is_torch_device_available(try_device: str) -> bool: - try_device = str(try_device) # Ensure try_device is a string - if try_device.startswith("cuda"): - return torch.cuda.is_available() - elif try_device == "mps": - return torch.backends.mps.is_available() - elif try_device == "xpu": - return torch.xpu.is_available() - elif try_device == "cpu": - return True - else: - raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") - - -def is_amp_available(device: str): - if device in ["cuda", "xpu", "cpu"]: - return True - elif device == "mps": - return False - else: - raise ValueError(f"Unknown device '{device}.") - - def init_logging( log_file: Path | None = None, display_pid: bool = False, @@ -297,9 +210,13 @@ class SuppressProgressBars: """ def __enter__(self): + from datasets.utils.logging import disable_progress_bar + disable_progress_bar() def __exit__(self, exc_type, exc_val, exc_tb): + from datasets.utils.logging import enable_progress_bar + enable_progress_bar() diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 31ca8d247..782358c9e 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -18,7 +18,7 @@ import os import numpy as np import rerun as rr -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR diff --git a/tests/configs/test_default.py b/tests/configs/test_default.py new file mode 100644 index 000000000..238b8bacd --- /dev/null +++ b/tests/configs/test_default.py @@ -0,0 +1,38 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from lerobot.configs.default import DatasetConfig + + +def test_dataset_config_valid(): + DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2]) + + +def test_dataset_config_negative_episodes(): + with pytest.raises(ValueError, match="non-negative"): + DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2]) + + +def test_dataset_config_duplicate_episodes(): + with pytest.raises(ValueError, match="duplicates"): + DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2]) + + +def test_dataset_config_none_episodes_ok(): + DatasetConfig(repo_id="user/repo", episodes=None) + + +def test_dataset_config_empty_episodes_ok(): + DatasetConfig(repo_id="user/repo", episodes=[]) diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 3609bac24..4ac7e001a 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -260,8 +260,8 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): # Mock the revision to prevent Hub calls during dataset loading with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "test_aggr") @@ -311,8 +311,8 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): # Mock the revision to prevent Hub calls during dataset loading with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "small_aggr") @@ -367,8 +367,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "regression_aggr") @@ -492,8 +492,8 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): # Load the aggregated dataset with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "image_aggr") @@ -562,8 +562,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "ds_ab") @@ -590,8 +590,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "ds_abc") diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 1de199630..5ed7aa1a3 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -67,8 +67,8 @@ def test_delete_single_episode(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -93,8 +93,8 @@ def test_delete_multiple_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -150,8 +150,8 @@ def test_split_by_episodes(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -193,8 +193,8 @@ def test_split_by_fractions(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -270,8 +270,8 @@ def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fact dataset2.finalize() with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -310,8 +310,8 @@ def test_add_features_with_values(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -346,8 +346,8 @@ def test_add_features_with_callable(sample_dataset, tmp_path): "reward": (compute_reward, feature_info), } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -401,8 +401,8 @@ def test_modify_features_add_and_remove(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "modified") @@ -434,8 +434,8 @@ def test_modify_features_only_add(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "modified") @@ -457,8 +457,8 @@ def test_modify_features_only_remove(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -494,8 +494,8 @@ def test_remove_single_feature(sample_dataset, tmp_path): "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -521,8 +521,8 @@ def test_remove_single_feature(sample_dataset, tmp_path): def test_remove_multiple_features(sample_dataset, tmp_path): """Test removing multiple features at once.""" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -576,8 +576,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path): camera_to_remove = camera_keys[0] with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "without_camera") @@ -598,8 +598,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path): def test_complex_workflow_integration(sample_dataset, tmp_path): """Test a complex workflow combining multiple operations.""" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -647,8 +647,8 @@ def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -671,8 +671,8 @@ def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -699,8 +699,8 @@ def test_split_three_ways(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -732,8 +732,8 @@ def test_split_preserves_stats(sample_dataset, tmp_path): splits = {"train": [0, 1, 2], "val": [3, 4]} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -790,8 +790,8 @@ def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fa datasets.append(dataset) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -832,8 +832,8 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f dataset2.finalize() with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -866,8 +866,8 @@ def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -890,8 +890,8 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -919,8 +919,8 @@ def test_delete_consecutive_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -943,8 +943,8 @@ def test_delete_first_and_last_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -971,8 +971,8 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -999,8 +999,8 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -1020,7 +1020,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): # Get original chunk/file indices from first episode if train_dataset.meta.episodes is None: - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes train_dataset.meta.episodes = load_episodes(train_dataset.meta.root) original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes] @@ -1040,7 +1040,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): # Check that chunk/file indices are preserved if modified_dataset.meta.episodes is None: - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root) new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes] @@ -1194,7 +1194,7 @@ def test_modify_tasks_in_place(sample_dataset): def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset): """Test that original tasks are kept when using episode_tasks without new_task.""" - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes # Ensure episodes metadata is loaded if sample_dataset.meta.episodes is None: @@ -1229,8 +1229,8 @@ def test_convert_image_to_video_dataset(tmp_path): output_dir = tmp_path / "pusht_video" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -1292,8 +1292,8 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path): output_dir = tmp_path / "pusht_video_subset" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index 99b832e55..874099e2b 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -19,11 +19,28 @@ import torch from datasets import Dataset from huggingface_hub import DatasetCard -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch +from lerobot.datasets.feature_utils import combine_feature_dicts +from lerobot.datasets.io_utils import hf_transform_to_torch +from lerobot.datasets.utils import create_lerobot_dataset_card from lerobot.utils.constants import ACTION, OBS_IMAGES +def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: + """Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}.""" + episode_data_index: dict[str, list[int]] = {"from": [], "to": []} + current_episode = None + if len(hf_dataset) == 0: + return {"from": torch.tensor([]), "to": torch.tensor([])} + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + episode_data_index["from"].append(idx) + if current_episode is not None: + episode_data_index["to"].append(idx) + current_episode = episode_idx + episode_data_index["to"].append(idx + 1) + return {k: torch.tensor(v) for k, v in episode_data_index.items()} + + def test_default_parameters(): card = create_lerobot_dataset_card() assert isinstance(card, DatasetCard) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 6f99eb301..67878d8f6 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -29,20 +29,19 @@ import lerobot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset +from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features from lerobot.datasets.image_writer import image_array_to_pil_image +from lerobot.datasets.io_utils import hf_transform_to_torch from lerobot.datasets.lerobot_dataset import ( LeRobotDataset, - MultiLeRobotDataset, _encode_video_worker, ) +from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_VIDEO_FILE_SIZE_IN_MB, create_branch, - get_hf_features_from_features, - hf_transform_to_torch, - hw_to_dataset_features, ) from lerobot.datasets.video_utils import VALID_VIDEO_CODECS from lerobot.envs.factory import make_env_config @@ -1329,7 +1328,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact dataset.finalize() - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes dataset.meta.episodes = load_episodes(dataset.root) assert dataset.meta.episodes is not None diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index 72f69bc72..8d9529f68 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from lerobot.datasets.utils import ( +from lerobot.datasets.feature_utils import ( check_delta_timestamps, get_delta_indices, ) diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index 99c8b24fc..e02755171 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory): def test_write_image_exception(tmp_path): image_array = "invalid data" fpath = tmp_path / DUMMY_IMAGE - with patch("builtins.print") as mock_print: + with patch("lerobot.datasets.image_writer.logger") as mock_logger: write_image(image_array, fpath) - mock_print.assert_called() + mock_logger.error.assert_called() assert not fpath.exists() @@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path): image_array = "invalid data" fpath = tmp_path / DUMMY_IMAGE fpath.parent.mkdir(parents=True, exist_ok=True) - with patch("builtins.print") as mock_print: + with patch("lerobot.datasets.image_writer.logger") as mock_logger: writer.save_image(image_array, fpath) writer.wait_until_done() - mock_print.assert_called() + mock_logger.error.assert_called() assert not fpath.exists() finally: writer.stop() diff --git a/tests/datasets/test_online_buffer.py b/tests/datasets/test_online_buffer.py deleted file mode 100644 index 887da6041..000000000 --- a/tests/datasets/test_online_buffer.py +++ /dev/null @@ -1,282 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.d -from copy import deepcopy -from uuid import uuid4 - -import numpy as np -import pytest -import torch - -from lerobot.datasets.online_buffer import OnlineBuffer, compute_sampler_weights - -# Some constants for OnlineBuffer tests. -data_key = "data" -data_shape = (2, 3) # just some arbitrary > 1D shape -buffer_capacity = 100 -fps = 10 - - -def make_new_buffer( - write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None -) -> tuple[OnlineBuffer, str]: - if write_dir is None: - write_dir = f"/tmp/online_buffer_{uuid4().hex}" - buffer = OnlineBuffer( - write_dir, - data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}}, - buffer_capacity=buffer_capacity, - fps=fps, - delta_timestamps=delta_timestamps, - ) - return buffer, write_dir - - -def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: - new_data = { - data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), - OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), - OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), - OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), - OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), - } - return new_data - - -def test_non_mutate(): - """Checks that the data provided to the add_data method is copied rather than passed by reference. - - This means that mutating the data in the buffer does not mutate the original data. - - NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust - a success case for `test_write_read`. - """ - buffer, _ = make_new_buffer() - new_data = make_spoof_data_frames(2, buffer_capacity // 4) - new_data_copy = deepcopy(new_data) - buffer.add_data(new_data) - buffer._data[data_key][:] += 1 - assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data) - - -def test_index_error_no_data(): - buffer, _ = make_new_buffer() - with pytest.raises(IndexError): - buffer[0] - - -def test_index_error_with_data(): - buffer, _ = make_new_buffer() - n_frames = buffer_capacity // 2 - new_data = make_spoof_data_frames(1, n_frames) - buffer.add_data(new_data) - with pytest.raises(IndexError): - buffer[n_frames] - with pytest.raises(IndexError): - buffer[-n_frames - 1] - - -@pytest.mark.parametrize("do_reload", [False, True]) -def test_write_read(do_reload: bool): - """Checks that data can be added to the buffer and read back. - - If do_reload we delete the buffer object and load the buffer back from disk before reading. - """ - buffer, write_dir = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - if do_reload: - del buffer - buffer, _ = make_new_buffer(write_dir) - - assert len(buffer) == n_frames_per_episode * n_episodes - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), new_data[data_key][i]) - - -def test_read_data_key(): - """Tests that data can be added to a buffer and all data for a. specific key can be read back.""" - buffer, _ = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - data_from_buffer = buffer.get_data_by_key(data_key) - assert isinstance(data_from_buffer, torch.Tensor) - assert np.array_equal(data_from_buffer.numpy(), new_data[data_key]) - - -def test_fifo(): - """Checks that if data is added beyond the buffer capacity, we discard the oldest data first.""" - buffer, _ = make_new_buffer() - n_frames_per_episode = buffer_capacity // 4 - n_episodes = 3 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - n_more_episodes = 2 - # Developer sanity check (in case someone changes the global `buffer_capacity`). - assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, ( - "Something went wrong with the test code." - ) - more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) - buffer.add_data(more_new_data) - assert len(buffer) == buffer_capacity, "The buffer should be full." - - expected_data = {} - for k in new_data: - # Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer. - expected_data[k] = np.roll( - np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:], - shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity, - axis=0, - ) - - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i]) - - -def test_delta_timestamps_within_tolerance(): - """Check that getting an item with delta_timestamps within tolerance succeeds. - - Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`. - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] - torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values") - assert not is_pad.any(), "Unexpected padding detected" - - -def test_delta_timestamps_outside_tolerance_inside_episode_range(): - """Check that getting an item with delta_timestamps outside of tolerance fails. - - We expect it to fail if and only if the requested timestamps are within the episode range. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - with pytest.raises(AssertionError): - buffer[2] - - -def test_delta_timestamps_outside_tolerance_outside_episode_range(): - """Check that copy-padding of timestamps outside of the episode range works. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item["index_is_pad"] - assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" - assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( - "Padding does not match expected values" - ) - - -# Arbitrarily set small dataset sizes, making sure to have uneven sizes. -@pytest.mark.parametrize("offline_dataset_size", [1, 6]) -@pytest.mark.parametrize("online_dataset_size", [0, 4]) -@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) -def test_compute_sampler_weights_trivial( - lerobot_dataset_factory, - tmp_path, - offline_dataset_size: int, - online_dataset_size: int, - online_sampling_ratio: float, -): - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) - online_dataset, _ = make_new_buffer() - if online_dataset_size > 0: - online_dataset.add_data( - make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) - ) - - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - if offline_dataset_size == 0 or online_dataset_size == 0: - expected_weights = torch.ones(offline_dataset_size + online_dataset_size) - elif online_sampling_ratio == 0: - expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) - elif online_sampling_ratio == 1: - expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) - expected_weights /= expected_weights.sum() - torch.testing.assert_close(weights, expected_weights) - - -def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - online_sampling_ratio = 0.8 - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) - ) - - -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) - ) - - -def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): - """Note: test copied from test_sampler.""" - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - - weights = compute_sampler_weights( - offline_dataset, - offline_drop_n_last_frames=1, - online_dataset=online_dataset, - online_sampling_ratio=0.5, - online_drop_n_last_frames=1, - ) - torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index fd7a6e380..18fb1c8ac 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -13,13 +13,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import pytest +import torch from datasets import Dataset -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.sampler import EpisodeAwareSampler -from lerobot.datasets.utils import ( +from lerobot.datasets.io_utils import ( hf_transform_to_torch, ) +from lerobot.datasets.sampler import EpisodeAwareSampler + + +def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: + """Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}.""" + episode_data_index: dict[str, list[int]] = {"from": [], "to": []} + current_episode = None + if len(hf_dataset) == 0: + return {"from": torch.tensor([]), "to": torch.tensor([])} + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + episode_data_index["from"].append(idx) + if current_episode is not None: + episode_data_index["to"].append(idx) + current_episode = episode_idx + episode_data_index["to"].append(idx + 1) + return {k: torch.tensor(v) for k, v in episode_data_index.items()} def test_drop_n_first_frames(): @@ -90,3 +109,28 @@ def test_shuffle(): assert sampler.indices == [0, 1, 2, 3, 4, 5] assert len(sampler) == 6 assert set(sampler) == {0, 1, 2, 3, 4, 5} + + +def test_negative_drop_first_frames_raises(): + with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): + EpisodeAwareSampler([0], [10], drop_n_first_frames=-1) + + +def test_negative_drop_last_frames_raises(): + with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"): + EpisodeAwareSampler([0], [10], drop_n_last_frames=-1) + + +def test_all_episodes_dropped_raises(): + # All episodes have 1 frame, drop_n_first_frames=1 removes all + with pytest.raises(ValueError, match="No valid frames remain"): + EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1) + + +def test_partial_episode_drop_warns(caplog): + # Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept) + with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"): + sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1) + # Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5 + assert sampler.indices == [2, 3, 4, 5] + assert "Episode 0" in caplog.text diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index f8dd01fec..5ecb52145 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -26,7 +26,10 @@ import pytest import torch from datasets import Dataset -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import get_hf_features_from_features +from lerobot.datasets.io_utils import hf_transform_to_torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -35,8 +38,6 @@ from lerobot.datasets.utils import ( DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, flatten_dict, - get_hf_features_from_features, - hf_transform_to_torch, ) from lerobot.datasets.video_utils import encode_video_frames from tests.fixtures.constants import ( @@ -453,8 +454,8 @@ def lerobot_dataset_metadata_factory( episodes=episodes, ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version_patch, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download_patch, ): mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 11f3fa94a..92d9ca1e2 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -20,17 +20,19 @@ import pandas as pd import pytest from datasets import Dataset -from lerobot.datasets.utils import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_DATA_FILE_SIZE_IN_MB, - DEFAULT_DATA_PATH, +from lerobot.datasets.io_utils import ( get_hf_dataset_size_in_mb, - update_chunk_file_indices, write_episodes, write_info, write_stats, write_tasks, ) +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + update_chunk_file_indices, +) def write_hf_dataset( diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index f69a2c02a..5504b30bf 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -20,8 +20,8 @@ from functools import cached_property from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.motors.motors_bus import Motor, MotorNormMode -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots import Robot, RobotConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from tests.mocks.mock_motors_bus import MockMotorsBus diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index 89174dadf..b84b2b891 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -19,8 +19,8 @@ from dataclasses import dataclass from functools import cached_property from typing import Any -from lerobot.processor import RobotAction from lerobot.teleoperators import Teleoperator, TeleoperatorConfig +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py index 760f13a5f..e299a34e2 100644 --- a/tests/policies/groot/test_groot_lerobot.py +++ b/tests/policies/groot/test_groot_lerobot.py @@ -28,8 +28,9 @@ import torch from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.modeling_groot import GrootPolicy from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline -from lerobot.utils.utils import auto_select_torch_device +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction +from lerobot.utils.device_utils import auto_select_torch_device from tests.utils import require_cuda # noqa: E402 pytest.importorskip("transformers") diff --git a/tests/policies/groot/test_groot_vs_original.py b/tests/policies/groot/test_groot_vs_original.py index e9dd1df00..0adad96ca 100644 --- a/tests/policies/groot/test_groot_vs_original.py +++ b/tests/policies/groot/test_groot_vs_original.py @@ -28,7 +28,8 @@ import torch from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.modeling_groot import GrootPolicy from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction pytest.importorskip("gr00t") pytest.importorskip("transformers") diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py index d24bb11d7..b757d5a94 100644 --- a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -31,7 +31,8 @@ pytest.importorskip("scipy") from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy from lerobot.policies.pi0_fast.processor_pi0_fast import make_pi0_fast_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index f70707262..a965132b0 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -42,7 +42,8 @@ from transformers import AutoTokenizer # noqa: E402 from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402 from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 # TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index d3d1c1908..62e34b70d 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -41,7 +41,8 @@ from transformers import AutoTokenizer # noqa: E402 from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402 from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 # TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 1ba82ffd0..77a74d60e 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -28,7 +28,8 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.factory import make_dataset -from lerobot.datasets.utils import cycle, dataset_to_policy_features +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.utils import preprocess_observation from lerobot.optim.factory import make_optimizer_and_scheduler @@ -41,6 +42,8 @@ from lerobot.policies.factory import ( make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats @@ -459,3 +462,45 @@ def test_act_temporal_ensembler(): assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4) + + +def test_vqbet_discretize_keeps_buffers_on_device(): + """Regression test: VQBeTHead.discretize() must not move registered buffers off the model device. + + Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the + registered buffer with a new CPU tensor, causing DDP to crash with: + RuntimeError: No backend type associated with device type cpu + The fix uses `.fill_(True)` to update in-place, preserving device placement. + """ + config = VQBeTConfig() + config.input_features = { + OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + # Tiny sizes for fast CPU/GPU execution. + config.n_vqvae_training_steps = 3 + config.vqvae_n_embed = 8 + config.vqvae_embedding_dim = 32 + config.vqvae_enc_hidden_dim = 32 + config.action_chunk_size = 2 + config.crop_shape = (84, 84) + + head = VQBeTHead(config).to(DEVICE) + vqvae = head.vqvae_model + + dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE) + n_steps = config.n_vqvae_training_steps + for _ in range(n_steps): + head.discretize(n_steps, dummy_actions) + + assert vqvae.discretized.device.type == torch.device(DEVICE).type, ( + "vqvae_model.discretized was moved off the model device after discretize(). " + "Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device." + ) + assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, ( + "vq_layer.freeze_codebook was moved off the model device after discretize(). " + "Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device." + ) diff --git a/tests/policies/test_sarm_processor.py b/tests/policies/test_sarm_processor.py index 66404f663..5b90784a6 100644 --- a/tests/policies/test_sarm_processor.py +++ b/tests/policies/test_sarm_processor.py @@ -25,7 +25,7 @@ import pandas as pd import pytest import torch -from lerobot.processor.core import TransitionKey +from lerobot.types import TransitionKey class MockDatasetMeta: diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py index e36d14d01..3cea11329 100644 --- a/tests/policies/xvla/test_xvla_original_vs_lerobot.py +++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py @@ -30,7 +30,8 @@ pytest.importorskip("transformers") from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.modeling_xvla import XVLAPolicy from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402 from tests.utils import require_cuda # noqa: E402 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 477381618..d589b6c5e 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -16,8 +16,9 @@ import torch -from lerobot.processor import DataProcessorPipeline, TransitionKey +from lerobot.processor import DataProcessorPipeline from lerobot.processor.converters import batch_to_transition, transition_to_batch +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, REWARD, TRUNCATED diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index 47a6eea18..91afdd0e5 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -18,13 +18,13 @@ import numpy as np import pytest import torch -from lerobot.processor import TransitionKey from lerobot.processor.converters import ( batch_to_transition, create_transition, to_tensor, transition_to_batch, ) +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, DONE, OBS_STATE, OBS_STR, REWARD diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index bb7d467bf..57b923076 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -19,8 +19,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey +from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep from lerobot.processor.converters import create_transition, identity_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 208a6b5c5..cd5c75005 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -30,7 +30,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR -from lerobot.utils.utils import auto_select_torch_device +from lerobot.utils.device_utils import auto_select_torch_device def test_numpy_conversion(): diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 11b58a66c..923059210 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -19,8 +19,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType -from lerobot.processor import TransitionKey, VanillaObservationProcessorStep +from lerobot.processor import VanillaObservationProcessorStep from lerobot.processor.converters import create_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 64cc8aac8..2f1c4cc9c 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -25,8 +25,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey +from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep from lerobot.processor.converters import create_transition, identity_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import ( ACTION, OBS_IMAGE, diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index ace0aea49..772588467 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -71,8 +71,8 @@ def test_record_and_resume(tmp_path): cfg.resume = True # Mock the revision to prevent Hub calls during resume with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "record") @@ -115,8 +115,8 @@ def test_record_and_replay(tmp_path): # Mock the revision to prevent Hub calls during replay with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "record_and_replay") diff --git a/tests/training/test_visual_validation.py b/tests/training/test_visual_validation.py index af693fe5e..89351e3c2 100644 --- a/tests/training/test_visual_validation.py +++ b/tests/training/test_visual_validation.py @@ -37,7 +37,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import make_policy_config from lerobot.scripts.lerobot_train import train -from lerobot.utils.utils import auto_select_torch_device +from lerobot.utils.device_utils import auto_select_torch_device pytest.importorskip("transformers") diff --git a/tests/utils.py b/tests/utils.py index a77082ea9..33c554804 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,8 +21,8 @@ import pytest import torch from lerobot import available_cameras, available_motors, available_robots +from lerobot.utils.device_utils import auto_select_torch_device from lerobot.utils.import_utils import is_package_available -from lerobot.utils.utils import auto_select_torch_device DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device())) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 408f636cb..c8e5a92a8 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -21,7 +21,7 @@ from types import SimpleNamespace import numpy as np import pytest -from lerobot.processor import TransitionKey +from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_STATE